Skip to content

Commit 624c468

Browse files
jakpiasebartekxk
andauthored
[CK_TILE] Add conv bwd weight two stage support (#2855)
* resolved conflicts * add conv bwd weight twostage * fix one file * fixes after review * fixes * fixes * Fix --------- Co-authored-by: Bartlomiej Kocot <barkocot@amd.com>
1 parent 4363a82 commit 624c468

16 files changed

+864
-361
lines changed

example/ck_tile/20_grouped_convolution/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,8 @@ target_compile_options(tile_example_grouped_conv_fwd PRIVATE ${EXAMPLE_GEMM_COMP
77
add_executable(tile_example_grouped_conv_bwd_weight EXCLUDE_FROM_ALL grouped_convolution_backward_weight.cpp)
88
target_compile_options(tile_example_grouped_conv_bwd_weight PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
99

10+
add_executable(tile_example_grouped_conv_bwd_weight_two_stage EXCLUDE_FROM_ALL grouped_convolution_backward_weight_two_stage.cpp)
11+
target_compile_options(tile_example_grouped_conv_bwd_weight_two_stage PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
12+
1013
add_executable(tile_example_grouped_conv_bwd_data EXCLUDE_FROM_ALL grouped_convolution_backward_data.cpp)
1114
target_compile_options(tile_example_grouped_conv_bwd_data PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})

example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data.cpp

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ float grouped_conv_bwd_data(const ck_tile::GroupedConvBwdDataHostArgs& args,
4141
constexpr ck_tile::index_t N_Warp_Tile = GemmWarpConfig::N_Warp_Tile;
4242
constexpr ck_tile::index_t K_Warp_Tile = GemmWarpConfig::K_Warp_Tile;
4343

44-
constexpr ck_tile::index_t VectorSizeA = 8;
45-
constexpr ck_tile::index_t VectorSizeB = 8;
44+
constexpr ck_tile::index_t VectorSizeA = 1;
45+
constexpr ck_tile::index_t VectorSizeB = 1;
4646
constexpr ck_tile::index_t VectorSizeC = 8;
4747

4848
// Implicit GEMM Traits
@@ -51,20 +51,29 @@ float grouped_conv_bwd_data(const ck_tile::GroupedConvBwdDataHostArgs& args,
5151
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
5252
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
5353

54-
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
55-
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenShape>;
56-
using GroupedConvTraitsType =
57-
ck_tile::GroupedConvTraits<NDimSpatial, ConvSpec, InLayout, WeiLayout, DsLayout, OutLayout>;
58-
using CodegenPipelineProblem =
59-
ck_tile::GemmPipelineProblem<InDataType,
60-
WeiDataType,
61-
AccDataType,
62-
CodegenShape,
63-
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraits,
64-
InDataType,
65-
true,
66-
VectorSizeA,
67-
VectorSizeB>;
54+
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
55+
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenShape>;
56+
using GroupedConvTraitsType = ck_tile::GroupedConvTraits<NDimSpatial,
57+
ConvSpec,
58+
InLayout,
59+
WeiLayout,
60+
DsLayout,
61+
OutLayout,
62+
VectorSizeA,
63+
VectorSizeB,
64+
VectorSizeC>;
65+
using CodegenPipelineProblem = ck_tile::GemmPipelineProblem<
66+
InDataType,
67+
WeiDataType,
68+
AccDataType,
69+
CodegenShape,
70+
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdData,
71+
ck_tile::element_wise::PassThrough,
72+
ck_tile::element_wise::PassThrough,
73+
InDataType,
74+
true,
75+
GroupedConvTraitsType::VectorSizeA,
76+
GroupedConvTraitsType::VectorSizeB>;
6877
using CodegenPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
6978

7079
const auto Run = [&](const auto memory_operation_) {
@@ -90,7 +99,7 @@ float grouped_conv_bwd_data(const ck_tile::GroupedConvBwdDataHostArgs& args,
9099
memory_operation,
91100
1,
92101
true,
93-
VectorSizeC>>;
102+
GroupedConvTraitsType::VectorSizeC>>;
94103

95104
using Kernel = ck_tile::GroupedConvolutionBackwardDataKernel<GroupedConvTraitsType,
96105
TilePartitioner,

example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight.cpp

Lines changed: 26 additions & 191 deletions
Original file line numberDiff line numberDiff line change
@@ -11,195 +11,13 @@
1111

1212
#include "ck_tile/host.hpp"
1313
#include "grouped_convolution_utils.hpp"
14-
15-
template <ck_tile::index_t NDimSpatial,
16-
typename GemmWarpConfig,
17-
typename InDataType,
18-
typename WeiDataType,
19-
typename AccDataType,
20-
typename OutDataType,
21-
typename InLayout,
22-
typename WeiLayout,
23-
typename OutLayout,
24-
typename DsDataType = ck_tile::tuple<>,
25-
typename DsLayout = ck_tile::tuple<>,
26-
typename CDEElementWise = ck_tile::element_wise::PassThrough>
27-
float grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args,
28-
const ck_tile::stream_config& s)
29-
{
30-
constexpr int kBlockPerCu = 1;
31-
32-
constexpr ck_tile::index_t M_Tile = 64;
33-
constexpr ck_tile::index_t N_Tile = 64;
34-
constexpr ck_tile::index_t K_Tile = 64;
35-
36-
constexpr ck_tile::index_t M_Warp = 2;
37-
constexpr ck_tile::index_t N_Warp = 2;
38-
constexpr ck_tile::index_t K_Warp = 1;
39-
40-
constexpr ck_tile::index_t M_Warp_Tile = GemmWarpConfig::M_Warp_Tile;
41-
constexpr ck_tile::index_t N_Warp_Tile = GemmWarpConfig::N_Warp_Tile;
42-
constexpr ck_tile::index_t K_Warp_Tile = GemmWarpConfig::K_Warp_Tile;
43-
44-
constexpr ck_tile::index_t VectorSizeA = 8;
45-
constexpr ck_tile::index_t VectorSizeB = 8;
46-
constexpr ck_tile::index_t VectorSizeC = 8;
47-
48-
// Implicit GEMM Traits
49-
using CodegenShape =
50-
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
51-
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
52-
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
53-
54-
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
55-
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenShape>;
56-
using GroupedConvTraitsType =
57-
ck_tile::GroupedConvTraits<NDimSpatial, ConvSpec, InLayout, WeiLayout, DsLayout, OutLayout>;
58-
using CodegenPipelineProblem =
59-
ck_tile::GemmPipelineProblem<InDataType,
60-
WeiDataType,
61-
AccDataType,
62-
CodegenShape,
63-
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraits,
64-
InDataType,
65-
true,
66-
VectorSizeA,
67-
VectorSizeB>;
68-
using CodegenPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
69-
70-
const auto Run = [&](const auto memory_operation_) {
71-
constexpr auto memory_operation = memory_operation_.value;
72-
73-
using ConvEpilogue = ck_tile::CShuffleEpilogue<
74-
ck_tile::CShuffleEpilogueProblem<InDataType,
75-
WeiDataType,
76-
DsDataType,
77-
AccDataType,
78-
OutDataType,
79-
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
80-
ck_tile::tensor_layout::gemm::RowMajor,
81-
CDEElementWise,
82-
TilePartitioner::MPerBlock,
83-
TilePartitioner::NPerBlock,
84-
M_Warp,
85-
N_Warp,
86-
M_Warp_Tile,
87-
N_Warp_Tile,
88-
K_Warp_Tile,
89-
CodegenPipelineProblem::TransposeC,
90-
memory_operation,
91-
1,
92-
true,
93-
VectorSizeC>>;
94-
95-
using Kernel = ck_tile::GroupedConvolutionBackwardWeightKernel<GroupedConvTraitsType,
96-
TilePartitioner,
97-
CodegenPipeline,
98-
ConvEpilogue>;
99-
auto kargs = Kernel::MakeKernelArgs(args);
100-
101-
const dim3 grids = Kernel::GridSize(kargs);
102-
const dim3 blocks = Kernel::BlockSize();
103-
104-
if(!Kernel::IsSupportedArgument(kargs))
105-
{
106-
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
107-
}
108-
109-
if(s.log_level_ > 0)
110-
{
111-
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
112-
<< "shape: " << CodegenShape::GetName() << '\n'
113-
<< "problem: " << CodegenPipelineProblem::GetName() << '\n'
114-
<< "pipeline: " << CodegenPipeline::GetName() << '\n'
115-
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
116-
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
117-
<< '\n'
118-
<< "Vector size A: " << CodegenPipeline::GetVectorSizeA()
119-
<< ", Vector size B: " << CodegenPipeline::GetVectorSizeB()
120-
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
121-
}
122-
123-
float ave_time = ck_tile::launch_kernel_time_mask(
124-
s,
125-
Kernel::Preprocess(kargs, s),
126-
ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
127-
128-
return ave_time;
129-
};
130-
131-
if(args.k_batch == 1)
132-
{
133-
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
134-
ck_tile::memory_operation_enum::set>{});
135-
}
136-
else
137-
{
138-
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
139-
ck_tile::memory_operation_enum::atomic_add>{});
140-
}
141-
}
142-
14+
#include "grouped_convolution_backward_weight_invoker.hpp"
14315
#include "run_grouped_convolution_bwd_weight_example.inc"
14416

145-
template <typename GemmWarpConfig,
146-
typename InPrecType,
147-
typename WeiPrecType = InPrecType,
148-
typename OutPrecType = InPrecType>
149-
int run_grouped_conv_bwd_weight_example_prec_type(
150-
std::string in_layout, std::string wei_layout, std::string out_layout, int argc, char* argv[])
151-
{
152-
using NWGC = ck_tile::tensor_layout::convolution::NWGC;
153-
using NHWGC = ck_tile::tensor_layout::convolution::NHWGC;
154-
using NDHWGC = ck_tile::tensor_layout::convolution::NDHWGC;
155-
156-
using GKXC = ck_tile::tensor_layout::convolution::GKXC;
157-
using GKYXC = ck_tile::tensor_layout::convolution::GKYXC;
158-
using GKZYXC = ck_tile::tensor_layout::convolution::GKZYXC;
159-
160-
using NWGK = ck_tile::tensor_layout::convolution::NWGK;
161-
using NHWGK = ck_tile::tensor_layout::convolution::NHWGK;
162-
using NDHWGK = ck_tile::tensor_layout::convolution::NDHWGK;
163-
164-
if(in_layout == "NWGC" && wei_layout == "GKXC" && out_layout == "NWGK")
165-
{
166-
return run_grouped_conv_bwd_weight_example_with_layouts<ck_tile::number<1>{},
167-
GemmWarpConfig,
168-
InPrecType,
169-
WeiPrecType,
170-
OutPrecType>(
171-
argc, argv, NWGC{}, GKXC{}, NWGK{});
172-
}
173-
else if(in_layout == "NHWGC" && wei_layout == "GKYXC" && out_layout == "NHWGK")
174-
{
175-
return run_grouped_conv_bwd_weight_example_with_layouts<ck_tile::number<2>{},
176-
GemmWarpConfig,
177-
InPrecType,
178-
WeiPrecType,
179-
OutPrecType>(
180-
argc, argv, NHWGC{}, GKYXC{}, NHWGK{});
181-
}
182-
else if(in_layout == "NDHWGC" && wei_layout == "GKZYXC" && out_layout == "NDHWGK")
183-
{
184-
return run_grouped_conv_bwd_weight_example_with_layouts<ck_tile::number<3>{},
185-
GemmWarpConfig,
186-
InPrecType,
187-
WeiPrecType,
188-
OutPrecType>(
189-
argc, argv, NDHWGC{}, GKZYXC{}, NDHWGK{});
190-
}
191-
else
192-
{
193-
throw std::runtime_error("Unsupported memory layout!");
194-
}
195-
}
196-
19717
template <typename GemmWarpConfig>
198-
int run_grouped_conv_bwd_weight_example(int argc, char* argv[])
18+
int run_grouped_conv_bwd_weight_example(ck_tile::ArgParser& arg_parser)
19919
{
200-
auto [result, arg_parser] = create_args(argc, argv);
201-
if(!result)
202-
return -1;
20+
using Invoker = GroupedConvolutionBackwardWeightInvoker;
20321

20422
std::string data_type = arg_parser.get_str("prec");
20523
std::string in_layout = arg_parser.get_str("in_layout");
@@ -208,13 +26,17 @@ int run_grouped_conv_bwd_weight_example(int argc, char* argv[])
20826

20927
if(data_type == "fp16")
21028
{
211-
return run_grouped_conv_bwd_weight_example_prec_type<GemmWarpConfig, ck_tile::half_t>(
212-
in_layout, wei_layout, out_layout, argc, argv);
29+
return run_grouped_conv_bwd_weight_example_prec_type<Invoker,
30+
GemmWarpConfig,
31+
ck_tile::half_t>(
32+
in_layout, wei_layout, out_layout, arg_parser);
21333
}
21434
else if(data_type == "bf16")
21535
{
216-
return run_grouped_conv_bwd_weight_example_prec_type<GemmWarpConfig, ck_tile::bf16_t>(
217-
in_layout, wei_layout, out_layout, argc, argv);
36+
return run_grouped_conv_bwd_weight_example_prec_type<Invoker,
37+
GemmWarpConfig,
38+
ck_tile::bf16_t>(
39+
in_layout, wei_layout, out_layout, arg_parser);
21840
}
21941
else
22042
{
@@ -224,9 +46,22 @@ int run_grouped_conv_bwd_weight_example(int argc, char* argv[])
22446

22547
int main(int argc, char* argv[])
22648
{
49+
50+
auto [result, arg_parser] = create_args(argc, argv);
51+
if(!result)
52+
return -1;
53+
54+
try
55+
{
22756
#if CK_TILE_USE_WMMA
228-
return !run_grouped_conv_bwd_weight_example<GemmWarpConfig_Wmma>(argc, argv);
57+
return !run_grouped_conv_bwd_weight_example<GemmWarpConfig_Wmma>(arg_parser);
22958
#else
230-
return !run_grouped_conv_bwd_weight_example<GemmWarpConfig_Mfma>(argc, argv);
59+
return !run_grouped_conv_bwd_weight_example<GemmWarpConfig_Mfma>(arg_parser);
23160
#endif
61+
}
62+
catch(const std::runtime_error& e)
63+
{
64+
std::cerr << "Runtime error: " << e.what() << '\n';
65+
return EXIT_FAILURE;
66+
}
23267
}

0 commit comments

Comments
 (0)