Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 193 additions & 1 deletion include/ck/tensor_description/multi_index_transform.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

Expand Down Expand Up @@ -1553,6 +1553,198 @@ struct UnMerge
}
};

/**
* @brief Transformation struct for convolution backward data output indices to GEMM indices.
*
* This struct is responsible for mapping the output tensor indices (N, Ho, Wo, K) from the
* convolution backward data operation to the corresponding indices (K0, M, K1) used in the
* implicit GEMM computation. It encapsulates the necessary parameters and transformation logic
* required to efficiently perform the index conversion.
*/
struct ConvBwdDataImplicitGemmOutTransform
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};

using LowerIndex = MultiIndex<4>; // N, Ho, Wo, K
using UpperIndex = MultiIndex<3>; // K0, M, K1

index_t N_, Ho_, Wo_, K_;
index_t XDot_;
index_t HTilde_, WTilde_;
index_t WTildeSlice_, TildeSlice_;
index_t IHTildeSliceBegin_, IWTildeSliceBegin_;
index_t HRatio_, WRatio_;
index_t XDotSlice_K_;
index_t MPad_, KPad_;
Tuple<index_t, index_t, index_t> up_lengths_; // K0_, MPadded, K1_;

Tuple<index_t, index_t, index_t, index_t>
low_lengths_magic_divisor_multiplier_; // XDotSlice_K_, K_, TildeSlice_, WTildeSlice_
Tuple<index_t, index_t, index_t, index_t>
low_lengths_magic_divisor_shift_; // XDotSlice_K_, K_, TildeSlice_, WTildeSlice_

__host__ __device__ constexpr ConvBwdDataImplicitGemmOutTransform() = default;

__host__ __device__ constexpr ConvBwdDataImplicitGemmOutTransform(index_t N,
index_t Ho,
index_t Wo,
index_t K,
index_t XDot,
index_t HTilde,
index_t WTilde,
index_t WTildeSlice,
index_t HWTildeSlice,
index_t IHTildeSliceBegin,
index_t IWTildeSliceBegin,
index_t HRatio,
index_t WRatio,
index_t XDotSlice_K,
index_t K0,
index_t MPadded,
index_t K1,
index_t MPad,
index_t KPad)
: N_{N},
Ho_{Ho},
Wo_{Wo},
K_{K},
XDot_{XDot},
HTilde_{HTilde},
WTilde_{WTilde},
WTildeSlice_{WTildeSlice},
TildeSlice_{HWTildeSlice},
IHTildeSliceBegin_{IHTildeSliceBegin},
IWTildeSliceBegin_{IWTildeSliceBegin},
HRatio_{HRatio},
WRatio_{WRatio},
XDotSlice_K_{XDotSlice_K},
MPad_{MPad},
KPad_{KPad},
up_lengths_{make_tuple(K0, MPadded, K1)},
low_lengths_magic_divisor_multiplier_{
MagicDivision::CalculateMagicMultiplier(XDotSlice_K_),
MagicDivision::CalculateMagicMultiplier(K_),
MagicDivision::CalculateMagicMultiplier(TildeSlice_),
MagicDivision::CalculateMagicMultiplier(WTildeSlice_)},
low_lengths_magic_divisor_shift_{MagicDivision::CalculateMagicShift(XDotSlice_K_),
MagicDivision::CalculateMagicShift(K_),
MagicDivision::CalculateMagicShift(TildeSlice_),
MagicDivision::CalculateMagicShift(WTildeSlice_)}
{
}

__host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 4; }

__host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 3; }

__host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }

template <typename UpIdx>
__host__ __device__ constexpr auto CalculateLowerIndexN(const UpIdx& idx_up) const
{
index_t NStep, HStep, WStep;
// Merge
// NStep = M_id / TildeSlice_
NStep = MagicDivision::DoMagicDivision(idx_up[I1],
this->low_lengths_magic_divisor_multiplier_[I2],
this->low_lengths_magic_divisor_shift_[I2]);
HStep = idx_up[I1] - NStep * TildeSlice_;
// HStep = HStep / WTildeSlice_
HStep = MagicDivision::DoMagicDivision(HStep,
this->low_lengths_magic_divisor_multiplier_[I3],
this->low_lengths_magic_divisor_shift_[I3]);
WStep = idx_up[I1] - NStep * TildeSlice_ - HStep * WTildeSlice_;
// Slice
HStep += IHTildeSliceBegin_;
WStep += IWTildeSliceBegin_;

return make_tuple(NStep, HStep, WStep, 0);
}

