Skip to content

Commit 4f89943

Browse files
committed
Generalize example code for variable NumD tensors and apply cleanup based on review feedback
1 parent 5d83464 commit 4f89943

File tree

6 files changed

+128
-73
lines changed

6 files changed

+128
-73
lines changed
Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,7 @@
11
add_executable(tile_example_batched_contraction EXCLUDE_FROM_ALL batched_contraction.cpp)
2-
set(EXAMPLE_GEMM_COMPILE_OPTIONS)
3-
set(EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS)
2+
set(EXAMPLE_GONTRACTION_COMPILE_OPTIONS)
43
if(CK_USE_OCP_FP8)
5-
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
4+
list(APPEND EXAMPLE_GONTRACTION_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
65
endif()
7-
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0)
8-
list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS -Wno-unused-local-typedef)
9-
list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS -Wno-gnu-line-marker)
10-
list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS --save-temps)
11-
list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm -enable-noalias-to-md-conversion=0")
12-
target_compile_options(tile_example_batched_contraction PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
6+
7+
target_compile_options(tile_example_batched_contraction PRIVATE ${EXAMPLE_GONTRACTION_COMPILE_OPTIONS})

example/ck_tile/40_batched_contraction/contraction_utils.hpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@
1010

1111
struct AddDs
1212
{
13-
template <typename E, typename C, typename D0, typename D1>
14-
CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const D0& d0, const D1& d1) const -> void
13+
template <typename E, typename C, typename... Ds>
14+
CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const Ds&... ds) const -> void
1515
{
16-
const float x0_f = ck_tile::type_convert<float>(c) + ck_tile::type_convert<float>(d0) +
17-
ck_tile::type_convert<float>(d1);
16+
const float x0_f =
17+
ck_tile::type_convert<float>(c) + (ck_tile::type_convert<float>(ds) + ...);
1818

1919
e = ck_tile::type_convert<E>(x0_f);
2020
}
@@ -50,7 +50,6 @@ auto create_args(int argc, char* argv[])
5050
.insert("k_dims", "2048", "K dimensions separated by comma (e.g., '64,32' for 2D K)")
5151
.insert(
5252
"g_dims", "8", "G dimensions separated by comma (e.g., '4,2' for 2D, '2,3,4' for 3D)")
53-
.insert("num_d", "1", "Number of D tensors (NumDTensor)")
5453
.insert("stride_a", "0", "Custom A tensor leading dimension stride (0 = auto)")
5554
.insert("stride_b", "0", "Custom B tensor leading dimension stride (0 = auto)")
5655
.insert("stride_e", "0", "Custom E tensor leading dimension stride (0 = auto)")

example/ck_tile/40_batched_contraction/run_batched_contraction_example.inc

