diff --git a/include/ck_tile/core/tensor/load_tile.hpp b/include/ck_tile/core/tensor/load_tile.hpp index 2e9ab0f5c6c..b9ff58b94fb 100644 --- a/include/ck_tile/core/tensor/load_tile.hpp +++ b/include/ck_tile/core/tensor/load_tile.hpp @@ -18,7 +18,25 @@ namespace ck_tile { -template +// Use SFINAE by declaring offset as integral rather than index_t, in order to avoid +// overload ambiguity caused by the implicit number<> to index_t conversion +template && std::is_integral_v>> +CK_TILE_DEVICE auto load_tile(const TileWindow_& tile_window, + Offset offset, + number = {}, + bool_constant = {}) +{ + return tile_window.load(offset, number{}, bool_constant{}); +} + +template && !is_constant_v>> CK_TILE_DEVICE auto load_tile(const TileWindow_& tile_window, number = {}, bool_constant = {}) @@ -34,10 +52,13 @@ CK_TILE_DEVICE auto load_tile(const TileWindow_& tile_window, * and an elementwise function. For each A = A0, A1… AN, the elementwise function * is additionally applied during a single read. */ -template +template < + typename TileWindow_, + typename ElementWise_, + index_t i_access = -1, + bool oob_conditional_check = true, + typename = std::enable_if_t && std::is_class_v && + !is_constant_v>> CK_TILE_DEVICE auto load_tile_with_elementwise(const TileWindow_& tile_window, ElementWise_ elementwise, number = {}, @@ -49,10 +70,31 @@ CK_TILE_DEVICE auto load_tile_with_elementwise(const TileWindow_& tile_window, tile_window, elementwise, number{}, bool_constant{}); } +// Use SFINAE by declaring offset as integral rather than index_t, in order to avoid +// overload ambiguity caused by the implicit number<> to index_t conversion +template > && + std::is_class_v && std::is_integral_v>> +CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile, + const TileWindow_& tile_window, + Offset offset, + number = {}, + bool_constant = {}) +{ + return tile_window.load( + offset, dst_tile, number{}, bool_constant{}); +} + template + bool oob_conditional_check = true, + typename = std::enable_if_t> && + std::is_class_v && !is_constant_v>> CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile, const TileWindow_& tile_window, number = {}, @@ -74,6 +116,7 @@ template & tile_window, number = {}, bool_constant = {}, @@ -114,15 +158,34 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile, template + bool oob_conditional_check = true, + typename = std::enable_if_t> && + std::is_class_v && std::is_integral_v>> CK_TILE_DEVICE auto async_load_tile(LdsTileWindow_&& lds_tile, const TileWindow_& tile_window, + Offset offset, number = {}, bool_constant = {}) { return tile_window.async_load( - lds_tile, number{}, bool_constant{}); + offset, lds_tile, number{}, bool_constant{}); +} + +template > && + std::is_class_v && !is_constant_v>> +CK_TILE_DEVICE auto async_load_tile(LdsTileWindow_&& lds_tile, + const TileWindow_& tile_window, + number = {}, + bool_constant = {}) +{ + return async_load_tile( + lds_tile, tile_window, 0, number{}, bool_constant{}); } template , typename = std::enable_if_t& tile_window) + PartitoinIndex_, + NumCoord>& __restrict__ tile_window, + index_t offset = 0) { using OutTileDstrEncode = typename OutputTileDistributionTraits< typename TileDistribution_::DstrEncode, typename BottomTensorView_::DataType>::TransposedDstrEncode; auto out_tensor = make_static_distributed_tensor( make_static_tile_distribution(OutTileDstrEncode{})); - auto trans_tensor = tile_window.template load_transpose(); + auto trans_tensor = tile_window.template load_transpose(offset); constexpr auto input_distr = TileDistribution_{}; constexpr auto output_distr = make_static_tile_distribution(OutTileDstrEncode{}); diff --git a/include/ck_tile/core/tensor/store_tile.hpp b/include/ck_tile/core/tensor/store_tile.hpp index d5a716664d5..899bad0d8ea 100644 --- a/include/ck_tile/core/tensor/store_tile.hpp +++ b/include/ck_tile/core/tensor/store_tile.hpp @@ -64,12 +64,14 @@ store_tile_raw(tile_window_with_static_lengths CK_TILE_DEVICE void store_tile(tile_window_with_static_distribution& tile_window, const static_distributed_tensor& dstr_tensor) { @@ -79,12 +81,14 @@ store_tile(tile_window_with_static_distribution CK_TILE_DEVICE void store_tile_raw(tile_window_with_static_distribution& tile_window, const static_distributed_tensor& dstr_tensor) { diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp index 2db5d719c09..53d250766f3 100644 --- a/include/ck_tile/core/tensor/tile_window.hpp +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -25,20 +25,23 @@ namespace ck_tile { * @note This tile window does not support single issue you need to use tile_window_linear * structure for this purpose * - * @tparam BottomTensorView_ Class describing & holding device tensor memory. - * @tparam WindowLengths_ Spatial sizes of windowed view on tensor. - * @tparam StaticTileDistribution_ Thread distribution (mapping) into Tile dimensions - * @tparam NumCoord TBD + * @tparam BottomTensorView_ Class describing & holding device tensor memory. + * @tparam WindowLengths_ Spatial sizes of windowed view on tensor. + * @tparam StaticTileDistribution_ Thread distribution (mapping) into Tile dimensions + * @tparam ReplacementPartitionIndex Replacement values of (get_warp_id(), get_lane_id()) tuple + * @tparam NumCoord TBD */ template struct tile_window_with_static_distribution : public tile_window_with_tile_dstr_base< tile_window_with_static_distribution, BottomTensorView_, WindowLengths_, @@ -48,6 +51,7 @@ struct tile_window_with_static_distribution tile_window_with_static_distribution, BottomTensorView_, WindowLengths_, @@ -71,14 +75,49 @@ struct tile_window_with_static_distribution : pre_computed_coords_{} { - this->window_origin_ = window_origin; - this->window_lengths_ = window_lengths; - this->bottom_tensor_view_ = bottom_tensor_view; - this->tile_dstr_ = tile_distribution; + this->window_origin_ = window_origin; + this->window_lengths_ = window_lengths; + this->bottom_tensor_view_ = bottom_tensor_view; + this->tile_dstr_ = tile_distribution; + + pre_computed_coords_ = prepare_coords(bottom_tensor_view, window_origin, tile_distribution); + if constexpr(Base::BottomTensorView::buffer_view::get_address_space() == + address_space_enum::global) + { + pre_computed_warp_coords_ = prepare_coords( + bottom_tensor_view, window_origin, tile_distribution, sequence<-1, 0>{}); + } + } + + template + CK_TILE_DEVICE constexpr auto + prepare_coords(const typename Base::BottomTensorView& bottom_tensor_view, + const typename Base::BottomTensorIndex& window_origin, + const typename Base::TileDstr& tile_distribution, + NewReplacementPartitionIndex = {}) const + { + array, NumCoord> + coords; + const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate( tile_distribution.get_ps_ys_to_xs_adaptor(), - container_concat(detail::get_partition_index(tile_distribution), - array{0})); + container_concat( + // Override partition_index with the corresponding non-negative elements (if + // any) from NewReplacementPartitionIndex + [&] { + auto partition_index = detail::get_partition_index(tile_distribution); + static_for<0, + ck_tile::min(partition_index.size(), + NewReplacementPartitionIndex::size()), + 1>{}([&](auto idx) { + if constexpr(0 <= NewReplacementPartitionIndex{}[idx]) + { + partition_index[idx] = NewReplacementPartitionIndex{}[idx]; + } + }); + return partition_index; + }(), + array{0})); typename Base::BottomTensorIndex bottom_tensor_thread_origin_idx_tmp = window_origin + window_adaptor_thread_coord_tmp.get_bottom_index(); @@ -105,18 +144,35 @@ struct tile_window_with_static_distribution Base::move_window_adaptor_and_bottom_tensor_thread_coordinate( window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); - pre_computed_coords_(iCoord) = - make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord); + coords(iCoord) = make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord); }); + + return coords; } template CK_TILE_DEVICE auto load(number = {}, bool_constant = {}) const + { + return load(0, number{}, bool_constant{}); + } + + // Use SFINAE by declaring offset as integral rather than index_t, in order to avoid + // overload ambiguity caused by the implicit number<> to index_t conversion + template >> + CK_TILE_DEVICE auto load(Offset offset, + number = {}, + bool_constant = {}) const { constexpr auto tile_dstr = typename Base::TileDstr{}; auto dst_tensor = make_static_distributed_tensor(tile_dstr); - load(dst_tensor, number{}, bool_constant{}); + load(offset, + dst_tensor, + number{}, + bool_constant{}); return dst_tensor; } @@ -130,10 +186,13 @@ struct tile_window_with_static_distribution * The same thread, during vectorized reading, accesses the same set of * data from A0, A1, A2, … AN. */ - template + template < + typename TileWindow_, + typename ElementWise_, + index_t i_access_unsupport_ = -1, + bool oob_conditional_check = true, + typename = std::enable_if_t && std::is_class_v && + !is_constant_v>> CK_TILE_DEVICE auto load(const TileWindow_& tile_window, ElementWise_ elementwise, number = {}, @@ -149,11 +208,15 @@ struct tile_window_with_static_distribution return dst_tensor; } - template + template < + typename DistributedTensor, + typename TileWindow_, + typename ElementWise_, + index_t i_access_unsupport_ = -1, + bool oob_conditional_check = true, + typename = std::enable_if_t> && + std::is_class_v && std::is_class_v && + !is_constant_v>> CK_TILE_DEVICE auto load(DistributedTensor& dst_tensor, const TileWindow_& tile_window, ElementWise_ elementwise, @@ -232,10 +295,27 @@ struct tile_window_with_static_distribution template + bool oob_conditional_check = true, + typename = std::enable_if_t> && + !is_constant_v>>> CK_TILE_DEVICE auto load(DistributedTensor& dst_tensor, number = {}, bool_constant = {}) const + { + load(0, dst_tensor, number{}, bool_constant{}); + } + + template && + std::is_class_v> && + !is_constant_v>>> + CK_TILE_DEVICE auto load(Offset offset, + DistributedTensor& dst_tensor, + number = {}, + bool_constant = {}) const { using Traits = typename Base::Traits; using vector_t = typename Traits::vector_t; @@ -258,7 +338,7 @@ struct tile_window_with_static_distribution // read from bottom tensor const vector_t vec_value = this->get_bottom_tensor_view().template get_vectorized_elements( - bottom_tensor_thread_coord, 0, bool_constant{}); + bottom_tensor_thread_coord, offset, bool_constant{}); // write into distributed tensor static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) { constexpr auto idx_ys = generate_tuple( @@ -451,7 +531,8 @@ struct tile_window_with_static_distribution template - CK_TILE_DEVICE auto async_load(LdsTileWindow_&& lds_tile, + CK_TILE_DEVICE auto async_load(index_t offset, + LdsTileWindow_&& lds_tile, number = {}, bool_constant = {}) const { @@ -472,12 +553,15 @@ struct tile_window_with_static_distribution auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; + auto window_adaptor_warp_coord = pre_computed_warp_coords_[iCoord][I0]; + auto bottom_tensor_warp_coord = pre_computed_warp_coords_[iCoord][I1]; + static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { constexpr auto iAccess = number{}; // Use precomputed window origin auto lds_bottom_tensor_thread_idx = - window_origin + window_adaptor_thread_coord.get_bottom_index(); + window_origin + window_adaptor_warp_coord.get_bottom_index(); // Use precomputed tensor descriptor const auto lds_coord = @@ -490,7 +574,7 @@ struct tile_window_with_static_distribution this->get_bottom_tensor_view().template async_get_vectorized_elements( smem, bottom_tensor_thread_coord, - number<0>{}, + offset, bool_constant{}); // Move thread coordinate if not last access @@ -503,18 +587,33 @@ struct tile_window_with_static_distribution Base::move_window_adaptor_and_bottom_tensor_thread_coordinate( window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); + + Base::move_window_adaptor_and_bottom_tensor_thread_coordinate( + window_adaptor_warp_coord, bottom_tensor_warp_coord, idx_diff_ps_ys); } }); }); } template - CK_TILE_DEVICE auto load_transpose() const + CK_TILE_DEVICE auto load_transpose(number = {}, + bool_constant = {}) const + { + return this->template load_transpose( + 0, number{}, bool_constant{}); + } + + template + CK_TILE_DEVICE auto load_transpose(index_t offset, + number = {}, + bool_constant = {}) const { constexpr auto tile_dstr = typename Base::TileDstr{}; auto dst_tensor = make_static_distributed_tensor(tile_dstr); - this->template load_transpose( - dst_tensor, number{}, bool_constant{}); + this->template load_transpose(offset, + dst_tensor, + number{}, + bool_constant{}); return dst_tensor; } @@ -522,7 +621,8 @@ struct tile_window_with_static_distribution typename DistributedTensor, index_t i_access_unsupport_ = -1, bool oob_conditional_check = true> - CK_TILE_DEVICE auto load_transpose(DistributedTensor& dst_tensor, + CK_TILE_DEVICE auto load_transpose(index_t offset, + DistributedTensor& dst_tensor, number = {}, bool_constant = {}) const { @@ -550,7 +650,7 @@ struct tile_window_with_static_distribution const vector_t vec_value = this->get_bottom_tensor_view() .template get_transpose_vectorized_elements( - bottom_tensor_thread_coord, 0); + bottom_tensor_thread_coord, offset); // write into distributed tensor static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) { constexpr auto orig_idx_ys = generate_tuple( @@ -862,6 +962,16 @@ struct tile_window_with_static_distribution pre_computed_coords_(iCoord)(I1), step); }); + + if constexpr(Base::BottomTensorView::buffer_view::get_address_space() == + address_space_enum::global) + { + static_for<0, NumCoord, 1>{}([&](auto iCoord) { + move_tensor_coordinate(this->bottom_tensor_view_.get_tensor_descriptor(), + pre_computed_warp_coords_(iCoord)(I1), + step); + }); + } } CK_TILE_DEVICE void set_window_origin_extended(const typename Base::BottomTensorIndex&) @@ -908,23 +1018,32 @@ struct tile_window_with_static_distribution // per-thread coordinate for bottom tensor array, NumCoord> pre_computed_coords_; + // pre_computed_warp_coords_ exists only in the global memory tile_window + std::conditional_t< + Base::BottomTensorView::buffer_view::get_address_space() == address_space_enum::global, + array, NumCoord>, + std::byte> + pre_computed_warp_coords_; }; // TODO: use strategy template + typename ReplacementPartitionIndex = sequence<-1, -1>, + index_t NumCoord = 1> CK_TILE_DEVICE constexpr auto make_tile_window(const TensorView_& tensor_view, const WindowLengths_& window_lengths, const multi_index& origin, const StaticTileDistribution_& tile_distribution, - number = {}) + ReplacementPartitionIndex = {}, + number = {}) { return tile_window_with_static_distribution, remove_cvref_t, remove_cvref_t, + ReplacementPartitionIndex, NumCoord>{ tensor_view, window_lengths, origin, tile_distribution}; } @@ -933,17 +1052,20 @@ make_tile_window(const TensorView_& tensor_view, template + typename ReplacementPartitionIndex = sequence<-1, -1>, + index_t NumCoord = 1> CK_TILE_DEVICE auto make_tile_window_raw(const TensorView_& tensor_view, const WindowLengths_& window_lengths, const multi_index& origin, const StaticTileDistribution_& tile_distribution, - number = {}) + ReplacementPartitionIndex = {}, + number = {}) { auto w = tile_window_with_static_distribution, remove_cvref_t, remove_cvref_t, + ReplacementPartitionIndex, NumCoord>{ tensor_view, window_lengths, origin, tile_distribution}; w.init_raw(); @@ -953,15 +1075,18 @@ make_tile_window_raw(const TensorView_& tensor_view, template CK_TILE_DEVICE void move_tile_window( tile_window_with_static_distribution& window, const typename tile_window_with_static_distribution::BottomTensorIndex& step) { window.move(step); @@ -970,20 +1095,24 @@ CK_TILE_DEVICE void move_tile_window( template CK_TILE_DEVICE void move_tile_window( tuple>& window, const typename tile_window_with_static_distribution::BottomTensorIndex& step) { using T = tuple>; static constexpr auto N = T::size(); @@ -1108,38 +1237,48 @@ make_tile_window(const tile_window_with_static_lengths +template > CK_TILE_DEVICE constexpr auto make_tile_window(const tile_window_with_static_lengths& tile_window, const multi_index& origin, - const StaticTileDistribution& tile_distribution) + const StaticTileDistribution& tile_distribution, + ReplacementPartitionIndex = {}) { return make_tile_window(tile_window.get_bottom_tensor_view(), tile_window.get_window_lengths(), origin, - tile_distribution); + tile_distribution, + ReplacementPartitionIndex{}); } -template +template > CK_TILE_DEVICE constexpr auto make_tile_window(const tile_window_with_static_lengths& tile_window, - const StaticTileDistribution& tile_distribution) + const StaticTileDistribution& tile_distribution, + ReplacementPartitionIndex = {}) { return make_tile_window(tile_window.get_bottom_tensor_view(), tile_window.get_window_lengths(), tile_window.get_window_origin(), - tile_distribution); + tile_distribution, + ReplacementPartitionIndex{}); } -template +template > CK_TILE_DEVICE constexpr auto make_tile_window_raw(const tile_window_with_static_lengths& tile_window, const StaticTileDistribution& tile_distribution) { - auto w = make_tile_window(tile_window.get_bottom_tensor_view(), - tile_window.get_window_lengths(), - tile_window.get_window_origin(), - tile_distribution); + auto w = make_tile_window(tile_window, tile_distribution, ReplacementPartitionIndex{}); w.init_raw(); return w; } @@ -1168,19 +1307,22 @@ struct is_tile_window_with_static_distribution : std::false_type /** * @brief Specialization for `tile_window_with_static_distribution` to evaluate to `true_type`. * - * @tparam BottomTensorView_ Bottom tensor view type of the tile window. - * @tparam WindowLengths_ Static window lengths. - * @tparam StaticTileDistribution_ Tile distribution policy. - * @tparam NumCoord Number of coordinate dimensions. + * @tparam BottomTensorView_ Class describing & holding device tensor memory. + * @tparam WindowLengths_ Spatial sizes of windowed view on tensor. + * @tparam StaticTileDistribution_ Thread distribution (mapping) into Tile dimensions + * @tparam ReplacementPartitionIndex Replacement values of (get_warp_id(), get_lane_id()) tuple + * @tparam NumCoord TBD */ template struct is_tile_window_with_static_distribution< tile_window_with_static_distribution> : std::true_type { }; diff --git a/include/ck_tile/core/tensor/update_tile.hpp b/include/ck_tile/core/tensor/update_tile.hpp index 570abde1893..553b763bbaa 100644 --- a/include/ck_tile/core/tensor/update_tile.hpp +++ b/include/ck_tile/core/tensor/update_tile.hpp @@ -40,6 +40,7 @@ update_tile(tile_window_with_static_lengths& template & tile_window, const static_distributed_tensor& dstr_tensor, number = {}, @@ -59,6 +61,7 @@ update_tile(tile_window_with_static_distribution& tile_window, const static_distributed_tensor& dstr_tensor, number = {},