@@ -50,14 +50,14 @@ template<> struct emb_kernel<ck::half_t, 8192> { using kernel_type = DeviceInsta
50
50
51
51
// clang-format on
52
52
53
- int main ()
53
+ int main (int argc, char * argv[] )
54
54
{
55
55
bool time_kernel = true ;
56
56
57
- constexpr auto num_rows = 65536 ;
58
- constexpr auto dims = ck::Sequence<256 , 512 , 768 , 1024 , 1536 , 2048 , 4096 , 8192 >{};
59
- // constexpr auto dims = ck::Sequence<256, 512>{} ;
60
- constexpr auto index_length = 2048 ;
57
+ ck:: index_t num_rows = 65536 ;
58
+ constexpr auto dims = ck::Sequence<256 , 512 , 768 , 1024 , 1536 , 2048 , 4096 , 8192 >{};
59
+ ck:: index_t index_length = 2048 ;
60
+ ck:: index_t dim_mask = 0xffff ;
61
61
constexpr AccDataType epsilon = 1e-4 ;
62
62
63
63
auto f_host_tensor_desc_1d = [](std::size_t len_) { return HostTensorDescriptor ({len_}); };
@@ -73,121 +73,140 @@ int main()
73
73
BetaDataType,
74
74
AccDataType,
75
75
OutType>;
76
-
76
+ if (argc == 1 )
77
+ {
78
+ // Use default value
79
+ }
80
+ else if (argc == 4 )
81
+ {
82
+ num_rows = atoi (argv[1 ]);
83
+ dim_mask = strtol (argv[2 ], nullptr , 0 );
84
+ index_length = atoi (argv[3 ]);
85
+ }
86
+ else
87
+ {
88
+ std::cout << " Usage of " << argv[0 ] << std::endl;
89
+ std::cout << " Arg1-3: num_rows dim_mask index_length" << std::endl;
90
+ }
77
91
ck::static_for<0 , dims.Size (), 1 >{}([&](auto I) {
78
- std::srand (std::time (nullptr ));
79
- constexpr auto current_dim = dims.At (I);
80
- Tensor<EmbType> emb_a (f_host_tensor_desc_2d (num_rows, current_dim));
81
- Tensor<EmbType> emb_b (f_host_tensor_desc_2d (num_rows, current_dim));
82
- Tensor<EmbType> emb_c (f_host_tensor_desc_2d (num_rows, current_dim));
83
-
84
- Tensor<IndexType> index_a (f_host_tensor_desc_1d (index_length));
85
- Tensor<IndexType> index_b (f_host_tensor_desc_1d (index_length));
86
- Tensor<IndexType> index_c (f_host_tensor_desc_1d (index_length));
87
-
88
- Tensor<GammaDataType> gamma (f_host_tensor_desc_1d (current_dim));
89
- Tensor<BetaDataType> beta (f_host_tensor_desc_1d (current_dim));
90
-
91
- Tensor<OutType> out (f_host_tensor_desc_2d (index_length, current_dim));
92
-
93
- emb_a.GenerateTensorValue (GeneratorTensor_3<EmbType>{0.0 , 1.0 });
94
- emb_b.GenerateTensorValue (GeneratorTensor_3<EmbType>{0.0 , 1.0 });
95
- emb_c.GenerateTensorValue (GeneratorTensor_3<EmbType>{0.0 , 1.0 });
96
-
97
- index_a.GenerateTensorValue (GeneratorTensor_2<IndexType>{0 , num_rows});
98
- index_b.GenerateTensorValue (GeneratorTensor_2<IndexType>{0 , num_rows});
99
- index_c.GenerateTensorValue (GeneratorTensor_2<IndexType>{0 , num_rows});
100
-
101
- gamma.GenerateTensorValue (GeneratorTensor_3<GammaDataType>{0.0 , 1.0 });
102
- beta.GenerateTensorValue (GeneratorTensor_3<BetaDataType>{0.0 , 1.0 });
103
-
104
- DeviceMem emb_a_dev (sizeof (EmbType) * emb_a.mDesc .GetElementSpaceSize ());
105
- DeviceMem emb_b_dev (sizeof (EmbType) * emb_b.mDesc .GetElementSpaceSize ());
106
- DeviceMem emb_c_dev (sizeof (EmbType) * emb_c.mDesc .GetElementSpaceSize ());
107
-
108
- DeviceMem index_a_dev (sizeof (IndexType) * index_a.mDesc .GetElementSpaceSize ());
109
- DeviceMem index_b_dev (sizeof (IndexType) * index_b.mDesc .GetElementSpaceSize ());
110
- DeviceMem index_c_dev (sizeof (IndexType) * index_c.mDesc .GetElementSpaceSize ());
111
-
112
- DeviceMem gamma_dev (sizeof (GammaDataType) * gamma.mDesc .GetElementSpaceSize ());
113
- DeviceMem beta_dev (sizeof (BetaDataType) * beta.mDesc .GetElementSpaceSize ());
114
-
115
- DeviceMem out_dev (sizeof (OutType) * out.mDesc .GetElementSpaceSize ());
116
-
117
- emb_a_dev.ToDevice (emb_a.mData .data ());
118
- emb_b_dev.ToDevice (emb_b.mData .data ());
119
- emb_c_dev.ToDevice (emb_c.mData .data ());
120
-
121
- index_a_dev.ToDevice (index_a.mData .data ());
122
- index_b_dev.ToDevice (index_b.mData .data ());
123
- index_c_dev.ToDevice (index_c.mData .data ());
124
-
125
- gamma_dev.ToDevice (gamma.mData .data ());
126
- beta_dev.ToDevice (beta.mData .data ());
127
-
128
- auto device_instance = typename emb_kernel<EmbType, current_dim>::kernel_type{};
129
- auto argument_ptr = device_instance.MakeArgumentPointer (
130
- out_dev.GetDeviceBuffer (),
131
- {ck::type_convert<EmbType*>(emb_a_dev.GetDeviceBuffer ()),
132
- ck::type_convert<EmbType*>(emb_b_dev.GetDeviceBuffer ()),
133
- ck::type_convert<EmbType*>(emb_c_dev.GetDeviceBuffer ())},
134
- {ck::type_convert<IndexType*>(index_a_dev.GetDeviceBuffer ()),
135
- ck::type_convert<IndexType*>(index_b_dev.GetDeviceBuffer ()),
136
- ck::type_convert<IndexType*>(index_c_dev.GetDeviceBuffer ())},
137
- gamma_dev.GetDeviceBuffer (),
138
- beta_dev.GetDeviceBuffer (),
139
- current_dim,
140
- index_length,
141
- epsilon,
142
- EmbElementwiseOperation{});
143
- std::cout << " Dim:" << current_dim << " , kernel:" << device_instance.GetTypeString ()
144
- << std::endl
145
- << std::flush;
146
-
147
- bool is_supported = device_instance.IsSupportedArgument (argument_ptr.get ());
148
-
149
- if (!is_supported)
92
+ if (dim_mask & (1 << I.value ))
150
93
{
151
- std::cout << " Runtime parameters are not supported" << std::endl;
152
- return ;
94
+ std::srand (std::time (nullptr ));
95
+ constexpr auto current_dim = dims.At (I);
96
+ Tensor<EmbType> emb_a (f_host_tensor_desc_2d (num_rows, current_dim));
97
+ Tensor<EmbType> emb_b (f_host_tensor_desc_2d (num_rows, current_dim));
98
+ Tensor<EmbType> emb_c (f_host_tensor_desc_2d (num_rows, current_dim));
99
+
100
+ Tensor<IndexType> index_a (f_host_tensor_desc_1d (index_length));
101
+ Tensor<IndexType> index_b (f_host_tensor_desc_1d (index_length));
102
+ Tensor<IndexType> index_c (f_host_tensor_desc_1d (index_length));
103
+
104
+ Tensor<GammaDataType> gamma (f_host_tensor_desc_1d (current_dim));
105
+ Tensor<BetaDataType> beta (f_host_tensor_desc_1d (current_dim));
106
+
107
+ Tensor<OutType> out (f_host_tensor_desc_2d (index_length, current_dim));
108
+
109
+ emb_a.GenerateTensorValue (GeneratorTensor_3<EmbType>{0.0 , 1.0 });
110
+ emb_b.GenerateTensorValue (GeneratorTensor_3<EmbType>{0.0 , 1.0 });
111
+ emb_c.GenerateTensorValue (GeneratorTensor_3<EmbType>{0.0 , 1.0 });
112
+
113
+ index_a.GenerateTensorValue (GeneratorTensor_2<IndexType>{0 , num_rows});
114
+ index_b.GenerateTensorValue (GeneratorTensor_2<IndexType>{0 , num_rows});
115
+ index_c.GenerateTensorValue (GeneratorTensor_2<IndexType>{0 , num_rows});
116
+
117
+ gamma.GenerateTensorValue (GeneratorTensor_3<GammaDataType>{0.0 , 1.0 });
118
+ beta.GenerateTensorValue (GeneratorTensor_3<BetaDataType>{0.0 , 1.0 });
119
+
120
+ DeviceMem emb_a_dev (sizeof (EmbType) * emb_a.mDesc .GetElementSpaceSize ());
121
+ DeviceMem emb_b_dev (sizeof (EmbType) * emb_b.mDesc .GetElementSpaceSize ());
122
+ DeviceMem emb_c_dev (sizeof (EmbType) * emb_c.mDesc .GetElementSpaceSize ());
123
+
124
+ DeviceMem index_a_dev (sizeof (IndexType) * index_a.mDesc .GetElementSpaceSize ());
125
+ DeviceMem index_b_dev (sizeof (IndexType) * index_b.mDesc .GetElementSpaceSize ());
126
+ DeviceMem index_c_dev (sizeof (IndexType) * index_c.mDesc .GetElementSpaceSize ());
127
+
128
+ DeviceMem gamma_dev (sizeof (GammaDataType) * gamma.mDesc .GetElementSpaceSize ());
129
+ DeviceMem beta_dev (sizeof (BetaDataType) * beta.mDesc .GetElementSpaceSize ());
130
+
131
+ DeviceMem out_dev (sizeof (OutType) * out.mDesc .GetElementSpaceSize ());
132
+
133
+ emb_a_dev.ToDevice (emb_a.mData .data ());
134
+ emb_b_dev.ToDevice (emb_b.mData .data ());
135
+ emb_c_dev.ToDevice (emb_c.mData .data ());
136
+
137
+ index_a_dev.ToDevice (index_a.mData .data ());
138
+ index_b_dev.ToDevice (index_b.mData .data ());
139
+ index_c_dev.ToDevice (index_c.mData .data ());
140
+
141
+ gamma_dev.ToDevice (gamma.mData .data ());
142
+ beta_dev.ToDevice (beta.mData .data ());
143
+
144
+ auto device_instance = typename emb_kernel<EmbType, current_dim>::kernel_type{};
145
+ auto argument_ptr = device_instance.MakeArgumentPointer (
146
+ out_dev.GetDeviceBuffer (),
147
+ {ck::type_convert<EmbType*>(emb_a_dev.GetDeviceBuffer ()),
148
+ ck::type_convert<EmbType*>(emb_b_dev.GetDeviceBuffer ()),
149
+ ck::type_convert<EmbType*>(emb_c_dev.GetDeviceBuffer ())},
150
+ {ck::type_convert<IndexType*>(index_a_dev.GetDeviceBuffer ()),
151
+ ck::type_convert<IndexType*>(index_b_dev.GetDeviceBuffer ()),
152
+ ck::type_convert<IndexType*>(index_c_dev.GetDeviceBuffer ())},
153
+ gamma_dev.GetDeviceBuffer (),
154
+ beta_dev.GetDeviceBuffer (),
155
+ current_dim,
156
+ index_length,
157
+ epsilon,
158
+ EmbElementwiseOperation{});
159
+ std::cout << " Dim:" << current_dim << " , kernel:" << device_instance.GetTypeString ()
160
+ << std::endl
161
+ << std::flush;
162
+
163
+ bool is_supported = device_instance.IsSupportedArgument (argument_ptr.get ());
164
+
165
+ if (!is_supported)
166
+ {
167
+ std::cout << " Runtime parameters are not supported" << std::endl;
168
+ return ;
169
+ }
170
+
171
+ auto invoker_ptr = device_instance.MakeInvokerPointer ();
172
+ float time_ms =
173
+ invoker_ptr->Run (argument_ptr.get (), StreamConfig{nullptr , time_kernel});
174
+
175
+ bool pass = true ;
176
+ {
177
+ Tensor<OutType> out_from_dev (f_host_tensor_desc_2d (index_length, current_dim));
178
+ ReferenceInstance ref;
179
+ auto ref_argument = ref.MakeArgument (out,
180
+ emb_a,
181
+ emb_b,
182
+ emb_c,
183
+ index_a,
184
+ index_b,
185
+ index_c,
186
+ gamma,
187
+ beta,
188
+ num_rows,
189
+ current_dim,
190
+ index_length,
191
+ epsilon);
192
+ auto ref_invoker = ref.MakeInvoker ();
193
+ ref_invoker.Run (ref_argument);
194
+
195
+ out_dev.FromDevice (out_from_dev.mData .data ());
196
+ pass &=
197
+ ck::utils::check_err (out_from_dev, out, " Error: Incorrect results" , 1e-3 , 1e-3 );
198
+ }
199
+
200
+ double total_read = current_dim * index_length * 3 * sizeof (EmbType) +
201
+ current_dim * sizeof (GammaDataType) +
202
+ current_dim * sizeof (BetaDataType);
203
+ double total_write = current_dim * index_length * sizeof (OutType);
204
+ double gbps = (total_read + total_write) / time_ms / 1e6 ;
205
+
206
+ std::cout << " , total bytes:" << (total_read + total_write) << " , time:" << time_ms
207
+ << " , gbps:" << gbps << " , valid:" << (pass ? " y" : " n" ) << std::endl
208
+ << std::flush;
153
209
}
154
-
155
- auto invoker_ptr = device_instance.MakeInvokerPointer ();
156
- float time_ms = invoker_ptr->Run (argument_ptr.get (), StreamConfig{nullptr , time_kernel});
157
-
158
- bool pass = true ;
159
- {
160
- Tensor<OutType> out_from_dev (f_host_tensor_desc_2d (index_length, current_dim));
161
- ReferenceInstance ref;
162
- auto ref_argument = ref.MakeArgument (out,
163
- emb_a,
164
- emb_b,
165
- emb_c,
166
- index_a,
167
- index_b,
168
- index_c,
169
- gamma,
170
- beta,
171
- num_rows,
172
- current_dim,
173
- index_length,
174
- epsilon);
175
- auto ref_invoker = ref.MakeInvoker ();
176
- ref_invoker.Run (ref_argument);
177
-
178
- out_dev.FromDevice (out_from_dev.mData .data ());
179
- pass &= ck::utils::check_err (out_from_dev, out, " Error: Incorrect results" , 1e-3 , 1e-3 );
180
- }
181
-
182
- double total_read = current_dim * index_length * 3 * sizeof (EmbType) +
183
- current_dim * sizeof (GammaDataType) +
184
- current_dim * sizeof (BetaDataType);
185
- double total_write = current_dim * index_length * sizeof (OutType);
186
- double gbps = (total_read + total_write) / time_ms / 1e6 ;
187
-
188
- std::cout << " , total bytes:" << (total_read + total_write) << " , time:" << time_ms
189
- << " , gbps:" << gbps << " , valid:" << (pass ? " y" : " n" ) << std::endl
190
- << std::flush;
191
210
});
192
211
193
212
return 0 ;
0 commit comments