template <typename UpIdx>
__host__ __device__ constexpr auto CalculateLowerIndexK(const UpIdx& idx_up) const
{
// UnMerge
// K_idx <- K0_idx * K1 + K1_idx
index_t K_idx = idx_up[I0] * up_lengths_[I2] + idx_up[I2];
// Merge
// YStep = K_idx / XDotSlice_K_
index_t YStep =
MagicDivision::DoMagicDivision(K_idx,
this->low_lengths_magic_divisor_multiplier_[I0],
this->low_lengths_magic_divisor_shift_[I0]);
index_t KStep = K_idx - YStep * XDotSlice_K_;
// Xstep = KStep / K_
index_t XStep =
MagicDivision::DoMagicDivision(KStep,
this->low_lengths_magic_divisor_multiplier_[I1],
this->low_lengths_magic_divisor_shift_[I1]);
KStep -= XStep * K_;
// Embed
YStep *= HRatio_;
XStep *= WRatio_;

return make_tuple(0, YStep, XStep, KStep);
}

template <typename LowIdx, typename UpIdx>
__host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
const UpIdx& idx_up) const
{
idx_low = CalculateLowerIndexN(idx_up) + CalculateLowerIndexK(idx_up);
}

template <typename LowIdxDiff,
typename UpIdxDiff,
typename LowIdx,
typename UpIdx,
index_t Hack>
__host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
const UpIdxDiff& /* idx_diff_up */,
LowIdx& idx_low,
const UpIdx& idx_up,
Number<Hack>) const
{
LowIdx low_old = idx_low;
idx_low = CalculateLowerIndexN(idx_up) + CalculateLowerIndexK(idx_up);
idx_diff_low = idx_low - low_old;
}

__host__ __device__ static constexpr bool IsLinearTransform() { return false; }

__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{
return true;
}

template <typename UpIdx>
__host__ __device__ constexpr bool
IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& idx_up) const
{
// Padding
index_t K_idx = idx_up[Number<0>{}] * up_lengths_[Number<2>{}] + idx_up[Number<2>{}];
index_t& M_idx = idx_up[Number<1>{}];

bool pad_valid = M_idx < up_lengths_[Number<1>{}] - MPad_ &&
K_idx < up_lengths_[Number<0>{}] * up_lengths_[Number<2>{}] - KPad_;
return pad_valid;
}

__host__ __device__ static constexpr bool IsKnownAtCompileTime() { return false; }

__host__ __device__ void Print() const
{
printf("{");
printf("ConvBwdDataImplicitGemmOutTransform, ");
printf("up_lengths_");
print_multi_index(up_lengths_);
printf("}");
}
};

