Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
076d53a
Allow sharing partition index across threads
poyenc Sep 23, 2025
a2f5011
Fix typo PartitoinIndex -> PartitionIndex
poyenc Sep 23, 2025
16d4959
Remove C++20 'requires' usages
poyenc Sep 23, 2025
a06f46d
Add missing template arguments
poyenc Sep 23, 2025
570492d
Merge branch 'develop' into poyenc/specify-offset-load-tile
poyenc Sep 24, 2025
d4925f4
Merge remote-tracking branch 'origin/develop' into poyenc/specify-off…
poyenc Sep 24, 2025
7437d10
Fix load_tile() overload ambiguity issue
poyenc Sep 24, 2025
bf9234c
Merge branch 'develop' into poyenc/specify-offset-load-tile
poyenc Sep 25, 2025
e0fac21
Use SFINAE to exclude invalid arguments
poyenc Sep 26, 2025
7dd874f
Merge branch 'develop' into poyenc/specify-offset-load-tile
poyenc Sep 26, 2025
b9b492f
Add additional offset parameter to the async_load_tile()
poyenc Sep 26, 2025
2753555
Remove async_load_tile() default argument to avoid ambiguity
poyenc Sep 26, 2025
c539872
Merge branch 'develop' into poyenc/specify-offset-load-tile
poyenc Sep 26, 2025
ec1a615
Merge branch 'develop' into poyenc/specify-offset-load-tile
poyenc Sep 27, 2025
804b336
Extract tile_window coordinate compute logic as method
poyenc Sep 27, 2025
904d995
Use warp-shared LDS base address in tile_window::async_load()
poyenc Sep 27, 2025
85eab02
Merge branch 'develop' into poyenc/specify-offset-load-tile
poyenc Sep 27, 2025
f7ed84f
Add constraint to tile_window::load() templates
poyenc Sep 28, 2025
79b9b9e
Fix wrong type traits is_class_v<> usages
poyenc Sep 28, 2025
0b16363
Add missing constraint to async_load_tile()
poyenc Sep 28, 2025
6caa88b
Merge branch 'develop' into poyenc/specify-offset-load-tile
poyenc Sep 28, 2025
12870fd
Add missing tile_window::load() overload
poyenc Sep 28, 2025
8f56ee8
Add more constraint to avoid load_tile() call ambiguity
poyenc Sep 28, 2025
64d4aaa
Merge branch 'develop' into poyenc/specify-offset-load-tile
poyenc Sep 30, 2025
f81b6bf
Rename ParitionIndex as ReplacementPartitionIndex
poyenc Sep 30, 2025
56604d3
Merge branch 'develop' into poyenc/specify-offset-load-tile
poyenc Sep 30, 2025
5c5e2dd
Update pre_computed_warp_coords_ in move_extended()
poyenc Oct 1, 2025
28246a5
Fix inconsistency between template parameters and documentation
poyenc Oct 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 71 additions & 8 deletions include/ck_tile/core/tensor/load_tile.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,25 @@

