Skip to content

Commit 8728380

Browse files
committed
add command options [part 2]
1 parent 40caa18 commit 8728380

File tree

50 files changed

+1262
-332
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+1262
-332
lines changed

profiler/include/profiler/profile_batched_gemm_bias_softmax_gemm_permute_impl.hpp

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,10 @@ bool profile_batched_gemm_bias_softmax_gemm_permute_impl(bool do_verification,
4545
int O,
4646
int G0,
4747
int G1,
48-
float alpha = -1.f)
48+
float alpha = -1.f,
49+
int instance_index = -1)
4950

5051
{
51-
5252
using PassThrough = tensor_operation::element_wise::PassThrough;
5353
using ScaleAdd = tensor_operation::element_wise::ScaleAdd;
5454
using AElementOp = PassThrough;
@@ -273,7 +273,7 @@ bool profile_batched_gemm_bias_softmax_gemm_permute_impl(bool do_verification,
273273
float best_ave_time = 0;
274274
float best_tflops = 0;
275275
float best_gb_per_sec = 0;
276-
276+
int num_kernel = 0;
277277
// profile device op instances
278278
for(auto& op_ptr : op_ptrs)
279279
{
@@ -310,6 +310,13 @@ bool profile_batched_gemm_bias_softmax_gemm_permute_impl(bool do_verification,
310310

311311
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
312312
{
313+
++num_kernel;
314+
if((instance_index != -1) && (instance_index + 1 != num_kernel))
315+
{
316+
// skip test if instance_index is specified
317+
continue;
318+
}
319+
313320
std::string op_name = op_ptr->GetTypeString();
314321

315322
float ave_time =
@@ -388,6 +395,11 @@ bool profile_batched_gemm_bias_softmax_gemm_permute_impl(bool do_verification,
388395
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
389396
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
390397

398+
if(instance_index != -1)
399+
{
400+
std::cout << "batched_gemm_bias_softmax_gemm_permute_instance (" << instance_index << "/"
401+
<< num_kernel << "): Passed" << std::endl;
402+
}
391403
return pass;
392404
}
393405

profiler/include/profiler/profile_batched_gemm_softmax_gemm_impl.hpp

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -40,19 +40,19 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
4040
int N,
4141
int K,
4242
int O,
43-
int BatchCount = 1,
44-
int StrideA = -1,
45-
int StrideB0 = -1,
46-
int StrideB1 = -1,
47-
int StrideC = -1,
48-
int BatchStrideA = -1,
49-
int BatchStrideB0 = -1,
50-
int BatchStrideB1 = -1,
51-
int BatchStrideC = -1,
52-
float alpha = -1.f)
43+
int BatchCount = 1,
44+
int StrideA = -1,
45+
int StrideB0 = -1,
46+
int StrideB1 = -1,
47+
int StrideC = -1,
48+
int BatchStrideA = -1,
49+
int BatchStrideB0 = -1,
50+
int BatchStrideB1 = -1,
51+
int BatchStrideC = -1,
52+
float alpha = -1.f,
53+
int instance_index = -1)
5354

5455
{
55-
5656
using Row = tensor_layout::gemm::RowMajor;
5757
using Col = tensor_layout::gemm::ColumnMajor;
5858
using PassThrough = tensor_operation::element_wise::PassThrough;
@@ -251,7 +251,7 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
251251
float best_ave_time = 0;
252252
float best_tflops = 0;
253253
float best_gb_per_sec = 0;
254-
254+
int num_kernel = 0;
255255
// profile device op instances
256256
for(auto& op_ptr : op_ptrs)
257257
{
@@ -283,6 +283,13 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
283283

284284
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
285285
{
286+
++num_kernel;
287+
if((instance_index != -1) && (instance_index + 1 != num_kernel))
288+
{
289+
// skip test if instance_index is specified
290+
continue;
291+
}
292+
286293
std::string op_name = op_ptr->GetTypeString();
287294

288295
float ave_time =
@@ -339,7 +346,11 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
339346

340347
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
341348
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
342-
349+
if(instance_index != -1)
350+
{
351+
std::cout << "batched_gemm_softmax_gemm_instance (" << instance_index << "/" << num_kernel
352+
<< "): Passed" << std::endl;
353+
}
343354
return pass;
344355
}
345356

profiler/include/profiler/profile_batched_gemm_softmax_gemm_permute_impl.hpp

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,10 @@ bool profile_batched_gemm_softmax_gemm_permute_impl(bool do_verification,
4545
int O,
4646
int G0,
4747
int G1,
48-
float alpha = -1.f)
48+
float alpha = -1.f,
49+
int instance_index = -1)
4950

5051
{
51-
5252
using PassThrough = tensor_operation::element_wise::PassThrough;
5353
using Scale = tensor_operation::element_wise::Scale;
5454
using AElementOp = PassThrough;
@@ -251,6 +251,7 @@ bool profile_batched_gemm_softmax_gemm_permute_impl(bool do_verification,
251251
float best_ave_time = 0;
252252
float best_tflops = 0;
253253
float best_gb_per_sec = 0;
254+
int num_kernel = 0;
254255

255256
// profile device op instances
256257
for(auto& op_ptr : op_ptrs)
@@ -284,6 +285,13 @@ bool profile_batched_gemm_softmax_gemm_permute_impl(bool do_verification,
284285

285286
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
286287
{
288+
++num_kernel;
289+
if((instance_index != -1) && (instance_index + 1 != num_kernel))
290+
{
291+
// skip test if instance_index is specified
292+
continue;
293+
}
294+
287295
std::string op_name = op_ptr->GetTypeString();
288296

289297
float ave_time =
@@ -359,7 +367,11 @@ bool profile_batched_gemm_softmax_gemm_permute_impl(bool do_verification,
359367

360368
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
361369
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
362-
370+
if(instance_index != -1)
371+
{
372+
std::cout << "batched_gemm_softmax_gemm_permute_instance (" << instance_index << "/"
373+
<< num_kernel << "): Passed" << std::endl;
374+
}
363375
return pass;
364376
}
365377

profiler/include/profiler/profile_contraction_impl.hpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ int profile_contraction_impl(ck::index_t do_verification,
5454
const std::vector<ck::index_t>& StridesA, // [M0, M1, K0, K1]
5555
const std::vector<ck::index_t>& StridesB, // [N0, N1, K0, K1]
5656
const std::vector<ck::index_t>& StridesE, // [M0, M1, N0, N1]
57-
const std::vector<ck::index_t>& StridesD) // [M0, M1, N0, N1]
57+
const std::vector<ck::index_t>& StridesD, // [M0, M1, N0, N1]
58+
int instance_index = -1)
5859
{
5960
bool pass = true;
6061

@@ -187,7 +188,7 @@ int profile_contraction_impl(ck::index_t do_verification,
187188
float best_avg_time = 0;
188189
float best_tflops = 0;
189190
float best_gb_per_sec = 0;
190-
191+
int num_kernel = 0;
191192
// profile device op instances
192193
for(auto& op_ptr : op_ptrs)
193194
{
@@ -246,6 +247,12 @@ int profile_contraction_impl(ck::index_t do_verification,
246247

247248
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
248249
{
250+
++num_kernel;
251+
if((instance_index != -1) && (instance_index + 1 != num_kernel))
252+
{
253+
// skip test if instance_index is specified
254+
continue;
255+
}
249256
// re-init C to zero before profiling next kernel
250257
e_device_buf.SetZero();
251258

@@ -366,6 +373,11 @@ int profile_contraction_impl(ck::index_t do_verification,
366373
<< " ms, " << best_tflops << " TFlops, " << best_gb_per_sec << " GB/s, "
367374
<< best_op_name << std::endl;
368375

376+
if(instance_index != -1)
377+
{
378+
std::cout << "contraction_instance (" << instance_index << "/" << num_kernel << "): Passed"
379+
<< std::endl;
380+
}
369381
return pass;
370382
}
371383

profiler/include/profiler/profile_conv_bwd_data_impl.hpp

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ bool profile_conv_bwd_data_impl(int do_verification,
5858
int init_method,
5959
bool do_log,
6060
bool time_kernel,
61-
const ck::utils::conv::ConvParam& conv_param)
61+
const ck::utils::conv::ConvParam& conv_param,
62+
int instance_index = -1)
6263
{
6364
using InElementOp = ck::tensor_operation::element_wise::PassThrough;
6465
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
@@ -174,7 +175,7 @@ bool profile_conv_bwd_data_impl(int do_verification,
174175
float best_avg_time = 0;
175176
float best_tflops = 0;
176177
float best_gb_per_sec = 0;
177-
178+
int num_kernel = 0;
178179
// profile device Conv instances
179180
bool pass = true;
180181

@@ -200,6 +201,12 @@ bool profile_conv_bwd_data_impl(int do_verification,
200201

201202
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
202203
{
204+
++num_kernel;
205+
if((instance_index != -1) && (instance_index + 1 != num_kernel))
206+
{
207+
// skip test if instance_index is specified
208+
continue;
209+
}
203210
// for conv bwd data, some input tensor element are zero, but not written by kernel,
204211
// need to set zero
205212
in_device_buf.SetZero();
@@ -263,7 +270,11 @@ bool profile_conv_bwd_data_impl(int do_verification,
263270
std::cout << "Best configuration parameters:" << "\nname: " << best_op_name
264271
<< "\navg_time: " << best_avg_time << "\ntflops: " << best_tflops
265272
<< "\nGB/s: " << best_gb_per_sec << std::endl;
266-
273+
if(instance_index != -1)
274+
{
275+
std::cout << "conv_bwd_data_instance (" << instance_index << "/" << num_kernel
276+
<< "): Passed" << std::endl;
277+
}
267278
return pass;
268279
}
269280

profiler/include/profiler/profile_conv_fwd_impl.hpp

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ bool profile_conv_fwd_impl(int do_verification,
3636
int init_method,
3737
bool do_log,
3838
bool time_kernel,
39-
const ck::utils::conv::ConvParam& conv_param)
39+
const ck::utils::conv::ConvParam& conv_param,
40+
int instance_index = -1)
4041
{
4142
using InElementOp = ck::tensor_operation::element_wise::PassThrough;
4243
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
@@ -156,7 +157,7 @@ bool profile_conv_fwd_impl(int do_verification,
156157
float best_avg_time = 0;
157158
float best_tflops = 0;
158159
float best_gb_per_sec = 0;
159-
160+
int num_kernel = 0;
160161
// profile device op instances
161162
bool pass = true;
162163

@@ -182,6 +183,12 @@ bool profile_conv_fwd_impl(int do_verification,
182183

183184
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
184185
{
186+
++num_kernel;
187+
if((instance_index != -1) && (instance_index + 1 != num_kernel))
188+
{
189+
// skip test if instance_index is specified
190+
continue;
191+
}
185192
// re-init output to zero before profiling next kernel
186193
out_device_buf.SetZero();
187194

@@ -236,7 +243,11 @@ bool profile_conv_fwd_impl(int do_verification,
236243
std::cout << "Best configuration parameters:" << "\nname: " << best_op_name
237244
<< "\navg_time: " << best_avg_time << "\ntflops: " << best_tflops
238245
<< "\nGB/s: " << best_gb_per_sec << std::endl;
239-
246+
if(instance_index != -1)
247+
{
248+
std::cout << "conv_fwd_instance (" << instance_index << "/" << num_kernel << "): Passed"
249+
<< std::endl;
250+
}
240251
return pass;
241252
}
242253

profiler/include/profiler/profile_gemm_reduce_impl.hpp

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,8 @@ bool profile_gemm_reduce_impl(int do_verification,
7171
int K,
7272
int StrideA,
7373
int StrideB,
74-
int StrideC)
74+
int StrideC,
75+
int instance_index = -1)
7576
{
7677
bool pass = true;
7778

@@ -253,7 +254,7 @@ bool profile_gemm_reduce_impl(int do_verification,
253254
float best_ave_time = 0;
254255
float best_tflops = 0;
255256
float best_gb_per_sec = 0;
256-
257+
int num_kernel = 0;
257258
// profile device GEMM instances
258259
for(auto& gemm_ptr : gemm_ptrs)
259260
{
@@ -279,6 +280,12 @@ bool profile_gemm_reduce_impl(int do_verification,
279280

280281
if(gemm_ptr->IsSupportedArgument(argument_ptr.get()))
281282
{
283+
++num_kernel;
284+
if((instance_index != -1) && (instance_index + 1 != num_kernel))
285+
{
286+
// skip test if instance_index is specified
287+
continue;
288+
}
282289
// init DO, D1 to 0
283290
reduce0_device_buf.SetZero();
284291
reduce1_device_buf.SetZero();
@@ -349,7 +356,11 @@ bool profile_gemm_reduce_impl(int do_verification,
349356

350357
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
351358
<< best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl;
352-
359+
if(instance_index != -1)
360+
{
361+
std::cout << "gemm_reduce_instance (" << instance_index << "/" << num_kernel << "): Passed"
362+
<< std::endl;
363+
}
353364
return pass;
354365
}
355366

profiler/include/profiler/profile_gemm_splitk_impl.hpp

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ bool profile_gemm_splitk_impl(int do_verification,
4545
int StrideC,
4646
int KBatch,
4747
int n_warmup,
48-
int n_iter)
48+
int n_iter,
49+
int instance_index = -1)
4950
{
5051
bool pass = true;
5152

@@ -145,6 +146,7 @@ bool profile_gemm_splitk_impl(int do_verification,
145146
float best_tflops = 0;
146147
float best_gb_per_sec = 0;
147148
float best_kbatch = 0;
149+
int num_kernel = 0;
148150

149151
// profile device GEMM instances
150152
for(auto& op_ptr : op_ptrs)
@@ -179,7 +181,12 @@ bool profile_gemm_splitk_impl(int do_verification,
179181

180182
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
181183
{
182-
184+
++num_kernel;
185+
if((instance_index != -1) && (instance_index + 1 != num_kernel))
186+
{
187+
// skip test if instance_index is specified
188+
continue;
189+
}
183190
// re-init C to zero before profiling next kernel
184191
c_device_buf.SetZero();
185192

@@ -298,7 +305,11 @@ bool profile_gemm_splitk_impl(int do_verification,
298305
<< " StrideB = " << StrideB << " StrideC = " << StrideC << " KBatch = " << best_kbatch
299306
<< " : " << best_ave_time << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec
300307
<< " GB/s, " << best_op_name << std::endl;
301-
308+
if(instance_index != -1)
309+
{
310+
std::cout << "gemm_splitk_instance (" << instance_index << "/" << num_kernel << "): Passed"
311+
<< std::endl;
312+
}
302313
return pass;
303314
}
304315

0 commit comments

Comments
 (0)