template <typename LowerIndex>
struct Freeze
{
Expand Down
55 changes: 54 additions & 1 deletion include/ck/tensor_description/multi_index_transform_helper.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

Expand Down Expand Up @@ -94,6 +94,59 @@ __host__ __device__ constexpr auto make_unmerge_transform(
return UnMerge<UpLengths, Use24BitIntegerCalculation>{up_lengths};
}

__host__ __device__ constexpr auto make_conv_bwd_data_out_transform(index_t N,
index_t Ho,
index_t Wo,
index_t K,
[[maybe_unused]] index_t YDot,
index_t XDot,
index_t HTilde,
index_t WTilde,
index_t ConvDilationH,
index_t ConvDilationW,
index_t HTildeSlice,
index_t WTildeSlice,
index_t YDotSlice,
index_t XDotSlice,
index_t IHTildeSliceBegin,
index_t IWTildeSliceBegin,
index_t GcdStrideDilationH,
index_t GcdStrideDilationW,
index_t K0,
index_t K1,
index_t MPerBlock,
index_t GemmKPerBlock)
{
// Calculate padding
const auto MRaw = N * HTildeSlice * WTildeSlice;
const auto MPadded = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto MPad = MPadded - MRaw;

const auto KRaw = YDotSlice * XDotSlice * K;
const auto KPadded = math::integer_divide_ceil(KRaw, GemmKPerBlock) * GemmKPerBlock;
const auto KPad = KPadded - KRaw;

return ConvBwdDataImplicitGemmOutTransform{N,
Ho,
Wo,
K,
XDot,
HTilde,
WTilde,
WTildeSlice,
HTildeSlice * WTildeSlice,
IHTildeSliceBegin,
IWTildeSliceBegin,
-ConvDilationH / GcdStrideDilationH,
-ConvDilationW / GcdStrideDilationW,
XDotSlice * K,
K0,
MPadded,
K1,
MPad,
KPad};
}

template <typename LowerIndex>
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex& low_idx)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@
namespace ck {
namespace tensor_operation {

/**
* @brief Enable custom tensor transform for convolution backward data output.
*
* When set to 1, this macro enables a custom transformation of the output tensor
* in convolution backward data operations.
*/
#define CK_USE_CUSTOM_TENSOR_TRANSFORM_FOR_BWD_DATA_OUT 1

template <
index_t NDimSpatial,
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization ConvBwdDataSpecialization,
Expand Down Expand Up @@ -705,6 +713,12 @@ struct TransformConvBwdDataToGemm_v1

if constexpr(NDimSpatial == 2)
{
const index_t K0PerBlock = GemmKPerBlock / AK1;
const index_t AK0 = math::integer_divide_ceil(YDotSlice * XDotSlice * K_,
AK1 * K0PerBlock * batch_k_) *
K0PerBlock;

#if CK_USE_CUSTOM_TENSOR_TRANSFORM_FOR_BWD_DATA_OUT == 0
// A: output tensor
const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor(
out_grid_desc,
Expand Down Expand Up @@ -762,21 +776,53 @@ struct TransformConvBwdDataToGemm_v1
make_tuple(GemmKPerBlock, GemmMPerBlock),
Sequence<true, DoPadGemmM>{});

const index_t K0PerBlock = GemmKPerBlock / AK1;
const index_t AK0 =
math::integer_divide_ceil(out_gemmk_gemmm_padded_grid_desc.GetLength(I0),
AK1 * K0PerBlock * batch_k_) *
K0PerBlock;

const auto out_gemmak0_gemmm_gemmak1_grid_desc = transform_tensor_descriptor(
out_gemmk_gemmm_padded_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(AK0 * batch_k_, AK1)),
make_pass_through_transform(
out_gemmk_gemmm_padded_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));

return out_gemmak0_gemmm_gemmak1_grid_desc;
#else
const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor(
out_grid_desc,
make_tuple(make_pass_through_transform(N_),
make_pad_transform(Ho_, I0, I0),
make_pad_transform(Wo_, I0, I0),
make_pass_through_transform(K_)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));

const auto out_n_hop_wop_k_grid_desc_final = transform_tensor_descriptor(
out_n_hop_wop_k_grid_desc,
make_tuple(make_conv_bwd_data_out_transform(N_,
Ho_,
Wo_,
K_,
YDot_,
XDot_,
HTilde_,
WTilde_,
ConvDilationH_,
ConvDilationW_,
HTildeSlice,
WTildeSlice,
YDotSlice,
XDotSlice,
IHTildeSliceBegin,
IWTildeSliceBegin,
GcdStrideDilationH_,
GcdStrideDilationW_,
AK0,
AK1,
GemmMPerBlock,
GemmKPerBlock)),
make_tuple(Sequence<0, 1, 2, 3>{}),
make_tuple(Sequence<0, 1, 2>{}));

return out_n_hop_wop_k_grid_desc_final;
#endif
}
else if constexpr(NDimSpatial == 3)
{
Expand Down
Loading