@@ -118,15 +118,43 @@ void calculate_reference_multi_dimensional(
118
118
e_idx.insert (e_idx.end (), n_idx.begin (), n_idx.end ());
119
119
120
120
::EDataType result = static_cast <::EDataType>(sum);
121
- std::vector<::DDataType> d_vals;
122
- for (const auto & d_tensor : ds_full_dims_host)
121
+ if (ds_full_dims_host.size () == 0 )
123
122
{
124
- d_vals. push_back (ck_tile::type_convert< float >( d_tensor (e_idx))) ;
123
+ ;
125
124
}
126
- if (d_vals .size () == 2 )
125
+ else if (ds_full_dims_host .size () == 1 )
127
126
{
128
- cde_elementwise (
129
- result, ck_tile::type_convert<float >(sum), d_vals[0 ], d_vals[1 ]);
127
+ cde_elementwise (result,
128
+ ck_tile::type_convert<float >(sum),
129
+ ck_tile::type_convert<float >(ds_full_dims_host[0 ](e_idx)));
130
+ }
131
+ else if (ds_full_dims_host.size () == 2 )
132
+ {
133
+ cde_elementwise (result,
134
+ ck_tile::type_convert<float >(sum),
135
+ ck_tile::type_convert<float >(ds_full_dims_host[0 ](e_idx)),
136
+ ck_tile::type_convert<float >(ds_full_dims_host[1 ](e_idx)));
137
+ }
138
+ else if (ds_full_dims_host.size () == 3 )
139
+ {
140
+ cde_elementwise (result,
141
+ ck_tile::type_convert<float >(sum),
142
+ ck_tile::type_convert<float >(ds_full_dims_host[0 ](e_idx)),
143
+ ck_tile::type_convert<float >(ds_full_dims_host[1 ](e_idx)),
144
+ ck_tile::type_convert<float >(ds_full_dims_host[2 ](e_idx)));
145
+ }
146
+ else if (ds_full_dims_host.size () == 4 )
147
+ {
148
+ cde_elementwise (result,
149
+ ck_tile::type_convert<float >(sum),
150
+ ck_tile::type_convert<float >(ds_full_dims_host[0 ](e_idx)),
151
+ ck_tile::type_convert<float >(ds_full_dims_host[1 ](e_idx)),
152
+ ck_tile::type_convert<float >(ds_full_dims_host[2 ](e_idx)),
153
+ ck_tile::type_convert<float >(ds_full_dims_host[3 ](e_idx)));
154
+ }
155
+ else
156
+ {
157
+ throw std::runtime_error (" Unsupported NumDTensor for reference calculation" );
130
158
}
131
159
132
160
e_full_dims_host_ref (e_idx) = static_cast <::EDataType>(result);
@@ -165,18 +193,69 @@ void calculate_reference_flat_indexing(
165
193
sum += static_cast <::AccDataType>(a_val) * static_cast <::AccDataType>(b_val);
166
194
}
167
195
168
- std::vector <::DDataType> d_vals ;
169
- for ( const auto & d_tensor : ds_full_dims_host )
196
+ ::EDataType result = static_cast <::EDataType>(sum) ;
197
+ if (ds_full_dims_host. size () == 0 )
170
198
{
171
- d_vals.push_back (ck_tile::type_convert<float >(
172
- d_tensor.mData [g_flat * M_total * N_total + m_flat * N_total + n_flat]));
199
+ ;
173
200
}
174
- ::EDataType result = static_cast <::EDataType>(sum);
175
- if (d_vals.size () == 2 )
201
+ else if (ds_full_dims_host.size () == 1 )
202
+ {
203
+ cde_elementwise (result,
204
+ ck_tile::type_convert<float >(sum),
205
+ ck_tile::type_convert<float >(
206
+ ds_full_dims_host[0 ].mData [g_flat * M_total * N_total +
207
+ m_flat * N_total + n_flat]));
208
+ }
209
+ else if (ds_full_dims_host.size () == 2 )
210
+ {
211
+ cde_elementwise (
212
+ result,
213
+ ck_tile::type_convert<float >(sum),
214
+ ck_tile::type_convert<float >(
215
+ ds_full_dims_host[0 ]
216
+ .mData [g_flat * M_total * N_total + m_flat * N_total + n_flat]),
217
+ ck_tile::type_convert<float >(
218
+ ds_full_dims_host[1 ]
219
+ .mData [g_flat * M_total * N_total + m_flat * N_total + n_flat]));
220
+ }
221
+ else if (ds_full_dims_host.size () == 3 )
222
+ {
223
+ cde_elementwise (
224
+ result,
225
+ ck_tile::type_convert<float >(sum),
226
+ ck_tile::type_convert<float >(
227
+ ds_full_dims_host[0 ]
228
+ .mData [g_flat * M_total * N_total + m_flat * N_total + n_flat]),
229
+ ck_tile::type_convert<float >(
230
+ ds_full_dims_host[1 ]
231
+ .mData [g_flat * M_total * N_total + m_flat * N_total + n_flat]),
232
+ ck_tile::type_convert<float >(
233
+ ds_full_dims_host[2 ]
234
+ .mData [g_flat * M_total * N_total + m_flat * N_total + n_flat]));
235
+ }
236
+ else if (ds_full_dims_host.size () == 4 )
176
237
{
177
238
cde_elementwise (
178
- result, ck_tile::type_convert<float >(sum), d_vals[0 ], d_vals[1 ]);
239
+ result,
240
+ ck_tile::type_convert<float >(sum),
241
+ ck_tile::type_convert<float >(
242
+ ds_full_dims_host[0 ]
243
+ .mData [g_flat * M_total * N_total + m_flat * N_total + n_flat]),
244
+ ck_tile::type_convert<float >(
245
+ ds_full_dims_host[1 ]
246
+ .mData [g_flat * M_total * N_total + m_flat * N_total + n_flat]),
247
+ ck_tile::type_convert<float >(
248
+ ds_full_dims_host[2 ]
249
+ .mData [g_flat * M_total * N_total + m_flat * N_total + n_flat]),
250
+ ck_tile::type_convert<float >(
251
+ ds_full_dims_host[3 ]
252
+ .mData [g_flat * M_total * N_total + m_flat * N_total + n_flat]));
253
+ }
254
+ else
255
+ {
256
+ throw std::runtime_error (" Unsupported NumDTensor for reference calculation" );
179
257
}
258
+
180
259
e_full_dims_host_ref.mData [g_flat * M_total * N_total + m_flat * N_total + n_flat] =
181
260
static_cast <::EDataType>(result);
182
261
}
@@ -368,25 +447,34 @@ int run_batched_contraction_example_with_layouts(
368
447
ck_tile::HostTensorDescriptor (Ds_dims[d], Ds_strides[d])));
369
448
}
370
449
371
- ck_tile::FillUniformDistribution<::DDataType>{-2 .f , 2 .f , std::nullopt }(ds_full_dims_host[0 ]);
372
- ck_tile::FillUniformDistribution<::DDataType>{-2 .f , 2 .f , std::nullopt }(ds_full_dims_host[1 ]);
373
-
374
- ck_tile::DeviceMem d0_full_dims_dev_buf (ds_full_dims_host[0 ].get_element_space_size_in_bytes ());
375
- ck_tile::DeviceMem d1_full_dims_dev_buf (ds_full_dims_host[1 ].get_element_space_size_in_bytes ());
376
- d0_full_dims_dev_buf.ToDevice (ds_full_dims_host[0 ].data ());
377
- d1_full_dims_dev_buf.ToDevice (ds_full_dims_host[1 ].data ());
450
+ for (int d = 0 ; d < NumDTensor; ++d)
451
+ {
452
+ ck_tile::FillUniformDistribution<::DDataType>{-2 .f , 2 .f , std::nullopt }(
453
+ ds_full_dims_host[d]);
454
+ }
378
455
379
- std::array<const void *, NumDTensor> ds_ptr_buf = {d0_full_dims_dev_buf.GetDeviceBuffer (),
380
- d1_full_dims_dev_buf.GetDeviceBuffer ()};
456
+ std::vector<std::unique_ptr<ck_tile::DeviceMem>> ds_full_dims_dev_buf;
457
+ for (int d = 0 ; d < NumDTensor; ++d)
458
+ {
459
+ ds_full_dims_dev_buf.push_back (std::make_unique<ck_tile::DeviceMem>(
460
+ ds_full_dims_host[d].get_element_space_size_in_bytes ()));
461
+ ds_full_dims_dev_buf[d]->ToDevice (ds_full_dims_host[d].data ());
462
+ }
463
+ std::array<const void *, NumDTensor> ds_ptr_buf;
464
+ for (int d = 0 ; d < NumDTensor; ++d)
465
+ {
466
+ ds_ptr_buf[d] = ds_full_dims_dev_buf[d]->GetDeviceBuffer ();
467
+ }
381
468
382
469
e_full_dims_dev_buf.SetZero ();
383
470
e_full_dims_host.SetZero ();
384
471
385
472
std::cout << " \n === Running GPU Kernel ===" << std::endl;
386
473
387
- using DsDataType = ck_tile::tuple_array<::DDataType, NumDTensor>;
388
- using DsLayout = ck_tile::tuple_array<DLayout, NumDTensor>;
389
- using CDEElementWise = AddDs;
474
+ using DsDataType = ck_tile::tuple_array<::DDataType, NumDTensor>;
475
+ using DsLayout = ck_tile::tuple_array<DLayout, NumDTensor>;
476
+ using CDEElementWise =
477
+ std::conditional_t <NumDTensor == 0 , ck_tile::element_wise::PassThrough, AddDs>;
390
478
391
479
float ave_time =
392
480
invoke_batched_contraction_kernel<::ADataType,
@@ -427,11 +515,13 @@ int run_batched_contraction_example_with_layouts(
427
515
" D, M: " + std::to_string (M_dims.size ()) + " D, N: " + std::to_string (N_dims.size ()) +
428
516
" D, K: " + std::to_string (K_dims.size ()) + " D" };
429
517
430
- std::size_t flop =
431
- std::size_t (2 ) * G_total * M_total * N_total * K_total; // Number of operations
432
- std::size_t num_byte = sizeof (::ADataType) * G_total * M_total * K_total + // A tensor size
433
- sizeof (::BDataType) * G_total * N_total * K_total + // B tensor size
434
- sizeof (::EDataType) * G_total * M_total * N_total; // E tensor size
518
+ std::size_t flop = std::size_t (2 ) * G_total * M_total * N_total * K_total +
519
+ NumDTensor * K_total * M_total * N_total; // Number of operations
520
+ std::size_t num_byte =
521
+ sizeof (::ADataType) * G_total * M_total * K_total + // A tensor size
522
+ sizeof (::BDataType) * G_total * N_total * K_total + // B tensor size
523
+ sizeof (::DDataType) * NumDTensor * G_total * M_total * N_total + // D tensors
524
+ sizeof (::EDataType) * G_total * M_total * N_total; // E tensor size
435
525
436
526
float tflops = static_cast <float >(flop) / 1 .E9 / ave_time; // TFlops calculation
437
527
float gb_per_sec = num_byte / 1 .E6 / ave_time; // GB/s calculation
@@ -443,23 +533,6 @@ int run_batched_contraction_example_with_layouts(
443
533
std::cout << " Performance: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
444
534
<< " GB/s" << std::endl;
445
535
446
- // DETAILED: Tensor information
447
- std::cout << " \n Detailed Tensor Info:" << std::endl;
448
- std::cout << " A tensor: " << G_total << " x " << M_total << " x " << K_total << " = "
449
- << G_total * M_total * K_total << " elements ("
450
- << (sizeof (::ADataType) * G_total * M_total * K_total) / 1024 / 1024 << " MB)"
451
- << std::endl;
452
- std::cout << " B tensor: " << G_total << " x " << N_total << " x " << K_total << " = "
453
- << G_total * N_total * K_total << " elements ("
454
- << (sizeof (::BDataType) * G_total * N_total * K_total) / 1024 / 1024 << " MB)"
455
- << std::endl;
456
- std::cout << " E tensor: " << G_total << " x " << M_total << " x " << N_total << " = "
457
- << G_total * M_total * N_total << " elements ("
458
- << (sizeof (::EDataType) * G_total * M_total * N_total) / 1024 / 1024 << " MB)"
459
- << std::endl;
460
- std::cout << " Total memory: " << num_byte / 1024 / 1024 << " MB" << std::endl;
461
- std::cout << " Total FLOPs: " << flop / 1000000 << " million" << std::endl;
462
-
463
536
std::cout << " ===============================================" << std::endl;
464
537
465
538
e_full_dims_dev_buf.FromDevice (e_full_dims_host.data ());
0 commit comments