@@ -119,11 +119,7 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
119
119
}
120
120
std::cout << " Acc_Type = " << DataTypeTraits<typename TypeConfig::AccDataType>::name
121
121
<< " C_Type = " << DataTypeTraits<typename TypeConfig::CDataType>::name
122
- << " QuantMode = "
123
- << (QuantMode == ck_tile::QuantType::AQuantGrouped
124
- ? " AQuantGrouped"
125
- : (QuantMode == ck_tile::QuantType::BQuantGrouped ? " BQuantGrouped"
126
- : " RowColQuant" ))
122
+ << " QuantMode = " << quant_type_to_string (QuantMode)
127
123
<< " PreshuffleQuant = " << (GemmConfig::PreshuffleQuant ? " true" : " false" ) << " : "
128
124
<< ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
129
125
<< std::endl;
@@ -183,10 +179,11 @@ int run_gemm_example_with_layouts(int argc,
183
179
AQK = 0 ; // No A quantization
184
180
BQK = K / QuantGroupSize; // Group quantization: BQK = K / GroupSize
185
181
}
186
- else if constexpr (QuantMode == ck_tile::QuantType::RowColQuant)
182
+ else if constexpr (QuantMode == ck_tile::QuantType::RowColQuant ||
183
+ QuantMode == ck_tile::QuantType::TensorQuant)
187
184
{
188
- AQK = 1 ; // Row quantization: tensor shape [M, 1]
189
- BQK = N ; // Column quantization: tensor shape [1, N]
185
+ AQK = 1 ; // Row quantization: tensor shape [M, 1] or [1]
186
+ BQK = 1 ; // Column quantization: tensor shape [1, N] or [1 ]
190
187
}
191
188
else
192
189
{
@@ -227,6 +224,11 @@ int run_gemm_example_with_layouts(int argc,
227
224
stride_AQ = ck_tile::get_default_stride (M, 1 , stride_AQ, is_row_major (aq_layout));
228
225
stride_BQ = ck_tile::get_default_stride (1 , N, stride_BQ, is_row_major (bq_layout));
229
226
}
227
+ else if constexpr (QuantMode == ck_tile::QuantType::TensorQuant)
228
+ {
229
+ stride_AQ = 1 ; // Tensor quantization: tensor shape [1]
230
+ stride_BQ = 1 ; // Tensor quantization: tensor shape [1]
231
+ }
230
232
231
233
ck_tile::HostTensor<ADataType> a_m_k (
232
234
ck_tile::host_tensor_descriptor (M, K, stride_A, is_row_major (a_layout)));
@@ -237,28 +239,30 @@ int run_gemm_example_with_layouts(int argc,
237
239
238
240
// Create AQ tensor with appropriate shape
239
241
std::unique_ptr<ck_tile::HostTensor<AQDataType>> aq_tensor_ptr = nullptr ;
240
- if constexpr (QuantMode == ck_tile::QuantType::AQuantGrouped)
242
+ if constexpr (QuantMode == ck_tile::QuantType::AQuantGrouped ||
243
+ QuantMode == ck_tile::QuantType::RowColQuant)
241
244
{
242
245
aq_tensor_ptr = std::make_unique<ck_tile::HostTensor<AQDataType>>(
243
246
ck_tile::host_tensor_descriptor (M, AQK, stride_AQ, is_row_major (aq_layout)));
244
247
}
245
- else if (QuantMode == ck_tile::QuantType::RowColQuant )
248
+ else if constexpr (QuantMode == ck_tile::QuantType::TensorQuant )
246
249
{
247
250
aq_tensor_ptr = std::make_unique<ck_tile::HostTensor<AQDataType>>(
248
- ck_tile::host_tensor_descriptor (M, AQK , stride_AQ, is_row_major (aq_layout)));
251
+ ck_tile::host_tensor_descriptor (1 , 1 , stride_AQ, is_row_major (aq_layout)));
249
252
}
250
253
251
- // Create BQ tensor only for RowColQuant mode
254
+ // Create BQ tensor with appropriate shape
252
255
std::unique_ptr<ck_tile::HostTensor<BQDataType>> bq_tensor_ptr = nullptr ;
253
- if constexpr (QuantMode == ck_tile::QuantType::BQuantGrouped)
256
+ if constexpr (QuantMode == ck_tile::QuantType::BQuantGrouped ||
257
+ QuantMode == ck_tile::QuantType::RowColQuant)
254
258
{
255
259
bq_tensor_ptr = std::make_unique<ck_tile::HostTensor<BQDataType>>(
256
260
ck_tile::host_tensor_descriptor (BQK, N, stride_BQ, is_row_major (bq_layout)));
257
261
}
258
- else if constexpr (QuantMode == ck_tile::QuantType::RowColQuant )
262
+ else if constexpr (QuantMode == ck_tile::QuantType::TensorQuant )
259
263
{
260
264
bq_tensor_ptr = std::make_unique<ck_tile::HostTensor<BQDataType>>(
261
- ck_tile::host_tensor_descriptor (1 , N , stride_BQ, is_row_major (bq_layout)));
265
+ ck_tile::host_tensor_descriptor (1 , 1 , stride_BQ, is_row_major (bq_layout)));
262
266
}
263
267
264
268
std::random_device rd;
@@ -282,7 +286,7 @@ int run_gemm_example_with_layouts(int argc,
282
286
*bq_tensor_ptr);
283
287
ck_tile::FillUniformDistribution<ADataType>{-5 .0f , 5 .0f , fill_seed (gen)}(a_m_k);
284
288
}
285
- else
289
+ else if constexpr (QuantMode == ck_tile::QuantType::AQuantGrouped)
286
290
{
287
291
if constexpr (std::is_same_v<ADataType, ck_tile::pk_int4_t >)
288
292
{
@@ -296,12 +300,15 @@ int run_gemm_example_with_layouts(int argc,
296
300
ck_tile::FillUniformDistribution<AQDataType>{-2 .0f , 2 .0f , fill_seed (gen)}(
297
301
*aq_tensor_ptr);
298
302
ck_tile::FillUniformDistribution<BDataType>{-5 .0f , 5 .0f , fill_seed (gen)}(b_k_n);
299
-
300
- if constexpr (QuantMode == ck_tile::QuantType::RowColQuant)
301
- {
302
- ck_tile::FillUniformDistribution<BQDataType>{-2 .0f , 2 .0f , fill_seed (gen)}(
303
- *bq_tensor_ptr);
304
- }
303
+ }
304
+ else
305
+ {
306
+ ck_tile::FillUniformDistribution<ADataType>{-2 .0f , 2 .0f , fill_seed (gen)}(a_m_k);
307
+ ck_tile::FillUniformDistribution<BDataType>{-2 .0f , 2 .0f , fill_seed (gen)}(b_k_n);
308
+ ck_tile::FillUniformDistribution<AQDataType>{-2 .0f , 2 .0f , fill_seed (gen)}(
309
+ *aq_tensor_ptr);
310
+ ck_tile::FillUniformDistribution<BQDataType>{-2 .0f , 2 .0f , fill_seed (gen)}(
311
+ *bq_tensor_ptr);
305
312
}
306
313
}
307
314
else if (init_method == 1 )
@@ -343,22 +350,25 @@ int run_gemm_example_with_layouts(int argc,
343
350
344
351
std::unique_ptr<ck_tile::DeviceMem> aq_dev_buf_ptr = nullptr ;
345
352
if constexpr (QuantMode == ck_tile::QuantType::AQuantGrouped ||
346
- QuantMode == ck_tile::QuantType::RowColQuant)
353
+ QuantMode == ck_tile::QuantType::RowColQuant ||
354
+ QuantMode == ck_tile::QuantType::TensorQuant)
347
355
{
348
356
aq_dev_buf_ptr =
349
357
std::make_unique<ck_tile::DeviceMem>(aq_tensor_ptr->get_element_space_size_in_bytes ());
350
358
}
351
359
352
360
std::unique_ptr<ck_tile::DeviceMem> bq_dev_buf_ptr = nullptr ;
353
361
if constexpr (QuantMode == ck_tile::QuantType::BQuantGrouped ||
354
- QuantMode == ck_tile::QuantType::RowColQuant)
362
+ QuantMode == ck_tile::QuantType::RowColQuant ||
363
+ QuantMode == ck_tile::QuantType::TensorQuant)
355
364
{
356
365
bq_dev_buf_ptr =
357
366
std::make_unique<ck_tile::DeviceMem>(bq_tensor_ptr->get_element_space_size_in_bytes ());
358
367
}
359
368
360
369
if constexpr (QuantMode == ck_tile::QuantType::AQuantGrouped ||
361
- QuantMode == ck_tile::QuantType::RowColQuant)
370
+ QuantMode == ck_tile::QuantType::RowColQuant ||
371
+ QuantMode == ck_tile::QuantType::TensorQuant)
362
372
{
363
373
if constexpr (GemmConfig::PreshuffleQuant)
364
374
{
@@ -398,7 +408,8 @@ int run_gemm_example_with_layouts(int argc,
398
408
c_m_n_dev_result.SetZero ();
399
409
400
410
if constexpr (QuantMode == ck_tile::QuantType::BQuantGrouped ||
401
- QuantMode == ck_tile::QuantType::RowColQuant)
411
+ QuantMode == ck_tile::QuantType::RowColQuant ||
412
+ QuantMode == ck_tile::QuantType::TensorQuant)
402
413
{
403
414
bq_dev_buf_ptr->ToDevice (bq_tensor_ptr->data ());
404
415
}
@@ -412,15 +423,9 @@ int run_gemm_example_with_layouts(int argc,
412
423
CLayout,
413
424
QuantGroupSize,
414
425
QuantMode>(a_m_k_dev_buf,
415
- (QuantMode == ck_tile::QuantType::AQuantGrouped ||
416
- QuantMode == ck_tile::QuantType::RowColQuant)
417
- ? aq_dev_buf_ptr.get ()
418
- : nullptr ,
426
+ aq_dev_buf_ptr.get (),
419
427
b_k_n_dev_buf,
420
- (QuantMode == ck_tile::QuantType::BQuantGrouped ||
421
- QuantMode == ck_tile::QuantType::RowColQuant)
422
- ? bq_dev_buf_ptr.get ()
423
- : nullptr ,
428
+ bq_dev_buf_ptr.get (),
424
429
c_m_n_dev_buf,
425
430
M,
426
431
N,
@@ -467,7 +472,7 @@ int run_gemm_example_with_layouts(int argc,
467
472
QuantGroupSize,
468
473
false >(a_m_k, *bq_tensor_ptr, b_k_n, c_m_n_host_ref);
469
474
}
470
- else
475
+ else if constexpr (QuantMode == ck_tile::QuantType::RowColQuant)
471
476
{
472
477
ck_tile::reference_gemm_rowcol_quant<ADataType,
473
478
AQDataType,
@@ -477,6 +482,16 @@ int run_gemm_example_with_layouts(int argc,
477
482
CDataType>(
478
483
a_m_k, *aq_tensor_ptr, b_k_n, *bq_tensor_ptr, c_m_n_host_ref);
479
484
}
485
+ else if constexpr (QuantMode == ck_tile::QuantType::TensorQuant)
486
+ {
487
+ ck_tile::reference_gemm_tensor_quant<ADataType,
488
+ AQDataType,
489
+ BDataType,
490
+ BQDataType,
491
+ AccDataType,
492
+ CDataType>(
493
+ a_m_k, *aq_tensor_ptr, b_k_n, *bq_tensor_ptr, c_m_n_host_ref);
494
+ }
480
495
481
496
const float max_accumulated_value =
482
497
*std::max_element (c_m_n_host_ref.mData .begin (), c_m_n_host_ref.mData .end ());
0 commit comments