Skip to content

Commit 4363a82

Browse files
authored
[CK_TILE] Tensor-wise scaled quant gemm kernel (#2846)
* rename gemm_group_quant to gemm_quant * Add TensorWise quant mode * Cshuffle epilogue tests with tensor scaling * Add tensor quant to example * Don't use readfirstlane for reading scales - doesn't work for some reason * Add to changelog * revert include - from a merge problem? * revert common.hpp include * revert host.hpp include * remove unused utility function * rename quant pipeline problem * refactor quant tests * remove aquant utils * use TEST_F * fix all tests by changing gemm config * Use typed tests * fix copyright
1 parent b765fe7 commit 4363a82

39 files changed

+1554
-1055
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
3131
* Added benchmarking support for tile engine GEMM Multi D.
3232
* Added block scaling support in CK_TILE GEMM, allowing flexible use of quantization matrices from either A or B operands.
3333
* Added the row-wise column-wise quantization for CK_TILE GEMM & CK_TILE Grouped GEMM.
34+
* Added tensor-wise quantization for CK_TILE GEMM
3435

3536
### Optimized
3637

example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
#include "ck_tile/core.hpp"
1414
#include "ck_tile/ops/epilogue.hpp"
1515
#include "ck_tile/ops/gemm.hpp"
16-
#include "ck_tile/ops/gemm_group_quant.hpp"
16+
#include "ck_tile/ops/gemm_quant.hpp"
1717
#include "ck_tile/host.hpp"
1818
#include "quant_grouped_gemm.hpp"
1919

@@ -65,15 +65,15 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
6565
constexpr auto memory_operation = memory_operation_.value;
6666
constexpr bool transpose_c = false;
6767

68-
using QuantGemmProblem = ck_tile::GemmRowColQuantPipelineProblem<ADataType,
69-
BDataType,
70-
AccDataType,
71-
AccDataType,
72-
GemmShape,
73-
GemmUniversalTraits,
74-
transpose_c,
75-
BDataType,
76-
scheduler>;
68+
using QuantGemmProblem = ck_tile::GemmRowColTensorQuantPipelineProblem<ADataType,
69+
BDataType,
70+
AccDataType,
71+
AccDataType,
72+
GemmShape,
73+
GemmUniversalTraits,
74+
transpose_c,
75+
BDataType,
76+
scheduler>;
7777

7878
using GemmPipeline = typename PipelineTypeTraits<
7979
GemmConfig::Pipeline>::template GemmPipeline<QuantGemmProblem>;

example/ck_tile/38_block_scale_gemm/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ This folder contains examples of quant GEMMs using the ck_tile tile-programming
55
- AQuant kernel with blocks of A matrix sharing scales: custom GEMM pipeline
66
- BQuant kernel with blocks of B matrix sharing scales: custom GEMM pipeline
77
- Row and Column-wise scaled: scaling implemented in Epilogue
8+
- Tensor-wise scaled: scaling implemented in Epilogue
89

910
## build
1011
```
@@ -14,7 +15,6 @@ mkdir build && cd build
1415
../script/cmake-ck-dev.sh ../ <arch>
1516
# Compile the quant kernels
1617
make tile_example_gemm_quant_basic -j
17-
make tile_example_gemm_bquant_basic -j
1818
```
1919
This will result in an executable `build/bin/tile_example_gemm_quant_basic`
2020

@@ -37,7 +37,7 @@ args:
3737
-warmup number of iterations before benchmark the kernel (default:10)
3838
-repeat number of iterations to benchmark the kernel (default:100)
3939
-timer gpu:gpu timer, cpu:cpu timer (default:gpu)
40-
-quant_mode Which quant method to use (aquant, rowcol)
40+
-quant_mode Which quant method to use (aquant, bquant, tensor, rowcol)
4141
```
4242

4343
User need to select correct mapping of config for each quant mode:

example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -66,19 +66,21 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
6666
constexpr auto tail_number_v = tail_number_.value;
6767
constexpr bool transpose_c = false;
6868

69+
// row-col and tensor quants use the regular pipeline, A/B quants use their own
6970
using PipelineProblem = std::conditional_t<
70-
QuantMode == ck_tile::QuantType::RowColQuant,
71-
ck_tile::GemmRowColQuantPipelineProblem<typename TypeConfig::ADataType,
72-
typename TypeConfig::BDataType,
73-
typename TypeConfig::AccDataType,
74-
typename TypeConfig::AccDataType,
75-
GemmShape,
76-
GemmTraits,
77-
transpose_c,
78-
ComputeDataType,
79-
GemmConfig::Scheduler,
80-
has_hot_loop_v,
81-
tail_number_v>,
71+
QuantMode == ck_tile::QuantType::RowColQuant ||
72+
QuantMode == ck_tile::QuantType::TensorQuant,
73+
ck_tile::GemmRowColTensorQuantPipelineProblem<typename TypeConfig::ADataType,
74+
typename TypeConfig::BDataType,
75+
typename TypeConfig::AccDataType,
76+
typename TypeConfig::AccDataType,
77+
GemmShape,
78+
GemmTraits,
79+
transpose_c,
80+
ComputeDataType,
81+
GemmConfig::Scheduler,
82+
has_hot_loop_v,
83+
tail_number_v>,
8284
std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped,
8385
ck_tile::GemmAQuantPipelineProblem<typename TypeConfig::ADataType,
8486
typename TypeConfig::QDataType,
@@ -105,7 +107,8 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
105107
tail_number_v>>>;
106108

107109
using GemmPipeline = std::conditional_t<
108-
QuantMode == ck_tile::QuantType::RowColQuant,
110+
QuantMode == ck_tile::QuantType::RowColQuant ||
111+
QuantMode == ck_tile::QuantType::TensorQuant,
109112
ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>,
110113
std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped,
111114
ck_tile::AQuantGemmPipelineAgBgCrCompV3<PipelineProblem>,
@@ -241,10 +244,18 @@ int run_gemm_example(int argc, char* argv[])
241244
ck_tile::QuantType::RowColQuant>(
242245
a_layout, b_layout, argc, argv);
243246
}
247+
else if(quant_mode == "tensor")
248+
{
249+
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
250+
TypeConfig,
251+
128,
252+
ck_tile::QuantType::TensorQuant>(
253+
a_layout, b_layout, argc, argv);
254+
}
244255
else
245256
{
246257
throw std::runtime_error(
247-
"Unsupported quantization mode! Use 'aquant', 'bquant' or 'rowcol'");
258+
"Unsupported quantization mode! Use 'aquant', 'bquant', 'tensor' or 'rowcol'");
248259
}
249260
}
250261
else if(data_type == "bf8")
@@ -276,10 +287,18 @@ int run_gemm_example(int argc, char* argv[])
276287
ck_tile::QuantType::RowColQuant>(
277288
a_layout, b_layout, argc, argv);
278289
}
290+
else if(quant_mode == "tensor")
291+
{
292+
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
293+
TypeConfig,
294+
128,
295+
ck_tile::QuantType::TensorQuant>(
296+
a_layout, b_layout, argc, argv);
297+
}
279298
else
280299
{
281300
throw std::runtime_error(
282-
"Unsupported quantization mode! Use 'aquant', 'bquant' or 'rowcol'");
301+
"Unsupported quantization mode! Use 'aquant', 'bquant', 'tensor' or 'rowcol'");
283302
}
284303
}
285304
else if(data_type == "i4fp8")

example/ck_tile/38_block_scale_gemm/gemm_utils.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
#include "ck_tile/host/kernel_launch.hpp"
1010
#include "ck_tile/ops/epilogue.hpp"
1111
#include "ck_tile/ops/gemm.hpp"
12-
#include "ck_tile/ops/gemm_group_quant.hpp"
12+
#include "ck_tile/ops/gemm_quant.hpp"
1313

1414
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
1515
constexpr ck_tile::index_t get_k_warp_tile()
@@ -241,7 +241,7 @@ auto create_args(int argc, char* argv[])
241241
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
242242
.insert("flush_cache", "true", "flush cache before running the kernel, defaults to true")
243243
.insert("rotating_count", "1", "rotating count, defaults to 1")
244-
.insert("quant_mode", "aquant", "Choose aquant (default), bquant or rowcol");
244+
.insert("quant_mode", "aquant", "Choose aquant (default), bquant, tensor or rowcol");
245245

246246
bool result = arg_parser.parse(argc, argv);
247247
return std::make_tuple(result, arg_parser);

example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc

Lines changed: 50 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -119,11 +119,7 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
119119
}
120120
std::cout << " Acc_Type = " << DataTypeTraits<typename TypeConfig::AccDataType>::name
121121
<< " C_Type = " << DataTypeTraits<typename TypeConfig::CDataType>::name
122-
<< " QuantMode = "
123-
<< (QuantMode == ck_tile::QuantType::AQuantGrouped
124-
? "AQuantGrouped"
125-
: (QuantMode == ck_tile::QuantType::BQuantGrouped ? "BQuantGrouped"
126-
: "RowColQuant"))
122+
<< " QuantMode = " << quant_type_to_string(QuantMode)
127123
<< " PreshuffleQuant = " << (GemmConfig::PreshuffleQuant ? "true" : "false") << " : "
128124
<< ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
129125
<< std::endl;
@@ -183,10 +179,11 @@ int run_gemm_example_with_layouts(int argc,
183179
AQK = 0; // No A quantization
184180
BQK = K / QuantGroupSize; // Group quantization: BQK = K / GroupSize
185181
}
186-
else if constexpr(QuantMode == ck_tile::QuantType::RowColQuant)
182+
else if constexpr(QuantMode == ck_tile::QuantType::RowColQuant ||
183+
QuantMode == ck_tile::QuantType::TensorQuant)
187184
{
188-
AQK = 1; // Row quantization: tensor shape [M, 1]
189-
BQK = N; // Column quantization: tensor shape [1, N]
185+
AQK = 1; // Row quantization: tensor shape [M, 1] or [1]
186+
BQK = 1; // Column quantization: tensor shape [1, N] or [1]
190187
}
191188
else
192189
{
@@ -227,6 +224,11 @@ int run_gemm_example_with_layouts(int argc,
227224
stride_AQ = ck_tile::get_default_stride(M, 1, stride_AQ, is_row_major(aq_layout));
228225
stride_BQ = ck_tile::get_default_stride(1, N, stride_BQ, is_row_major(bq_layout));
229226
}
227+
else if constexpr(QuantMode == ck_tile::QuantType::TensorQuant)
228+
{
229+
stride_AQ = 1; // Tensor quantization: tensor shape [1]
230+
stride_BQ = 1; // Tensor quantization: tensor shape [1]
231+
}
230232

231233
ck_tile::HostTensor<ADataType> a_m_k(
232234
ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout)));
@@ -237,28 +239,30 @@ int run_gemm_example_with_layouts(int argc,
237239

238240
// Create AQ tensor with appropriate shape
239241
std::unique_ptr<ck_tile::HostTensor<AQDataType>> aq_tensor_ptr = nullptr;
240-
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
242+
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped ||
243+
QuantMode == ck_tile::QuantType::RowColQuant)
241244
{
242245
aq_tensor_ptr = std::make_unique<ck_tile::HostTensor<AQDataType>>(
243246
ck_tile::host_tensor_descriptor(M, AQK, stride_AQ, is_row_major(aq_layout)));
244247
}
245-
else if(QuantMode == ck_tile::QuantType::RowColQuant)
248+
else if constexpr(QuantMode == ck_tile::QuantType::TensorQuant)
246249
{
247250
aq_tensor_ptr = std::make_unique<ck_tile::HostTensor<AQDataType>>(
248-
ck_tile::host_tensor_descriptor(M, AQK, stride_AQ, is_row_major(aq_layout)));
251+
ck_tile::host_tensor_descriptor(1, 1, stride_AQ, is_row_major(aq_layout)));
249252
}
250253

251-
// Create BQ tensor only for RowColQuant mode
254+
// Create BQ tensor with appropriate shape
252255
std::unique_ptr<ck_tile::HostTensor<BQDataType>> bq_tensor_ptr = nullptr;
253-
if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
256+
if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped ||
257+
QuantMode == ck_tile::QuantType::RowColQuant)
254258
{
255259
bq_tensor_ptr = std::make_unique<ck_tile::HostTensor<BQDataType>>(
256260
ck_tile::host_tensor_descriptor(BQK, N, stride_BQ, is_row_major(bq_layout)));
257261
}
258-
else if constexpr(QuantMode == ck_tile::QuantType::RowColQuant)
262+
else if constexpr(QuantMode == ck_tile::QuantType::TensorQuant)
259263
{
260264
bq_tensor_ptr = std::make_unique<ck_tile::HostTensor<BQDataType>>(
261-
ck_tile::host_tensor_descriptor(1, N, stride_BQ, is_row_major(bq_layout)));
265+
ck_tile::host_tensor_descriptor(1, 1, stride_BQ, is_row_major(bq_layout)));
262266
}
263267

264268
std::random_device rd;
@@ -282,7 +286,7 @@ int run_gemm_example_with_layouts(int argc,
282286
*bq_tensor_ptr);
283287
ck_tile::FillUniformDistribution<ADataType>{-5.0f, 5.0f, fill_seed(gen)}(a_m_k);
284288
}
285-
else
289+
else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
286290
{
287291
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
288292
{
@@ -296,12 +300,15 @@ int run_gemm_example_with_layouts(int argc,
296300
ck_tile::FillUniformDistribution<AQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
297301
*aq_tensor_ptr);
298302
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f, fill_seed(gen)}(b_k_n);
299-
300-
if constexpr(QuantMode == ck_tile::QuantType::RowColQuant)
301-
{
302-
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
303-
*bq_tensor_ptr);
304-
}
303+
}
304+
else
305+
{
306+
ck_tile::FillUniformDistribution<ADataType>{-2.0f, 2.0f, fill_seed(gen)}(a_m_k);
307+
ck_tile::FillUniformDistribution<BDataType>{-2.0f, 2.0f, fill_seed(gen)}(b_k_n);
308+
ck_tile::FillUniformDistribution<AQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
309+
*aq_tensor_ptr);
310+
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
311+
*bq_tensor_ptr);
305312
}
306313
}
307314
else if(init_method == 1)
@@ -343,22 +350,25 @@ int run_gemm_example_with_layouts(int argc,
343350

344351
std::unique_ptr<ck_tile::DeviceMem> aq_dev_buf_ptr = nullptr;
345352
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped ||
346-
QuantMode == ck_tile::QuantType::RowColQuant)
353+
QuantMode == ck_tile::QuantType::RowColQuant ||
354+
QuantMode == ck_tile::QuantType::TensorQuant)
347355
{
348356
aq_dev_buf_ptr =
349357
std::make_unique<ck_tile::DeviceMem>(aq_tensor_ptr->get_element_space_size_in_bytes());
350358
}
351359

352360
std::unique_ptr<ck_tile::DeviceMem> bq_dev_buf_ptr = nullptr;
353361
if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped ||
354-
QuantMode == ck_tile::QuantType::RowColQuant)
362+
QuantMode == ck_tile::QuantType::RowColQuant ||
363+
QuantMode == ck_tile::QuantType::TensorQuant)
355364
{
356365
bq_dev_buf_ptr =
357366
std::make_unique<ck_tile::DeviceMem>(bq_tensor_ptr->get_element_space_size_in_bytes());
358367
}
359368

360369
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped ||
361-
QuantMode == ck_tile::QuantType::RowColQuant)
370+
QuantMode == ck_tile::QuantType::RowColQuant ||
371+
QuantMode == ck_tile::QuantType::TensorQuant)
362372
{
363373
if constexpr(GemmConfig::PreshuffleQuant)
364374
{
@@ -398,7 +408,8 @@ int run_gemm_example_with_layouts(int argc,
398408
c_m_n_dev_result.SetZero();
399409

400410
if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped ||
401-
QuantMode == ck_tile::QuantType::RowColQuant)
411+
QuantMode == ck_tile::QuantType::RowColQuant ||
412+
QuantMode == ck_tile::QuantType::TensorQuant)
402413
{
403414
bq_dev_buf_ptr->ToDevice(bq_tensor_ptr->data());
404415
}
@@ -412,15 +423,9 @@ int run_gemm_example_with_layouts(int argc,
412423
CLayout,
413424
QuantGroupSize,
414425
QuantMode>(a_m_k_dev_buf,
415-
(QuantMode == ck_tile::QuantType::AQuantGrouped ||
416-
QuantMode == ck_tile::QuantType::RowColQuant)
417-
? aq_dev_buf_ptr.get()
418-
: nullptr,
426+
aq_dev_buf_ptr.get(),
419427
b_k_n_dev_buf,
420-
(QuantMode == ck_tile::QuantType::BQuantGrouped ||
421-
QuantMode == ck_tile::QuantType::RowColQuant)
422-
? bq_dev_buf_ptr.get()
423-
: nullptr,
428+
bq_dev_buf_ptr.get(),
424429
c_m_n_dev_buf,
425430
M,
426431
N,
@@ -467,7 +472,7 @@ int run_gemm_example_with_layouts(int argc,
467472
QuantGroupSize,
468473
false>(a_m_k, *bq_tensor_ptr, b_k_n, c_m_n_host_ref);
469474
}
470-
else
475+
else if constexpr(QuantMode == ck_tile::QuantType::RowColQuant)
471476
{
472477
ck_tile::reference_gemm_rowcol_quant<ADataType,
473478
AQDataType,
@@ -477,6 +482,16 @@ int run_gemm_example_with_layouts(int argc,
477482
CDataType>(
478483
a_m_k, *aq_tensor_ptr, b_k_n, *bq_tensor_ptr, c_m_n_host_ref);
479484
}
485+
else if constexpr(QuantMode == ck_tile::QuantType::TensorQuant)
486+
{
487+
ck_tile::reference_gemm_tensor_quant<ADataType,
488+
AQDataType,
489+
BDataType,
490+
BQDataType,
491+
AccDataType,
492+
CDataType>(
493+
a_m_k, *aq_tensor_ptr, b_k_n, *bq_tensor_ptr, c_m_n_host_ref);
494+
}
480495

481496
const float max_accumulated_value =
482497
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());

include/ck_tile/core/tensor/load_tile.hpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
// SPDX-License-Identifier: MIT
2-
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
2+
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
33

44
#pragma once
55

@@ -158,4 +158,7 @@ CK_TILE_DEVICE auto load_tile_raw(T& /*null_tile*/, const null_tile_window<Windo
158158
{
159159
}
160160

161+
template <typename Tile>
162+
concept IsLoadableTile = requires { load_tile(std::declval<Tile>()); };
163+
161164
} // namespace ck_tile

0 commit comments

Comments
 (0)