Lines changed: 120 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -118,15 +118,43 @@ void calculate_reference_multi_dimensional(
118118
e_idx.insert(e_idx.end(), n_idx.begin(), n_idx.end());
119119

120120
::EDataType result = static_cast<::EDataType>(sum);
121-
std::vector<::DDataType> d_vals;
122-
for(const auto& d_tensor : ds_full_dims_host)
121+
if(ds_full_dims_host.size() == 0)
123122
{
124-
d_vals.push_back(ck_tile::type_convert<float>(d_tensor(e_idx)));
123+
;
125124
}
126-
if(d_vals.size() == 2)
125+
else if(ds_full_dims_host.size() == 1)
127126
{
128-
cde_elementwise(
129-
result, ck_tile::type_convert<float>(sum), d_vals[0], d_vals[1]);
127+
cde_elementwise(result,
128+
ck_tile::type_convert<float>(sum),
129+
ck_tile::type_convert<float>(ds_full_dims_host[0](e_idx)));
130+
}
131+
else if(ds_full_dims_host.size() == 2)
132+
{
133+
cde_elementwise(result,
134+
ck_tile::type_convert<float>(sum),
135+
ck_tile::type_convert<float>(ds_full_dims_host[0](e_idx)),
136+
ck_tile::type_convert<float>(ds_full_dims_host[1](e_idx)));
137+
}
138+
else if(ds_full_dims_host.size() == 3)
139+
{
140+
cde_elementwise(result,
141+
ck_tile::type_convert<float>(sum),
142+
ck_tile::type_convert<float>(ds_full_dims_host[0](e_idx)),
143+
ck_tile::type_convert<float>(ds_full_dims_host[1](e_idx)),
144+
ck_tile::type_convert<float>(ds_full_dims_host[2](e_idx)));
145+
}
146+
else if(ds_full_dims_host.size() == 4)
147+
{
148+
cde_elementwise(result,
149+
ck_tile::type_convert<float>(sum),
150+
ck_tile::type_convert<float>(ds_full_dims_host[0](e_idx)),
151+
ck_tile::type_convert<float>(ds_full_dims_host[1](e_idx)),
152+
ck_tile::type_convert<float>(ds_full_dims_host[2](e_idx)),
153+
ck_tile::type_convert<float>(ds_full_dims_host[3](e_idx)));
154+
}
155+
else
156+
{
157+
throw std::runtime_error("Unsupported NumDTensor for reference calculation");
130158
}
131159

132160
e_full_dims_host_ref(e_idx) = static_cast<::EDataType>(result);
@@ -165,18 +193,69 @@ void calculate_reference_flat_indexing(
165193
sum += static_cast<::AccDataType>(a_val) * static_cast<::AccDataType>(b_val);
166194
}
167195

168-
std::vector<::DDataType> d_vals;
169-
for(const auto& d_tensor : ds_full_dims_host)
196+
::EDataType result = static_cast<::EDataType>(sum);
197+
if(ds_full_dims_host.size() == 0)
170198
{
171-
d_vals.push_back(ck_tile::type_convert<float>(
172-
d_tensor.mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]));
199+
;
173200
}
174-
::EDataType result = static_cast<::EDataType>(sum);
175-
if(d_vals.size() == 2)
201+
else if(ds_full_dims_host.size() == 1)
202+
{
203+
cde_elementwise(result,
204+
ck_tile::type_convert<float>(sum),
205+
ck_tile::type_convert<float>(
206+
ds_full_dims_host[0].mData[g_flat * M_total * N_total +
207+
m_flat * N_total + n_flat]));
208+
}
209+
else if(ds_full_dims_host.size() == 2)
210+
{
211+
cde_elementwise(
212+
result,
213+
ck_tile::type_convert<float>(sum),
214+
ck_tile::type_convert<float>(
215+
ds_full_dims_host[0]
216+
.mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
217+
ck_tile::type_convert<float>(
218+
ds_full_dims_host[1]
219+
.mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]));
220+
}
221+
else if(ds_full_dims_host.size() == 3)
222+
{
223+
cde_elementwise(
224+
result,
225+
ck_tile::type_convert<float>(sum),
226+
ck_tile::type_convert<float>(
227+
ds_full_dims_host[0]
228+
.mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
229+
ck_tile::type_convert<float>(
230+
ds_full_dims_host[1]
231+
.mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
232+
ck_tile::type_convert<float>(
233+
ds_full_dims_host[2]
234+
.mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]));
235+
}
236+
else if(ds_full_dims_host.size() == 4)
176237
{
177238
cde_elementwise(
178-
result, ck_tile::type_convert<float>(sum), d_vals[0], d_vals[1]);
239+
result,
240+
ck_tile::type_convert<float>(sum),
241+
ck_tile::type_convert<float>(
242+
ds_full_dims_host[0]
243+
.mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
244+
ck_tile::type_convert<float>(
245+
ds_full_dims_host[1]
246+
.mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
247+
ck_tile::type_convert<float>(
248+
ds_full_dims_host[2]
249+
.mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
250+
ck_tile::type_convert<float>(
251+
ds_full_dims_host[3]
252+
.mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]));
253+
}
254+
else
255+
{
256+
throw std::runtime_error("Unsupported NumDTensor for reference calculation");
179257
}
258+
180259
e_full_dims_host_ref.mData[g_flat * M_total * N_total + m_flat * N_total + n_flat] =
181260
static_cast<::EDataType>(result);
182261
}
@@ -368,25 +447,34 @@ int run_batched_contraction_example_with_layouts(
368447
ck_tile::HostTensorDescriptor(Ds_dims[d], Ds_strides[d])));
369448
}
370449

371-
ck_tile::FillUniformDistribution<::DDataType>{-2.f, 2.f, std::nullopt}(ds_full_dims_host[0]);
372-
ck_tile::FillUniformDistribution<::DDataType>{-2.f, 2.f, std::nullopt}(ds_full_dims_host[1]);
373-
374-
ck_tile::DeviceMem d0_full_dims_dev_buf(ds_full_dims_host[0].get_element_space_size_in_bytes());
375-
ck_tile::DeviceMem d1_full_dims_dev_buf(ds_full_dims_host[1].get_element_space_size_in_bytes());
376-
d0_full_dims_dev_buf.ToDevice(ds_full_dims_host[0].data());
377-
d1_full_dims_dev_buf.ToDevice(ds_full_dims_host[1].data());
450+
for(int d = 0; d < NumDTensor; ++d)
451+
{
452+
ck_tile::FillUniformDistribution<::DDataType>{-2.f, 2.f, std::nullopt}(
453+
ds_full_dims_host[d]);
454+
}
378455

379-
std::array<const void*, NumDTensor> ds_ptr_buf = {d0_full_dims_dev_buf.GetDeviceBuffer(),
380-
d1_full_dims_dev_buf.GetDeviceBuffer()};
456+
std::vector<std::unique_ptr<ck_tile::DeviceMem>> ds_full_dims_dev_buf;
457+
for(int d = 0; d < NumDTensor; ++d)
458+
{
459+
ds_full_dims_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
460+
ds_full_dims_host[d].get_element_space_size_in_bytes()));
461+
ds_full_dims_dev_buf[d]->ToDevice(ds_full_dims_host[d].data());
462+
}
463+
std::array<const void*, NumDTensor> ds_ptr_buf;
464+
for(int d = 0; d < NumDTensor; ++d)
465+
{
466+
ds_ptr_buf[d] = ds_full_dims_dev_buf[d]->GetDeviceBuffer();
467+
}
381468

