11
11
12
12
#include " ck_tile/host.hpp"
13
13
#include " grouped_convolution_utils.hpp"
14
-
15
- template <ck_tile::index_t NDimSpatial,
16
- typename GemmWarpConfig,
17
- typename InDataType,
18
- typename WeiDataType,
19
- typename AccDataType,
20
- typename OutDataType,
21
- typename InLayout,
22
- typename WeiLayout,
23
- typename OutLayout,
24
- typename DsDataType = ck_tile::tuple<>,
25
- typename DsLayout = ck_tile::tuple<>,
26
- typename CDEElementWise = ck_tile::element_wise::PassThrough>
27
- float grouped_conv_bwd_weight (const ck_tile::GroupedConvBwdWeightHostArgs& args,
28
- const ck_tile::stream_config& s)
29
- {
30
- constexpr int kBlockPerCu = 1 ;
31
-
32
- constexpr ck_tile::index_t M_Tile = 64 ;
33
- constexpr ck_tile::index_t N_Tile = 64 ;
34
- constexpr ck_tile::index_t K_Tile = 64 ;
35
-
36
- constexpr ck_tile::index_t M_Warp = 2 ;
37
- constexpr ck_tile::index_t N_Warp = 2 ;
38
- constexpr ck_tile::index_t K_Warp = 1 ;
39
-
40
- constexpr ck_tile::index_t M_Warp_Tile = GemmWarpConfig::M_Warp_Tile;
41
- constexpr ck_tile::index_t N_Warp_Tile = GemmWarpConfig::N_Warp_Tile;
42
- constexpr ck_tile::index_t K_Warp_Tile = GemmWarpConfig::K_Warp_Tile;
43
-
44
- constexpr ck_tile::index_t VectorSizeA = 8 ;
45
- constexpr ck_tile::index_t VectorSizeB = 8 ;
46
- constexpr ck_tile::index_t VectorSizeC = 8 ;
47
-
48
- // Implicit GEMM Traits
49
- using CodegenShape =
50
- ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
51
- ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
52
- ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
53
-
54
- constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
55
- using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenShape>;
56
- using GroupedConvTraitsType =
57
- ck_tile::GroupedConvTraits<NDimSpatial, ConvSpec, InLayout, WeiLayout, DsLayout, OutLayout>;
58
- using CodegenPipelineProblem =
59
- ck_tile::GemmPipelineProblem<InDataType,
60
- WeiDataType,
61
- AccDataType,
62
- CodegenShape,
63
- typename GroupedConvTraitsType::GroupedConvImplicitGemmTraits,
64
- InDataType,
65
- true ,
66
- VectorSizeA,
67
- VectorSizeB>;
68
- using CodegenPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
69
-
70
- const auto Run = [&](const auto memory_operation_) {
71
- constexpr auto memory_operation = memory_operation_.value ;
72
-
73
- using ConvEpilogue = ck_tile::CShuffleEpilogue<
74
- ck_tile::CShuffleEpilogueProblem<InDataType,
75
- WeiDataType,
76
- DsDataType,
77
- AccDataType,
78
- OutDataType,
79
- typename GroupedConvTraitsType::ImplicitGemmDsLayout,
80
- ck_tile::tensor_layout::gemm::RowMajor,
81
- CDEElementWise,
82
- TilePartitioner::MPerBlock,
83
- TilePartitioner::NPerBlock,
84
- M_Warp,
85
- N_Warp,
86
- M_Warp_Tile,
87
- N_Warp_Tile,
88
- K_Warp_Tile,
89
- CodegenPipelineProblem::TransposeC,
90
- memory_operation,
91
- 1 ,
92
- true ,
93
- VectorSizeC>>;
94
-
95
- using Kernel = ck_tile::GroupedConvolutionBackwardWeightKernel<GroupedConvTraitsType,
96
- TilePartitioner,
97
- CodegenPipeline,
98
- ConvEpilogue>;
99
- auto kargs = Kernel::MakeKernelArgs (args);
100
-
101
- const dim3 grids = Kernel::GridSize (kargs);
102
- const dim3 blocks = Kernel::BlockSize ();
103
-
104
- if (!Kernel::IsSupportedArgument (kargs))
105
- {
106
- throw std::runtime_error (" Wrong! Arguments not supported! Skipping conv!\n " );
107
- }
108
-
109
- if (s.log_level_ > 0 )
110
- {
111
- std::cout << " Launching kernel with args: " << Kernel::GetName () << ' \n '
112
- << " shape: " << CodegenShape::GetName () << ' \n '
113
- << " problem: " << CodegenPipelineProblem::GetName () << ' \n '
114
- << " pipeline: " << CodegenPipeline::GetName () << ' \n '
115
- << " grid: {" << grids.x << " , " << grids.y << " , " << grids.z << " }"
116
- << " , blocks: {" << blocks.x << " , " << blocks.y << " , " << blocks.z << " }"
117
- << ' \n '
118
- << " Vector size A: " << CodegenPipeline::GetVectorSizeA ()
119
- << " , Vector size B: " << CodegenPipeline::GetVectorSizeB ()
120
- << " , Vector size C: " << ConvEpilogue::GetVectorSizeC () << std::endl;
121
- }
122
-
123
- float ave_time = ck_tile::launch_kernel_time_mask (
124
- s,
125
- Kernel::Preprocess (kargs, s),
126
- ck_tile::make_kernel<kBlockPerCu >(Kernel{}, grids, blocks, 0 , kargs));
127
-
128
- return ave_time;
129
- };
130
-
131
- if (args.k_batch == 1 )
132
- {
133
- return Run (ck_tile::integral_constant<ck_tile::memory_operation_enum,
134
- ck_tile::memory_operation_enum::set>{});
135
- }
136
- else
137
- {
138
- return Run (ck_tile::integral_constant<ck_tile::memory_operation_enum,
139
- ck_tile::memory_operation_enum::atomic_add>{});
140
- }
141
- }
142
-
14
+ #include " grouped_convolution_backward_weight_invoker.hpp"
143
15
#include " run_grouped_convolution_bwd_weight_example.inc"
144
16
145
- template <typename GemmWarpConfig,
146
- typename InPrecType,
147
- typename WeiPrecType = InPrecType,
148
- typename OutPrecType = InPrecType>
149
- int run_grouped_conv_bwd_weight_example_prec_type (
150
- std::string in_layout, std::string wei_layout, std::string out_layout, int argc, char * argv[])
151
- {
152
- using NWGC = ck_tile::tensor_layout::convolution::NWGC;
153
- using NHWGC = ck_tile::tensor_layout::convolution::NHWGC;
154
- using NDHWGC = ck_tile::tensor_layout::convolution::NDHWGC;
155
-
156
- using GKXC = ck_tile::tensor_layout::convolution::GKXC;
157
- using GKYXC = ck_tile::tensor_layout::convolution::GKYXC;
158
- using GKZYXC = ck_tile::tensor_layout::convolution::GKZYXC;
159
-
160
- using NWGK = ck_tile::tensor_layout::convolution::NWGK;
161
- using NHWGK = ck_tile::tensor_layout::convolution::NHWGK;
162
- using NDHWGK = ck_tile::tensor_layout::convolution::NDHWGK;
163
-
164
- if (in_layout == " NWGC" && wei_layout == " GKXC" && out_layout == " NWGK" )
165
- {
166
- return run_grouped_conv_bwd_weight_example_with_layouts<ck_tile::number<1 >{},
167
- GemmWarpConfig,
168
- InPrecType,
169
- WeiPrecType,
170
- OutPrecType>(
171
- argc, argv, NWGC{}, GKXC{}, NWGK{});
172
- }
173
- else if (in_layout == " NHWGC" && wei_layout == " GKYXC" && out_layout == " NHWGK" )
174
- {
175
- return run_grouped_conv_bwd_weight_example_with_layouts<ck_tile::number<2 >{},
176
- GemmWarpConfig,
177
- InPrecType,
178
- WeiPrecType,
179
- OutPrecType>(
180
- argc, argv, NHWGC{}, GKYXC{}, NHWGK{});
181
- }
182
- else if (in_layout == " NDHWGC" && wei_layout == " GKZYXC" && out_layout == " NDHWGK" )
183
- {
184
- return run_grouped_conv_bwd_weight_example_with_layouts<ck_tile::number<3 >{},
185
- GemmWarpConfig,
186
- InPrecType,
187
- WeiPrecType,
188
- OutPrecType>(
189
- argc, argv, NDHWGC{}, GKZYXC{}, NDHWGK{});
190
- }
191
- else
192
- {
193
- throw std::runtime_error (" Unsupported memory layout!" );
194
- }
195
- }
196
-
197
17
template <typename GemmWarpConfig>
198
- int run_grouped_conv_bwd_weight_example (int argc, char * argv[] )
18
+ int run_grouped_conv_bwd_weight_example (ck_tile::ArgParser& arg_parser )
199
19
{
200
- auto [result, arg_parser] = create_args (argc, argv);
201
- if (!result)
202
- return -1 ;
20
+ using Invoker = GroupedConvolutionBackwardWeightInvoker;
203
21
204
22
std::string data_type = arg_parser.get_str (" prec" );
205
23
std::string in_layout = arg_parser.get_str (" in_layout" );
@@ -208,13 +26,17 @@ int run_grouped_conv_bwd_weight_example(int argc, char* argv[])
208
26
209
27
if (data_type == " fp16" )
210
28
{
211
- return run_grouped_conv_bwd_weight_example_prec_type<GemmWarpConfig, ck_tile::half_t >(
212
- in_layout, wei_layout, out_layout, argc, argv);
29
+ return run_grouped_conv_bwd_weight_example_prec_type<Invoker,
30
+ GemmWarpConfig,
31
+ ck_tile::half_t >(
32
+ in_layout, wei_layout, out_layout, arg_parser);
213
33
}
214
34
else if (data_type == " bf16" )
215
35
{
216
- return run_grouped_conv_bwd_weight_example_prec_type<GemmWarpConfig, ck_tile::bf16_t >(
217
- in_layout, wei_layout, out_layout, argc, argv);
36
+ return run_grouped_conv_bwd_weight_example_prec_type<Invoker,
37
+ GemmWarpConfig,
38
+ ck_tile::bf16_t >(
39
+ in_layout, wei_layout, out_layout, arg_parser);
218
40
}
219
41
else
220
42
{
@@ -224,9 +46,22 @@ int run_grouped_conv_bwd_weight_example(int argc, char* argv[])
224
46
225
47
int main (int argc, char * argv[])
226
48
{
49
+
50
+ auto [result, arg_parser] = create_args (argc, argv);
51
+ if (!result)
52
+ return -1 ;
53
+
54
+ try
55
+ {
227
56
#if CK_TILE_USE_WMMA
228
- return !run_grouped_conv_bwd_weight_example<GemmWarpConfig_Wmma>(argc, argv );
57
+ return !run_grouped_conv_bwd_weight_example<GemmWarpConfig_Wmma>(arg_parser );
229
58
#else
230
- return !run_grouped_conv_bwd_weight_example<GemmWarpConfig_Mfma>(argc, argv );
59
+ return !run_grouped_conv_bwd_weight_example<GemmWarpConfig_Mfma>(arg_parser );
231
60
#endif
61
+ }
62
+ catch (const std::runtime_error& e)
63
+ {
64
+ std::cerr << " Runtime error: " << e.what () << ' \n ' ;
65
+ return EXIT_FAILURE;
66
+ }
232
67
}
0 commit comments