From 076d53a07d540da2a2dbe997332c55947ef5aa13 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Tue, 23 Sep 2025 08:58:47 -0500 Subject: [PATCH 01/18] Allow sharing partition index across threads --- include/ck_tile/core/tensor/load_tile.hpp | 29 +++++ .../core/tensor/load_tile_transpose.hpp | 7 +- include/ck_tile/core/tensor/store_tile.hpp | 4 + include/ck_tile/core/tensor/tile_window.hpp | 108 ++++++++++++++---- include/ck_tile/core/tensor/update_tile.hpp | 4 + 5 files changed, 127 insertions(+), 25 deletions(-) diff --git a/include/ck_tile/core/tensor/load_tile.hpp b/include/ck_tile/core/tensor/load_tile.hpp index a3620453b4c..864646156ee 100644 --- a/include/ck_tile/core/tensor/load_tile.hpp +++ b/include/ck_tile/core/tensor/load_tile.hpp @@ -19,6 +19,17 @@ namespace ck_tile { template + requires std::is_class_v +CK_TILE_DEVICE auto load_tile(const TileWindow_& tile_window, + index_t offset, + number = {}, + bool_constant = {}) +{ + return tile_window.load(offset, number{}, bool_constant{}); +} + +template + requires std::is_class_v CK_TILE_DEVICE auto load_tile(const TileWindow_& tile_window, number = {}, bool_constant = {}) @@ -53,6 +64,22 @@ template + requires std::is_class_v && std::is_class_v +CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile, + const TileWindow_& tile_window, + index_t offset, + number = {}, + bool_constant = {}) +{ + return tile_window.load( + offset, dst_tile, number{}, bool_constant{}); +} + +template + requires std::is_class_v && std::is_class_v CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile, const TileWindow_& tile_window, number = {}, @@ -74,6 +101,7 @@ template & tile_window, number = {}, bool_constant = {}, diff --git a/include/ck_tile/core/tensor/load_tile_transpose.hpp b/include/ck_tile/core/tensor/load_tile_transpose.hpp index 15352507227..5bed78fdbe5 100644 --- a/include/ck_tile/core/tensor/load_tile_transpose.hpp +++ b/include/ck_tile/core/tensor/load_tile_transpose.hpp @@ -393,6 +393,7 @@ template < typename BottomTensorView_, typename WindowLengths_, typename TileDistribution_, + typename PartitoinIndex_, index_t NumCoord, typename Policy = DefaultTranspose, typename = std::enable_if_t& tile_window) + PartitoinIndex_, + NumCoord>& __restrict__ tile_window, + index_t offset = 0) { using OutTileDstrEncode = typename OutputTileDistributionTraits< typename TileDistribution_::DstrEncode, typename BottomTensorView_::DataType>::TransposedDstrEncode; auto out_tensor = make_static_distributed_tensor( make_static_tile_distribution(OutTileDstrEncode{})); - auto trans_tensor = tile_window.template load_transpose(); + auto trans_tensor = tile_window.template load_transpose(offset); constexpr auto input_distr = TileDistribution_{}; constexpr auto output_distr = make_static_tile_distribution(OutTileDstrEncode{}); diff --git a/include/ck_tile/core/tensor/store_tile.hpp b/include/ck_tile/core/tensor/store_tile.hpp index d5a716664d5..3394902e78c 100644 --- a/include/ck_tile/core/tensor/store_tile.hpp +++ b/include/ck_tile/core/tensor/store_tile.hpp @@ -64,12 +64,14 @@ store_tile_raw(tile_window_with_static_lengths CK_TILE_DEVICE void store_tile(tile_window_with_static_distribution& tile_window, const static_distributed_tensor& dstr_tensor) { @@ -79,12 +81,14 @@ store_tile(tile_window_with_static_distribution CK_TILE_DEVICE void store_tile_raw(tile_window_with_static_distribution& tile_window, const static_distributed_tensor& dstr_tensor) { diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp index 2db5d719c09..27ec5c1220c 100644 --- a/include/ck_tile/core/tensor/tile_window.hpp +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -33,12 +33,14 @@ namespace ck_tile { template struct tile_window_with_static_distribution : public tile_window_with_tile_dstr_base< tile_window_with_static_distribution, BottomTensorView_, WindowLengths_, @@ -48,6 +50,7 @@ struct tile_window_with_static_distribution tile_window_with_static_distribution, BottomTensorView_, WindowLengths_, @@ -77,8 +80,21 @@ struct tile_window_with_static_distribution this->tile_dstr_ = tile_distribution; const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate( tile_distribution.get_ps_ys_to_xs_adaptor(), - container_concat(detail::get_partition_index(tile_distribution), - array{0})); + container_concat( + // use PartitionIndex if all the indices are non-negative + [&] { + if constexpr(0 <= PartitionIndex{}[number<0>{}] && + 0 <= PartitionIndex{}[number<1>{}]) + { + return array{PartitionIndex{}[number<0>{}], + PartitionIndex{}[number<1>{}]}; + } + else + { + return detail::get_partition_index(tile_distribution); + } + }(), + array{0})); typename Base::BottomTensorIndex bottom_tensor_thread_origin_idx_tmp = window_origin + window_adaptor_thread_coord_tmp.get_bottom_index(); @@ -113,10 +129,21 @@ struct tile_window_with_static_distribution template CK_TILE_DEVICE auto load(number = {}, bool_constant = {}) const + { + return load(0, number{}, bool_constant{}); + } + + template + CK_TILE_DEVICE auto load(index_t offset, + number = {}, + bool_constant = {}) const { constexpr auto tile_dstr = typename Base::TileDstr{}; auto dst_tensor = make_static_distributed_tensor(tile_dstr); - load(dst_tensor, number{}, bool_constant{}); + load(offset, + dst_tensor, + number{}, + bool_constant{}); return dst_tensor; } @@ -233,7 +260,8 @@ struct tile_window_with_static_distribution template - CK_TILE_DEVICE auto load(DistributedTensor& dst_tensor, + CK_TILE_DEVICE auto load(index_t offset, + DistributedTensor& dst_tensor, number = {}, bool_constant = {}) const { @@ -258,7 +286,7 @@ struct tile_window_with_static_distribution // read from bottom tensor const vector_t vec_value = this->get_bottom_tensor_view().template get_vectorized_elements( - bottom_tensor_thread_coord, 0, bool_constant{}); + bottom_tensor_thread_coord, offset, bool_constant{}); // write into distributed tensor static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) { constexpr auto idx_ys = generate_tuple( @@ -509,12 +537,24 @@ struct tile_window_with_static_distribution } template - CK_TILE_DEVICE auto load_transpose() const + CK_TILE_DEVICE auto load_transpose(number = {}, + bool_constant = {}) const + { + return this->template load_transpose( + 0, number{}, bool_constant{}); + } + + template + CK_TILE_DEVICE auto load_transpose(index_t offset, + number = {}, + bool_constant = {}) const { constexpr auto tile_dstr = typename Base::TileDstr{}; auto dst_tensor = make_static_distributed_tensor(tile_dstr); - this->template load_transpose( - dst_tensor, number{}, bool_constant{}); + this->template load_transpose(offset, + dst_tensor, + number{}, + bool_constant{}); return dst_tensor; } @@ -522,7 +562,8 @@ struct tile_window_with_static_distribution typename DistributedTensor, index_t i_access_unsupport_ = -1, bool oob_conditional_check = true> - CK_TILE_DEVICE auto load_transpose(DistributedTensor& dst_tensor, + CK_TILE_DEVICE auto load_transpose(index_t offset, + DistributedTensor& dst_tensor, number = {}, bool_constant = {}) const { @@ -550,7 +591,7 @@ struct tile_window_with_static_distribution const vector_t vec_value = this->get_bottom_tensor_view() .template get_transpose_vectorized_elements( - bottom_tensor_thread_coord, 0); + bottom_tensor_thread_coord, offset); // write into distributed tensor static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) { constexpr auto orig_idx_ys = generate_tuple( @@ -914,17 +955,20 @@ struct tile_window_with_static_distribution template + typename PartitionIndex = sequence<-1, -1>, + index_t NumCoord = 1> CK_TILE_DEVICE constexpr auto make_tile_window(const TensorView_& tensor_view, const WindowLengths_& window_lengths, const multi_index& origin, const StaticTileDistribution_& tile_distribution, + PartitionIndex = {}, number = {}) { return tile_window_with_static_distribution, remove_cvref_t, remove_cvref_t, + PartitionIndex, NumCoord>{ tensor_view, window_lengths, origin, tile_distribution}; } @@ -933,17 +977,20 @@ make_tile_window(const TensorView_& tensor_view, template + typename PartitionIndex = sequence<-1, -1>, + index_t NumCoord = 1> CK_TILE_DEVICE auto make_tile_window_raw(const TensorView_& tensor_view, const WindowLengths_& window_lengths, const multi_index& origin, const StaticTileDistribution_& tile_distribution, + PartitionIndex = {}, number = {}) { auto w = tile_window_with_static_distribution, remove_cvref_t, remove_cvref_t, + PartitionIndex, NumCoord>{ tensor_view, window_lengths, origin, tile_distribution}; w.init_raw(); @@ -953,15 +1000,18 @@ make_tile_window_raw(const TensorView_& tensor_view, template CK_TILE_DEVICE void move_tile_window( tile_window_with_static_distribution& window, const typename tile_window_with_static_distribution::BottomTensorIndex& step) { window.move(step); @@ -1108,38 +1158,48 @@ make_tile_window(const tile_window_with_static_lengths +template > CK_TILE_DEVICE constexpr auto make_tile_window(const tile_window_with_static_lengths& tile_window, const multi_index& origin, - const StaticTileDistribution& tile_distribution) + const StaticTileDistribution& tile_distribution, + PartitionIndex = {}) { return make_tile_window(tile_window.get_bottom_tensor_view(), tile_window.get_window_lengths(), origin, - tile_distribution); + tile_distribution, + PartitionIndex{}); } -template +template > CK_TILE_DEVICE constexpr auto make_tile_window(const tile_window_with_static_lengths& tile_window, - const StaticTileDistribution& tile_distribution) + const StaticTileDistribution& tile_distribution, + PartitionIndex = {}) { return make_tile_window(tile_window.get_bottom_tensor_view(), tile_window.get_window_lengths(), tile_window.get_window_origin(), - tile_distribution); + tile_distribution, + PartitionIndex{}); } -template +template > CK_TILE_DEVICE constexpr auto make_tile_window_raw(const tile_window_with_static_lengths& tile_window, const StaticTileDistribution& tile_distribution) { - auto w = make_tile_window(tile_window.get_bottom_tensor_view(), - tile_window.get_window_lengths(), - tile_window.get_window_origin(), - tile_distribution); + auto w = make_tile_window(tile_window, tile_distribution, PartitionIndex{}); w.init_raw(); return w; } @@ -1176,11 +1236,13 @@ struct is_tile_window_with_static_distribution : std::false_type template struct is_tile_window_with_static_distribution< tile_window_with_static_distribution> : std::true_type { }; diff --git a/include/ck_tile/core/tensor/update_tile.hpp b/include/ck_tile/core/tensor/update_tile.hpp index 570abde1893..1a7b2b139d9 100644 --- a/include/ck_tile/core/tensor/update_tile.hpp +++ b/include/ck_tile/core/tensor/update_tile.hpp @@ -40,6 +40,7 @@ update_tile(tile_window_with_static_lengths& template & tile_window, const static_distributed_tensor& dstr_tensor, number = {}, @@ -59,6 +61,7 @@ update_tile(tile_window_with_static_distribution& tile_window, const static_distributed_tensor& dstr_tensor, number = {}, From a2f5011eeb2b673a5d89b8df28910ef05f40257f Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Tue, 23 Sep 2025 15:10:46 -0500 Subject: [PATCH 02/18] Fix typo PartitoinIndex -> PartitionIndex --- include/ck_tile/core/tensor/update_tile.hpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/include/ck_tile/core/tensor/update_tile.hpp b/include/ck_tile/core/tensor/update_tile.hpp index 1a7b2b139d9..846f0244d6c 100644 --- a/include/ck_tile/core/tensor/update_tile.hpp +++ b/include/ck_tile/core/tensor/update_tile.hpp @@ -40,7 +40,7 @@ update_tile(tile_window_with_static_lengths& template & tile_window, const static_distributed_tensor& dstr_tensor, number = {}, @@ -61,7 +61,7 @@ update_tile(tile_window_with_static_distribution& tile_window, const static_distributed_tensor& dstr_tensor, number = {}, From 16d495984f852beabac46f04d98f6a2fd3cd1826 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Tue, 23 Sep 2025 15:27:04 -0500 Subject: [PATCH 03/18] Remove C++20 'requires' usages --- include/ck_tile/core/tensor/load_tile.hpp | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/include/ck_tile/core/tensor/load_tile.hpp b/include/ck_tile/core/tensor/load_tile.hpp index 864646156ee..8d470686391 100644 --- a/include/ck_tile/core/tensor/load_tile.hpp +++ b/include/ck_tile/core/tensor/load_tile.hpp @@ -18,8 +18,10 @@ namespace ck_tile { -template - requires std::is_class_v +template >> CK_TILE_DEVICE auto load_tile(const TileWindow_& tile_window, index_t offset, number = {}, @@ -28,8 +30,10 @@ CK_TILE_DEVICE auto load_tile(const TileWindow_& tile_window, return tile_window.load(offset, number{}, bool_constant{}); } -template - requires std::is_class_v +template >> CK_TILE_DEVICE auto load_tile(const TileWindow_& tile_window, number = {}, bool_constant = {}) @@ -63,8 +67,9 @@ CK_TILE_DEVICE auto load_tile_with_elementwise(const TileWindow_& tile_window, template - requires std::is_class_v && std::is_class_v + bool oob_conditional_check = true, + typename = + std::enable_if_t && std::is_class_v>> CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile, const TileWindow_& tile_window, index_t offset, @@ -78,8 +83,9 @@ CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile, template - requires std::is_class_v && std::is_class_v + bool oob_conditional_check = true, + typename = + std::enable_if_t && std::is_class_v>> CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile, const TileWindow_& tile_window, number = {}, From a06f46d4529fc0d7997ed6a15241dca409f1a379 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Tue, 23 Sep 2025 15:27:41 -0500 Subject: [PATCH 04/18] Add missing template arguments --- include/ck_tile/core/tensor/tile_window.hpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp index 27ec5c1220c..cad68d0cff7 100644 --- a/include/ck_tile/core/tensor/tile_window.hpp +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -1020,20 +1020,24 @@ CK_TILE_DEVICE void move_tile_window( template CK_TILE_DEVICE void move_tile_window( tuple>& window, const typename tile_window_with_static_distribution::BottomTensorIndex& step) { using T = tuple>; static constexpr auto N = T::size(); From 7437d10d6d0002c9701721949cb3175fba30c27c Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Wed, 24 Sep 2025 05:07:53 -0500 Subject: [PATCH 05/18] Fix load_tile() overload ambiguity issue --- include/ck_tile/core/tensor/load_tile.hpp | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/include/ck_tile/core/tensor/load_tile.hpp b/include/ck_tile/core/tensor/load_tile.hpp index 37f8b0fc376..4ea54bb5e5d 100644 --- a/include/ck_tile/core/tensor/load_tile.hpp +++ b/include/ck_tile/core/tensor/load_tile.hpp @@ -18,12 +18,15 @@ namespace ck_tile { +// Use SFINAE by declaring offset as integral rather than index_t, in order to avoid +// overload ambiguity caused by the implicit number<> to index_t conversion template >> + typename = std::enable_if_t && std::is_integral_v>> CK_TILE_DEVICE auto load_tile(const TileWindow_& tile_window, - index_t offset, + Offset offset, number = {}, bool_constant = {}) { @@ -64,15 +67,18 @@ CK_TILE_DEVICE auto load_tile_with_elementwise(const TileWindow_& tile_window, tile_window, elementwise, number{}, bool_constant{}); } +// Use SFINAE by declaring offset as integral rather than index_t, in order to avoid +// overload ambiguity caused by the implicit number<> to index_t conversion template && std::is_class_v>> + typename = std::enable_if_t && + std::is_class_v && std::is_integral_v>> CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile, const TileWindow_& tile_window, - index_t offset, + Offset offset, number = {}, bool_constant = {}) { From e0fac21aeaae827905bc6ed3a15f021a72ea73cf Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Fri, 26 Sep 2025 01:41:11 -0500 Subject: [PATCH 06/18] Use SFINAE to exclude invalid arguments --- include/ck_tile/core/tensor/load_tile.hpp | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/include/ck_tile/core/tensor/load_tile.hpp b/include/ck_tile/core/tensor/load_tile.hpp index 4ea54bb5e5d..a45db04c1aa 100644 --- a/include/ck_tile/core/tensor/load_tile.hpp +++ b/include/ck_tile/core/tensor/load_tile.hpp @@ -52,10 +52,13 @@ CK_TILE_DEVICE auto load_tile(const TileWindow_& tile_window, * and an elementwise function. For each A = A0, A1… AN, the elementwise function * is additionally applied during a single read. */ -template +template < + typename TileWindow_, + typename ElementWise_, + index_t i_access = -1, + bool oob_conditional_check = true, + typename = std::enable_if_t && std::is_class_v && + !is_constant_v>> CK_TILE_DEVICE auto load_tile_with_elementwise(const TileWindow_& tile_window, ElementWise_ elementwise, number = {}, @@ -90,8 +93,8 @@ template && std::is_class_v>> + typename = std::enable_if_t && + std::is_class_v && !is_constant_v>> CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile, const TileWindow_& tile_window, number = {}, From b9b492f5591e4368ba52c1fdbea822d123e0a37d Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Fri, 26 Sep 2025 02:44:12 -0500 Subject: [PATCH 07/18] Add additional offset parameter to the async_load_tile() --- include/ck_tile/core/tensor/load_tile.hpp | 8 ++++++-- include/ck_tile/core/tensor/tile_window.hpp | 5 +++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/include/ck_tile/core/tensor/load_tile.hpp b/include/ck_tile/core/tensor/load_tile.hpp index a45db04c1aa..b1b1cef9fdd 100644 --- a/include/ck_tile/core/tensor/load_tile.hpp +++ b/include/ck_tile/core/tensor/load_tile.hpp @@ -158,15 +158,19 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile, template + bool oob_conditional_check = true, + typename = std::enable_if_t> && + std::is_class_v && std::is_integral_v>> CK_TILE_DEVICE auto async_load_tile(LdsTileWindow_&& lds_tile, const TileWindow_& tile_window, + Offset offset = 0, number = {}, bool_constant = {}) { return tile_window.async_load( - lds_tile, number{}, bool_constant{}); + offset, lds_tile, number{}, bool_constant{}); } template - CK_TILE_DEVICE auto async_load(LdsTileWindow_&& lds_tile, + CK_TILE_DEVICE auto async_load(index_t offset, + LdsTileWindow_&& lds_tile, number = {}, bool_constant = {}) const { @@ -518,7 +519,7 @@ struct tile_window_with_static_distribution this->get_bottom_tensor_view().template async_get_vectorized_elements( smem, bottom_tensor_thread_coord, - number<0>{}, + offset, bool_constant{}); // Move thread coordinate if not last access From 27535557bc266d18c7780ac24edb1f57bb22f8b8 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Fri, 26 Sep 2025 04:02:07 -0500 Subject: [PATCH 08/18] Remove async_load_tile() default argument to avoid ambiguity --- include/ck_tile/core/tensor/load_tile.hpp | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/include/ck_tile/core/tensor/load_tile.hpp b/include/ck_tile/core/tensor/load_tile.hpp index b1b1cef9fdd..a9bbd86a5bc 100644 --- a/include/ck_tile/core/tensor/load_tile.hpp +++ b/include/ck_tile/core/tensor/load_tile.hpp @@ -165,7 +165,7 @@ template && std::is_integral_v>> CK_TILE_DEVICE auto async_load_tile(LdsTileWindow_&& lds_tile, const TileWindow_& tile_window, - Offset offset = 0, + Offset offset, number = {}, bool_constant = {}) { @@ -173,6 +173,21 @@ CK_TILE_DEVICE auto async_load_tile(LdsTileWindow_&& lds_tile, offset, lds_tile, number{}, bool_constant{}); } +template > && + std::is_class_v>> +CK_TILE_DEVICE auto async_load_tile(LdsTileWindow_&& lds_tile, + const TileWindow_& tile_window, + number = {}, + bool_constant = {}) +{ + return async_load_tile( + lds_tile, tile_window, 0, number{}, bool_constant{}); +} + template Date: Sat, 27 Sep 2025 08:54:38 -0500 Subject: [PATCH 09/18] Extract tile_window coordinate compute logic as method --- include/ck_tile/core/tensor/tile_window.hpp | 49 ++++++++++++++------- 1 file changed, 32 insertions(+), 17 deletions(-) diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp index 890ad76a688..29f6a629e70 100644 --- a/include/ck_tile/core/tensor/tile_window.hpp +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -74,25 +74,39 @@ struct tile_window_with_static_distribution : pre_computed_coords_{} { - this->window_origin_ = window_origin; - this->window_lengths_ = window_lengths; - this->bottom_tensor_view_ = bottom_tensor_view; - this->tile_dstr_ = tile_distribution; + this->window_origin_ = window_origin; + this->window_lengths_ = window_lengths; + this->bottom_tensor_view_ = bottom_tensor_view; + this->tile_dstr_ = tile_distribution; + + pre_computed_coords_ = prepare_coords(bottom_tensor_view, window_origin, tile_distribution); + } + + template + CK_TILE_DEVICE constexpr auto + prepare_coords(const typename Base::BottomTensorView& bottom_tensor_view, + const typename Base::BottomTensorIndex& window_origin, + const typename Base::TileDstr& tile_distribution, + NewPartitionIndex = {}) const + { + array, NumCoord> + coords; + const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate( tile_distribution.get_ps_ys_to_xs_adaptor(), container_concat( - // use PartitionIndex if all the indices are non-negative + // use NewPartitionIndex if all the indices are non-negative [&] { - if constexpr(0 <= PartitionIndex{}[number<0>{}] && - 0 <= PartitionIndex{}[number<1>{}]) - { - return array{PartitionIndex{}[number<0>{}], - PartitionIndex{}[number<1>{}]}; - } - else - { - return detail::get_partition_index(tile_distribution); - } + auto partition_index = detail::get_partition_index(tile_distribution); + static_for<0, + ck_tile::min(partition_index.size(), NewPartitionIndex::size()), + 1>{}([&](auto idx) { + if constexpr(0 <= NewPartitionIndex{}[idx]) + { + partition_index[idx] = NewPartitionIndex{}[idx]; + } + }); + return partition_index; }(), array{0})); @@ -121,9 +135,10 @@ struct tile_window_with_static_distribution Base::move_window_adaptor_and_bottom_tensor_thread_coordinate( window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); - pre_computed_coords_(iCoord) = - make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord); + coords(iCoord) = make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord); }); + + return coords; } template From 904d9957c8c4132e6f432e001764082c57d7a0ec Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Sat, 27 Sep 2025 09:55:44 -0500 Subject: [PATCH 10/18] Use warp-shared LDS base address in tile_window::async_load() --- include/ck_tile/core/tensor/tile_window.hpp | 23 +++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp index 29f6a629e70..63de70638cb 100644 --- a/include/ck_tile/core/tensor/tile_window.hpp +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -80,6 +80,12 @@ struct tile_window_with_static_distribution this->tile_dstr_ = tile_distribution; pre_computed_coords_ = prepare_coords(bottom_tensor_view, window_origin, tile_distribution); + if constexpr(Base::BottomTensorView::buffer_view::get_address_space() == + address_space_enum::global) + { + pre_computed_warp_coords_ = prepare_coords( + bottom_tensor_view, window_origin, tile_distribution, sequence<-1, 0>{}); + } } template @@ -95,7 +101,8 @@ struct tile_window_with_static_distribution const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate( tile_distribution.get_ps_ys_to_xs_adaptor(), container_concat( - // use NewPartitionIndex if all the indices are non-negative + // Override partition_index with the corresponding non-negative elements (if + // any) from NewPartitionIndex [&] { auto partition_index = detail::get_partition_index(tile_distribution); static_for<0, @@ -516,12 +523,15 @@ struct tile_window_with_static_distribution auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; + auto window_adaptor_warp_coord = pre_computed_warp_coords_[iCoord][I0]; + auto bottom_tensor_warp_coord = pre_computed_warp_coords_[iCoord][I1]; + static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { constexpr auto iAccess = number{}; // Use precomputed window origin auto lds_bottom_tensor_thread_idx = - window_origin + window_adaptor_thread_coord.get_bottom_index(); + window_origin + window_adaptor_warp_coord.get_bottom_index(); // Use precomputed tensor descriptor const auto lds_coord = @@ -547,6 +557,9 @@ struct tile_window_with_static_distribution Base::move_window_adaptor_and_bottom_tensor_thread_coordinate( window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); + + Base::move_window_adaptor_and_bottom_tensor_thread_coordinate( + window_adaptor_warp_coord, bottom_tensor_warp_coord, idx_diff_ps_ys); } }); }); @@ -965,6 +978,12 @@ struct tile_window_with_static_distribution // per-thread coordinate for bottom tensor array, NumCoord> pre_computed_coords_; + // pre_computed_warp_coords_ exists only in the global memory tile_window + std::conditional_t< + Base::BottomTensorView::buffer_view::get_address_space() == address_space_enum::global, + array, NumCoord>, + std::byte> + pre_computed_warp_coords_; }; // TODO: use strategy From f7ed84fc3fecbfc0bcc88eb113628228ba0c8936 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Sun, 28 Sep 2025 11:58:01 -0500 Subject: [PATCH 11/18] Add constraint to tile_window::load() templates --- include/ck_tile/core/tensor/tile_window.hpp | 44 ++++++++++++++------- 1 file changed, 30 insertions(+), 14 deletions(-) diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp index 63de70638cb..b44e26cadd1 100644 --- a/include/ck_tile/core/tensor/tile_window.hpp +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -155,8 +155,13 @@ struct tile_window_with_static_distribution return load(0, number{}, bool_constant{}); } - template - CK_TILE_DEVICE auto load(index_t offset, + // Use SFINAE by declaring offset as integral rather than index_t, in order to avoid + // overload ambiguity caused by the implicit number<> to index_t conversion + template >> + CK_TILE_DEVICE auto load(Offset offset, number = {}, bool_constant = {}) const { @@ -179,10 +184,13 @@ struct tile_window_with_static_distribution * The same thread, during vectorized reading, accesses the same set of * data from A0, A1, A2, … AN. */ - template + template < + typename TileWindow_, + typename ElementWise_, + index_t i_access_unsupport_ = -1, + bool oob_conditional_check = true, + typename = std::enable_if_t && std::is_class_v && + !is_constant_v>> CK_TILE_DEVICE auto load(const TileWindow_& tile_window, ElementWise_ elementwise, number = {}, @@ -198,11 +206,15 @@ struct tile_window_with_static_distribution return dst_tensor; } - template + template < + typename DistributedTensor, + typename TileWindow_, + typename ElementWise_, + index_t i_access_unsupport_ = -1, + bool oob_conditional_check = true, + typename = std::enable_if_t> && + std::is_class_v && std::is_class_v && + !is_constant_v>> CK_TILE_DEVICE auto load(DistributedTensor& dst_tensor, const TileWindow_& tile_window, ElementWise_ elementwise, @@ -279,10 +291,14 @@ struct tile_window_with_static_distribution }); } - template - CK_TILE_DEVICE auto load(index_t offset, + bool oob_conditional_check = true, + typename = std::enable_if_t && + std::is_class_v> && + !is_constant_v>>> + CK_TILE_DEVICE auto load(Offset offset, DistributedTensor& dst_tensor, number = {}, bool_constant = {}) const From 79b9b9ef17ee4d6a485d3a10d1af3c125e344765 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Sun, 28 Sep 2025 12:05:35 -0500 Subject: [PATCH 12/18] Fix wrong type traits is_class_v<> usages --- include/ck_tile/core/tensor/load_tile.hpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/include/ck_tile/core/tensor/load_tile.hpp b/include/ck_tile/core/tensor/load_tile.hpp index a9bbd86a5bc..1dc31243688 100644 --- a/include/ck_tile/core/tensor/load_tile.hpp +++ b/include/ck_tile/core/tensor/load_tile.hpp @@ -77,8 +77,8 @@ template && - std::is_class_v && std::is_integral_v>> + typename = std::enable_if_t> && + std::is_class_v && std::is_integral_v>> CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile, const TileWindow_& tile_window, Offset offset, @@ -93,8 +93,8 @@ template && - std::is_class_v && !is_constant_v>> + typename = std::enable_if_t> && + std::is_class_v && !is_constant_v>> CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile, const TileWindow_& tile_window, number = {}, From 0b16363c55ff48824ede33c6c807447bdcceb73a Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Sun, 28 Sep 2025 12:12:24 -0500 Subject: [PATCH 13/18] Add missing constraint to async_load_tile() --- include/ck_tile/core/tensor/load_tile.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/ck_tile/core/tensor/load_tile.hpp b/include/ck_tile/core/tensor/load_tile.hpp index 1dc31243688..5c5e18b298d 100644 --- a/include/ck_tile/core/tensor/load_tile.hpp +++ b/include/ck_tile/core/tensor/load_tile.hpp @@ -178,7 +178,7 @@ template > && - std::is_class_v>> + std::is_class_v && !is_constant_v>> CK_TILE_DEVICE auto async_load_tile(LdsTileWindow_&& lds_tile, const TileWindow_& tile_window, number = {}, From 12870fd88382a81121aed0507bcc3327823802dc Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Sun, 28 Sep 2025 13:10:08 -0500 Subject: [PATCH 14/18] Add missing tile_window::load() overload --- include/ck_tile/core/tensor/tile_window.hpp | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp index b44e26cadd1..bf31caacec6 100644 --- a/include/ck_tile/core/tensor/tile_window.hpp +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -291,6 +291,18 @@ struct tile_window_with_static_distribution }); } + template > && + !is_constant_v>>> + CK_TILE_DEVICE auto load(DistributedTensor& dst_tensor, + number = {}, + bool_constant = {}) const + { + load(0, dst_tensor, number{}, bool_constant{}); + } + template Date: Sun, 28 Sep 2025 13:16:49 -0500 Subject: [PATCH 15/18] Add more constraint to avoid load_tile() call ambiguity --- include/ck_tile/core/tensor/load_tile.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/ck_tile/core/tensor/load_tile.hpp b/include/ck_tile/core/tensor/load_tile.hpp index 5c5e18b298d..b9ff58b94fb 100644 --- a/include/ck_tile/core/tensor/load_tile.hpp +++ b/include/ck_tile/core/tensor/load_tile.hpp @@ -36,7 +36,7 @@ CK_TILE_DEVICE auto load_tile(const TileWindow_& tile_window, template >> + typename = std::enable_if_t && !is_constant_v>> CK_TILE_DEVICE auto load_tile(const TileWindow_& tile_window, number = {}, bool_constant = {}) From f81b6bf260ea54d9f6a5ac7734e76c9b735a0cb3 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Mon, 29 Sep 2025 21:31:22 -0500 Subject: [PATCH 16/18] Rename ParitionIndex as ReplacementPartitionIndex --- include/ck_tile/core/tensor/store_tile.hpp | 8 +- include/ck_tile/core/tensor/tile_window.hpp | 91 +++++++++++---------- include/ck_tile/core/tensor/update_tile.hpp | 8 +- 3 files changed, 55 insertions(+), 52 deletions(-) diff --git a/include/ck_tile/core/tensor/store_tile.hpp b/include/ck_tile/core/tensor/store_tile.hpp index 3394902e78c..899bad0d8ea 100644 --- a/include/ck_tile/core/tensor/store_tile.hpp +++ b/include/ck_tile/core/tensor/store_tile.hpp @@ -64,14 +64,14 @@ store_tile_raw(tile_window_with_static_lengths CK_TILE_DEVICE void store_tile(tile_window_with_static_distribution& tile_window, const static_distributed_tensor& dstr_tensor) { @@ -81,14 +81,14 @@ store_tile(tile_window_with_static_distribution CK_TILE_DEVICE void store_tile_raw(tile_window_with_static_distribution& tile_window, const static_distributed_tensor& dstr_tensor) { diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp index bf31caacec6..07d86c3f9c5 100644 --- a/include/ck_tile/core/tensor/tile_window.hpp +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -25,22 +25,23 @@ namespace ck_tile { * @note This tile window does not support single issue you need to use tile_window_linear * structure for this purpose * - * @tparam BottomTensorView_ Class describing & holding device tensor memory. - * @tparam WindowLengths_ Spatial sizes of windowed view on tensor. - * @tparam StaticTileDistribution_ Thread distribution (mapping) into Tile dimensions - * @tparam NumCoord TBD + * @tparam BottomTensorView_ Class describing & holding device tensor memory. + * @tparam WindowLengths_ Spatial sizes of windowed view on tensor. + * @tparam StaticTileDistribution_ Thread distribution (mapping) into Tile dimensions + * @tparam ReplacementPartitionIndex Replacement values of (get_warp_id(), get_lane_id()) tuple + * @tparam NumCoord TBD */ template struct tile_window_with_static_distribution : public tile_window_with_tile_dstr_base< tile_window_with_static_distribution, BottomTensorView_, WindowLengths_, @@ -50,7 +51,7 @@ struct tile_window_with_static_distribution tile_window_with_static_distribution, BottomTensorView_, WindowLengths_, @@ -88,12 +89,12 @@ struct tile_window_with_static_distribution } } - template + template CK_TILE_DEVICE constexpr auto prepare_coords(const typename Base::BottomTensorView& bottom_tensor_view, const typename Base::BottomTensorIndex& window_origin, const typename Base::TileDstr& tile_distribution, - NewPartitionIndex = {}) const + NewReplacementPartitionIndex = {}) const { array, NumCoord> coords; @@ -102,15 +103,16 @@ struct tile_window_with_static_distribution tile_distribution.get_ps_ys_to_xs_adaptor(), container_concat( // Override partition_index with the corresponding non-negative elements (if - // any) from NewPartitionIndex + // any) from NewReplacementPartitionIndex [&] { auto partition_index = detail::get_partition_index(tile_distribution); static_for<0, - ck_tile::min(partition_index.size(), NewPartitionIndex::size()), + ck_tile::min(partition_index.size(), + NewReplacementPartitionIndex::size()), 1>{}([&](auto idx) { - if constexpr(0 <= NewPartitionIndex{}[idx]) + if constexpr(0 <= NewReplacementPartitionIndex{}[idx]) { - partition_index[idx] = NewPartitionIndex{}[idx]; + partition_index[idx] = NewReplacementPartitionIndex{}[idx]; } }); return partition_index; @@ -1018,20 +1020,20 @@ struct tile_window_with_static_distribution template , - index_t NumCoord = 1> + typename ReplacementPartitionIndex = sequence<-1, -1>, + index_t NumCoord = 1> CK_TILE_DEVICE constexpr auto make_tile_window(const TensorView_& tensor_view, const WindowLengths_& window_lengths, const multi_index& origin, const StaticTileDistribution_& tile_distribution, - PartitionIndex = {}, - number = {}) + ReplacementPartitionIndex = {}, + number = {}) { return tile_window_with_static_distribution, remove_cvref_t, remove_cvref_t, - PartitionIndex, + ReplacementPartitionIndex, NumCoord>{ tensor_view, window_lengths, origin, tile_distribution}; } @@ -1040,20 +1042,20 @@ make_tile_window(const TensorView_& tensor_view, template , - index_t NumCoord = 1> + typename ReplacementPartitionIndex = sequence<-1, -1>, + index_t NumCoord = 1> CK_TILE_DEVICE auto make_tile_window_raw(const TensorView_& tensor_view, const WindowLengths_& window_lengths, const multi_index& origin, const StaticTileDistribution_& tile_distribution, - PartitionIndex = {}, - number = {}) + ReplacementPartitionIndex = {}, + number = {}) { auto w = tile_window_with_static_distribution, remove_cvref_t, remove_cvref_t, - PartitionIndex, + ReplacementPartitionIndex, NumCoord>{ tensor_view, window_lengths, origin, tile_distribution}; w.init_raw(); @@ -1063,18 +1065,18 @@ make_tile_window_raw(const TensorView_& tensor_view, template CK_TILE_DEVICE void move_tile_window( tile_window_with_static_distribution& window, const typename tile_window_with_static_distribution::BottomTensorIndex& step) { window.move(step); @@ -1083,24 +1085,24 @@ CK_TILE_DEVICE void move_tile_window( template CK_TILE_DEVICE void move_tile_window( tuple>& window, const typename tile_window_with_static_distribution::BottomTensorIndex& step) { using T = tuple>; static constexpr auto N = T::size(); @@ -1228,45 +1230,45 @@ make_tile_window(const tile_window_with_static_lengths> + typename ReplacementPartitionIndex = sequence<-1, -1>> CK_TILE_DEVICE constexpr auto make_tile_window(const tile_window_with_static_lengths& tile_window, const multi_index& origin, const StaticTileDistribution& tile_distribution, - PartitionIndex = {}) + ReplacementPartitionIndex = {}) { return make_tile_window(tile_window.get_bottom_tensor_view(), tile_window.get_window_lengths(), origin, tile_distribution, - PartitionIndex{}); + ReplacementPartitionIndex{}); } template > + typename ReplacementPartitionIndex = sequence<-1, -1>> CK_TILE_DEVICE constexpr auto make_tile_window(const tile_window_with_static_lengths& tile_window, const StaticTileDistribution& tile_distribution, - PartitionIndex = {}) + ReplacementPartitionIndex = {}) { return make_tile_window(tile_window.get_bottom_tensor_view(), tile_window.get_window_lengths(), tile_window.get_window_origin(), tile_distribution, - PartitionIndex{}); + ReplacementPartitionIndex{}); } template > + typename ReplacementPartitionIndex = sequence<-1, -1>> CK_TILE_DEVICE constexpr auto make_tile_window_raw(const tile_window_with_static_lengths& tile_window, const StaticTileDistribution& tile_distribution) { - auto w = make_tile_window(tile_window, tile_distribution, PartitionIndex{}); + auto w = make_tile_window(tile_window, tile_distribution, ReplacementPartitionIndex{}); w.init_raw(); return w; } @@ -1295,21 +1297,22 @@ struct is_tile_window_with_static_distribution : std::false_type /** * @brief Specialization for `tile_window_with_static_distribution` to evaluate to `true_type`. * - * @tparam BottomTensorView_ Bottom tensor view type of the tile window. - * @tparam WindowLengths_ Static window lengths. - * @tparam StaticTileDistribution_ Tile distribution policy. - * @tparam NumCoord Number of coordinate dimensions. + * @tparam BottomTensorView_ Class describing & holding device tensor memory. + * @tparam WindowLengths_ Spatial sizes of windowed view on tensor. + * @tparam StaticTileDistribution_ Thread distribution (mapping) into Tile dimensions + * @tparam ReplacementPartitionIndex Replacement values of (get_warp_id(), get_lane_id()) tuple + * @tparam NumCoord TBD */ template struct is_tile_window_with_static_distribution< tile_window_with_static_distribution> : std::true_type { }; diff --git a/include/ck_tile/core/tensor/update_tile.hpp b/include/ck_tile/core/tensor/update_tile.hpp index 846f0244d6c..553b763bbaa 100644 --- a/include/ck_tile/core/tensor/update_tile.hpp +++ b/include/ck_tile/core/tensor/update_tile.hpp @@ -40,7 +40,7 @@ update_tile(tile_window_with_static_lengths& template & tile_window, const static_distributed_tensor& dstr_tensor, number = {}, @@ -61,7 +61,7 @@ update_tile(tile_window_with_static_distribution& tile_window, const static_distributed_tensor& dstr_tensor, number = {}, From 5c5e2dda58b7e77b29487f3db396289b5a34f2cf Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Wed, 1 Oct 2025 11:13:50 -0500 Subject: [PATCH 17/18] Update pre_computed_warp_coords_ in move_extended() --- include/ck_tile/core/tensor/tile_window.hpp | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp index 07d86c3f9c5..6302d77102c 100644 --- a/include/ck_tile/core/tensor/tile_window.hpp +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -962,6 +962,16 @@ struct tile_window_with_static_distribution pre_computed_coords_(iCoord)(I1), step); }); + + if constexpr(Base::BottomTensorView::buffer_view::get_address_space() == + address_space_enum::global) + { + static_for<0, NumCoord, 1>{}([&](auto iCoord) { + move_tensor_coordinate(this->bottom_tensor_view_.get_tensor_descriptor(), + pre_computed_warp_coords_(iCoord)(I1), + step); + }); + } } CK_TILE_DEVICE void set_window_origin_extended(const typename Base::BottomTensorIndex&) From 28246a5dd65d5e3082b4c4a66c790d6c02268263 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Wed, 1 Oct 2025 12:36:20 -0500 Subject: [PATCH 18/18] Fix inconsistency between template parameters and documentation --- include/ck_tile/core/tensor/tile_window.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp index 6302d77102c..53d250766f3 100644 --- a/include/ck_tile/core/tensor/tile_window.hpp +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -1316,13 +1316,13 @@ struct is_tile_window_with_static_distribution : std::false_type template struct is_tile_window_with_static_distribution< tile_window_with_static_distribution> : std::true_type { };