namespace ck_tile {

template <typename TileWindow_, index_t i_access = -1, bool oob_conditional_check = true>
// Use SFINAE by declaring offset as integral<Offset> rather than index_t, in order to avoid
// overload ambiguity caused by the implicit number<> to index_t conversion
template <typename TileWindow_,
typename Offset,
index_t i_access = -1,
bool oob_conditional_check = true,
typename = std::enable_if_t<std::is_class_v<TileWindow_> && std::is_integral_v<Offset>>>
CK_TILE_DEVICE auto load_tile(const TileWindow_& tile_window,
Offset offset,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
return tile_window.load(offset, number<i_access>{}, bool_constant<oob_conditional_check>{});
}

template <typename TileWindow_,
index_t i_access = -1,
bool oob_conditional_check = true,
typename = std::enable_if_t<std::is_class_v<TileWindow_> && !is_constant_v<TileWindow_>>>
CK_TILE_DEVICE auto load_tile(const TileWindow_& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
Expand All @@ -34,10 +52,13 @@ CK_TILE_DEVICE auto load_tile(const TileWindow_& tile_window,
* and an elementwise function. For each A = A0, A1… AN, the elementwise function
* is additionally applied during a single read.
*/
template <typename TileWindow_,
typename ElementWise_,
index_t i_access = -1,
bool oob_conditional_check = true>
template <
typename TileWindow_,
typename ElementWise_,
index_t i_access = -1,
bool oob_conditional_check = true,
typename = std::enable_if_t<std::is_class_v<TileWindow_> && std::is_class_v<ElementWise_> &&
!is_constant_v<ElementWise_>>>
CK_TILE_DEVICE auto load_tile_with_elementwise(const TileWindow_& tile_window,
ElementWise_ elementwise,
number<i_access> = {},
Expand All @@ -49,10 +70,31 @@ CK_TILE_DEVICE auto load_tile_with_elementwise(const TileWindow_& tile_window,
tile_window, elementwise, number<i_access>{}, bool_constant<oob_conditional_check>{});
}

// Use SFINAE by declaring offset as integral<Offset> rather than index_t, in order to avoid
// overload ambiguity caused by the implicit number<> to index_t conversion
template <typename DistributedTensor_,
typename TileWindow_,
typename Offset,
index_t i_access = -1,
bool oob_conditional_check = true,
typename = std::enable_if_t<std::is_class_v<std::remove_cv_t<DistributedTensor_>> &&
std::is_class_v<TileWindow_> && std::is_integral_v<Offset>>>
CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile,
const TileWindow_& tile_window,
Offset offset,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
return tile_window.load(
offset, dst_tile, number<i_access>{}, bool_constant<oob_conditional_check>{});
}

template <typename DistributedTensor_,
typename TileWindow_,
index_t i_access = -1,
bool oob_conditional_check = true>
bool oob_conditional_check = true,
typename = std::enable_if_t<std::is_class_v<std::remove_cv_t<DistributedTensor_>> &&
std::is_class_v<TileWindow_> && !is_constant_v<TileWindow_>>>
CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile,
const TileWindow_& tile_window,
number<i_access> = {},
Expand All @@ -74,6 +116,7 @@ template <typename T,
typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename PartitoinIndex_,
index_t NumCoord,
index_t i_access = -1,
bool oob_conditional_check = true,
Expand All @@ -82,6 +125,7 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile,
const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
PartitoinIndex_,
NumCoord>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {},
Expand Down Expand Up @@ -114,15 +158,34 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile,

template <typename LdsTileWindow_,
typename TileWindow_,
typename Offset,
index_t i_access = -1,
bool oob_conditional_check = true>
bool oob_conditional_check = true,
typename = std::enable_if_t<std::is_class_v<remove_cvref_t<LdsTileWindow_>> &&
std::is_class_v<TileWindow_> && std::is_integral_v<Offset>>>
CK_TILE_DEVICE auto async_load_tile(LdsTileWindow_&& lds_tile,
const TileWindow_& tile_window,
Offset offset,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
return tile_window.async_load(
lds_tile, number<i_access>{}, bool_constant<oob_conditional_check>{});
offset, lds_tile, number<i_access>{}, bool_constant<oob_conditional_check>{});
}

template <typename LdsTileWindow_,
typename TileWindow_,
index_t i_access = -1,
bool oob_conditional_check = true,
typename = std::enable_if_t<std::is_class_v<remove_cvref_t<LdsTileWindow_>> &&
std::is_class_v<TileWindow_> && !is_constant_v<TileWindow_>>>
CK_TILE_DEVICE auto async_load_tile(LdsTileWindow_&& lds_tile,
const TileWindow_& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
return async_load_tile(
lds_tile, tile_window, 0, number<i_access>{}, bool_constant<oob_conditional_check>{});
}

template <typename LdsTileWindow_,
Expand Down
7 changes: 5 additions & 2 deletions include/ck_tile/core/tensor/load_tile_transpose.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,7 @@ template <
typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename PartitoinIndex_,
index_t NumCoord,
typename Policy = DefaultTranspose<typename BottomTensorView_::DataType>,
typename = std::enable_if_t<TransposeTileDistrChecker<TileDistribution_,
Expand All @@ -403,14 +404,16 @@ CK_TILE_DEVICE auto
load_tile_transpose(const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window)
PartitoinIndex_,
NumCoord>& __restrict__ tile_window,
index_t offset = 0)
{
using OutTileDstrEncode = typename OutputTileDistributionTraits<
typename TileDistribution_::DstrEncode,
typename BottomTensorView_::DataType>::TransposedDstrEncode;
auto out_tensor = make_static_distributed_tensor<typename BottomTensorView_::DataType>(
make_static_tile_distribution(OutTileDstrEncode{}));
auto trans_tensor = tile_window.template load_transpose<Policy>();
auto trans_tensor = tile_window.template load_transpose<Policy>(offset);
constexpr auto input_distr = TileDistribution_{};
constexpr auto output_distr = make_static_tile_distribution(OutTileDstrEncode{});

Expand Down
4 changes: 4 additions & 0 deletions include/ck_tile/core/tensor/store_tile.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,14 @@ store_tile_raw(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename ReplacementPartitionIndex_,
index_t NumCoord,
typename DataType_>
CK_TILE_DEVICE void
store_tile(tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
ReplacementPartitionIndex_,
NumCoord>& tile_window,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
{
Expand All @@ -79,12 +81,14 @@ store_tile(tile_window_with_static_distribution<BottomTensorView_,
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename ReplacementPartitionIndex_,
index_t NumCoord,
typename DataType_>
CK_TILE_DEVICE void
store_tile_raw(tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
ReplacementPartitionIndex_,
NumCoord>& tile_window,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
{
Expand Down
Loading
Loading