382469
e_full_dims_dev_buf.SetZero();
383470
e_full_dims_host.SetZero();
384471

385472
std::cout << "\n=== Running GPU Kernel ===" << std::endl;
386473

387-
using DsDataType = ck_tile::tuple_array<::DDataType, NumDTensor>;
388-
using DsLayout = ck_tile::tuple_array<DLayout, NumDTensor>;
389-
using CDEElementWise = AddDs;
474+
using DsDataType = ck_tile::tuple_array<::DDataType, NumDTensor>;
475+
using DsLayout = ck_tile::tuple_array<DLayout, NumDTensor>;
476+
using CDEElementWise =
477+
std::conditional_t<NumDTensor == 0, ck_tile::element_wise::PassThrough, AddDs>;
390478

391479
float ave_time =
392480
invoke_batched_contraction_kernel<::ADataType,
@@ -427,11 +515,13 @@ int run_batched_contraction_example_with_layouts(
427515
"D, M: " + std::to_string(M_dims.size()) + "D, N: " + std::to_string(N_dims.size()) +
428516
"D, K: " + std::to_string(K_dims.size()) + "D"};
429517

430-
std::size_t flop =
431-
std::size_t(2) * G_total * M_total * N_total * K_total; // Number of operations
432-
std::size_t num_byte = sizeof(::ADataType) * G_total * M_total * K_total + // A tensor size
433-
sizeof(::BDataType) * G_total * N_total * K_total + // B tensor size
434-
sizeof(::EDataType) * G_total * M_total * N_total; // E tensor size
518+
std::size_t flop = std::size_t(2) * G_total * M_total * N_total * K_total +
519+
NumDTensor * K_total * M_total * N_total; // Number of operations
520+
std::size_t num_byte =
521+
sizeof(::ADataType) * G_total * M_total * K_total + // A tensor size
522+
sizeof(::BDataType) * G_total * N_total * K_total + // B tensor size
523+
sizeof(::DDataType) * NumDTensor * G_total * M_total * N_total + // D tensors
524+
sizeof(::EDataType) * G_total * M_total * N_total; // E tensor size
435525

436526
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; // TFlops calculation
437527
float gb_per_sec = num_byte / 1.E6 / ave_time; // GB/s calculation
@@ -443,23 +533,6 @@ int run_batched_contraction_example_with_layouts(
443533
std::cout << " Performance: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
444534
<< " GB/s" << std::endl;
445535

446-
// DETAILED: Tensor information
447-
std::cout << "\nDetailed Tensor Info:" << std::endl;
448-
std::cout << " A tensor: " << G_total << " x " << M_total << " x " << K_total << " = "
449-
<< G_total * M_total * K_total << " elements ("
450-
<< (sizeof(::ADataType) * G_total * M_total * K_total) / 1024 / 1024 << " MB)"
451-
<< std::endl;
452-
std::cout << " B tensor: " << G_total << " x " << N_total << " x " << K_total << " = "
453-
<< G_total * N_total * K_total << " elements ("
454-
<< (sizeof(::BDataType) * G_total * N_total * K_total) / 1024 / 1024 << " MB)"
455-
<< std::endl;
456-
std::cout << " E tensor: " << G_total << " x " << M_total << " x " << N_total << " = "
457-
<< G_total * M_total * N_total << " elements ("
458-
<< (sizeof(::EDataType) * G_total * M_total * N_total) / 1024 / 1024 << " MB)"
459-
<< std::endl;
460-
std::cout << " Total memory: " << num_byte / 1024 / 1024 << " MB" << std::endl;
461-
std::cout << " Total FLOPs: " << flop / 1000000 << " million" << std::endl;
462-
463536
std::cout << "===============================================" << std::endl;
464537

465538
e_full_dims_dev_buf.FromDevice(e_full_dims_host.data());

include/ck_tile/ops/batched_contraction.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
#pragma once
55

6-
#include "ck_tile/ops/batched_contraction/kernel/batched_conratction_utils.hpp"
76
#include "ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp"
87
#include "ck_tile/ops/batched_contraction/pipeline/batched_contraction_problem.hpp"
98
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"

include/ck_tile/ops/batched_contraction/kernel/batched_conratction_utils.hpp

Lines changed: 0 additions & 10 deletions
This file was deleted.

include/ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
#include "ck_tile/core.hpp"
77
#include "ck_tile/ops/batched_contraction/pipeline/batched_contraction_problem.hpp"
88
#include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp"
9-
#include "ck_tile/ops/batched_contraction/kernel/batched_conratction_utils.hpp"
109

1110
namespace ck_tile {
1211

0 commit comments

Comments
 (0)