From 58e82b813b2109aca56185a7841189c6645d147c Mon Sep 17 00:00:00 2001 From: yingmaolu Date: Tue, 16 Sep 2025 10:00:41 +0800 Subject: [PATCH 01/19] conv:tf32:add more instances --- example/01_gemm/CMakeLists.txt | 10 + .../01_gemm/gemm_xdl_lds_direct_load_fp32.cpp | 2 +- .../gemm_xdl_lds_direct_load_fp32_tf32.cpp | 4 +- example/09_convnd_fwd/CMakeLists.txt | 9 + example/09_convnd_fwd/convnd_fwd_xdl_fp32.cpp | 22 +- .../convnd_fwd_xdl_fp32_tf32.cpp | 22 +- example/17_convnd_bwd_data/CMakeLists.txt | 17 ++ .../convnd_bwd_data_common.hpp | 9 +- .../convnd_bwd_data_xdl_fp32.cpp | 207 +++++++++++++++++ .../convnd_bwd_data_xdl_fp32_tf32.cpp | 212 ++++++++++++++++++ include/ck/library/utility/check_err.hpp | 53 ++++- .../gpu/block/blockwise_gemm_xdlops.hpp | 30 ++- ...device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp | 48 +++- ...ridwise_gemm_multiple_abd_xdl_cshuffle.hpp | 23 +- .../gpu/grid/gridwise_gemm_xdlops_v2r3.hpp | 45 ++-- include/ck/utility/amd_xdlops.hpp | 4 +- .../cpu/reference_conv_bwd_data.hpp | 22 +- ...grouped_conv_fwd_xdl_bilinear_instance.hpp | 35 +++ .../device_grouped_conv_fwd_xdl_instance.hpp | 19 ++ ...ed_conv_fwd_xdl_merged_groups_instance.hpp | 22 ++ ...ce_grouped_conv_fwd_xdl_scale_instance.hpp | 35 +++ ...uped_conv_fwd_xdl_scaleadd_ab_instance.hpp | 24 +- ...wd_xdl_scaleadd_scaleadd_relu_instance.hpp | 23 ++ .../gpu/grouped_convolution_forward.hpp | 175 ++++++++++----- ...grouped_convolution_forward_bias_clamp.hpp | 90 ++++---- ...ped_convolution_forward_bias_clamp_xdl.inc | 48 ++++ .../grouped_convolution_forward_bilinear.hpp | 28 ++- .../gpu/grouped_convolution_forward_clamp.hpp | 89 ++++---- .../grouped_convolution_forward_clamp_xdl.inc | 48 ++++ .../gpu/grouped_convolution_forward_scale.hpp | 27 ++- ...rouped_convolution_forward_scaleadd_ab.hpp | 27 ++- ...olution_forward_scaleadd_scaleadd_relu.hpp | 28 ++- .../gpu/grouped_convolution_forward_xdl.inc | 81 ++++++- ..._convolution_forward_xdl_merged_groups.inc | 63 ++++++ .../gpu/grouped_conv1d_fwd/CMakeLists.txt | 1 + ...d_xdl_gnwc_gkxc_gnwk_f32_tf32_instance.cpp | 56 +++++ .../gpu/grouped_conv2d_fwd/CMakeLists.txt | 6 + ...dl_gnhwc_gkyxc_gnhwk_f32_tf32_instance.cpp | 66 ++++++ ...dl_ngchw_gkcyx_ngkhw_f32_tf32_instance.cpp | 56 +++++ ...dl_ngchw_gkyxc_ngkhw_f32_tf32_instance.cpp | 41 ++++ ...dl_nhwgc_gkyxc_nhwgk_f32_tf32_instance.cpp | 66 ++++++ ...ps_ngchw_gkcyx_ngkhw_f32_tf32_instance.cpp | 50 +++++ ...ps_nhwgc_gkyxc_nhwgk_f32_tf32_instance.cpp | 50 +++++ .../CMakeLists.txt | 59 ++--- ...ups_nhwgc_gkyxc_nhwgk_f32_tf32_instance.in | 68 ++++++ .../CMakeLists.txt | 2 + ...l_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp | 62 +++++ ...s_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp | 56 +++++ .../grouped_conv2d_fwd_clamp/CMakeLists.txt | 2 + ...l_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp | 62 +++++ ...s_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp | 55 +++++ .../gpu/grouped_conv3d_fwd/CMakeLists.txt | 2 + ...ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp | 49 ++++ ...ngcdhw_gkczyx_ngkdhw_f32_tf32_instance.cpp | 49 ++++ .../CMakeLists.txt | 9 + ..._ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.in | 68 ++++++ .../CMakeLists.txt | 1 + ...dhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp | 53 +++++ .../CMakeLists.txt | 1 + ...ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp | 57 +++++ .../grouped_conv3d_fwd_clamp/CMakeLists.txt | 1 + ...dhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp | 53 +++++ .../grouped_conv3d_fwd_scale/CMakeLists.txt | 1 + ...ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp | 57 +++++ .../CMakeLists.txt | 1 + ...ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp | 54 +++++ .../CMakeLists.txt | 1 + ...ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp | 60 +++++ 68 files changed, 2619 insertions(+), 257 deletions(-) create mode 100644 example/17_convnd_bwd_data/convnd_bwd_data_xdl_fp32.cpp create mode 100644 example/17_convnd_bwd_data/convnd_bwd_data_xdl_fp32_tf32.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/xdl/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_tf32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_tf32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_tf32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_tf32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_f32_tf32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/xdl/merged_groups/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instance.in create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/merged_groups/device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/merged_groups/device_grouped_conv2d_fwd_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_f32_tf32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/xdl/merged_groups/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.in create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/xdl/merged_groups/device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/xdl/device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/merged_groups/device_grouped_conv3d_fwd_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt index 03bde864214..d24362cc171 100644 --- a/example/01_gemm/CMakeLists.txt +++ b/example/01_gemm/CMakeLists.txt @@ -115,6 +115,16 @@ foreach(gpu IN LISTS GPU_TARGETS) endif() endforeach() +list(APPEND gpu_list gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_example_executable(example_gemm_xdl_lds_direct_load_fp32_tf32 gemm_xdl_lds_direct_load_fp32_tf32.cpp) + add_example_dependencies(example_gemm_xdl example_gemm_xdl_lds_direct_load_fp32_tf32) + set(target 1) + endif() +endforeach() + add_example_executable(example_gemm_xdl_fp8 gemm_xdl_fp8.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8) diff --git a/example/01_gemm/gemm_xdl_lds_direct_load_fp32.cpp b/example/01_gemm/gemm_xdl_lds_direct_load_fp32.cpp index 75971bdecf3..3cff8b30e2e 100644 --- a/example/01_gemm/gemm_xdl_lds_direct_load_fp32.cpp +++ b/example/01_gemm/gemm_xdl_lds_direct_load_fp32.cpp @@ -37,7 +37,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle // ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, 1, 1, S<1, 8, 1, 8>, 4>; + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 2, S<4, 64, 1>, S<1, 0, 2>, 2, 1, 1, S<4, 64, 1>, S<1, 0, 2>, 2, 1, 1, 1, 1, S<1, 32, 1, 4>, 4>; // clang-format on #else // clang-format off diff --git a/example/01_gemm/gemm_xdl_lds_direct_load_fp32_tf32.cpp b/example/01_gemm/gemm_xdl_lds_direct_load_fp32_tf32.cpp index 9b92fad779b..9b2c2df09e5 100644 --- a/example/01_gemm/gemm_xdl_lds_direct_load_fp32_tf32.cpp +++ b/example/01_gemm/gemm_xdl_lds_direct_load_fp32_tf32.cpp @@ -43,9 +43,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle // ######| XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 128, 128, 32, - 8, 8, 32, 32, 2, 2, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, - 1, 1, S<1, 8, 1, 8>, 4, ck::LoopScheduler::Default, ck::PipelineVersion::v4, ComputeDataType>; + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 2, S<4, 64, 1>, S<1, 0, 2>, 2, 1, 1, S<4, 64, 1>, S<1, 0, 2>, 2, 1, 1, 1, 1, S<1, 32, 1, 4>, 4, ck::LoopScheduler::Default, ck::PipelineVersion::v4, ComputeDataType>; // clang-format on #else // clang-format off diff --git a/example/09_convnd_fwd/CMakeLists.txt b/example/09_convnd_fwd/CMakeLists.txt index 4f174bfcbb2..67766011c6b 100644 --- a/example/09_convnd_fwd/CMakeLists.txt +++ b/example/09_convnd_fwd/CMakeLists.txt @@ -29,3 +29,12 @@ foreach(gpu IN LISTS GPU_TARGETS) set(target 1) endif() endforeach() + +list(APPEND gpu_tf32_list gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_tf32_list AND target EQUAL 0) + add_example_executable(example_convnd_fwd_xdl_fp32_tf32 convnd_fwd_xdl_fp32_tf32.cpp) + set(target 1) + endif() +endforeach() diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_fp32.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_fp32.cpp index 40c38b39d87..d4e2a8e0b71 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl_fp32.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl_fp32.cpp @@ -46,32 +46,32 @@ using DeviceGroupedConvNDFwdInstance = GemmSpec, // GemmSpecialization 1, // 256, // BlockSize - 128, // MPerBlock - 256, // NPerBlock - 16, // KPerBlock - 4, // AK1 - 4, // BK1 - 32, // MPerXdl - 32, // NPerXdl + 64, // MPerBlock + 64, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 16, // MPerXdl + 16, // NPerXdl 2, // MXdlPerWave - 4, // NXdlPerWave + 2, // NXdlPerWave S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder 2, // ABlockTransferSrcVectorDim 4, // ABlockTransferSrcScalarPerVector - 4, // ABlockTransferDstScalarPerVector_AK1 + 8, // ABlockTransferDstScalarPerVector_AK1 1, // ABlockLdsExtraM S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // BBlockTransferSrcAccessOrder 2, // BBlockTransferSrcVectorDim 4, // BBlockTransferSrcScalarPerVector - 4, // BBlockTransferDstScalarPerVector_BK1 + 8, // BBlockTransferDstScalarPerVector_BK1 1, // BBlockLdsExtraN 1, 1, - S<1, 16, 1, 16>, + S<1, 32, 1, 4>, 4>; #include "run_convnd_fwd_example.inc" diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_fp32_tf32.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_fp32_tf32.cpp index 348da7e1ef4..9264aee24d6 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl_fp32_tf32.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl_fp32_tf32.cpp @@ -49,32 +49,32 @@ using DeviceGroupedConvNDFwdInstance = GemmSpec, // GemmSpecialization 1, // NumGemmKPrefetchStage 256, // BlockSize - 128, // MPerBlock - 192, // NPerBlock - 16, // KPerBlock - 4, // AK1 - 4, // BK1 - 32, // MPerXdl - 32, // NPerXdl + 64, // MPerBlock + 64, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 16, // MPerXdl + 16, // NPerXdl 2, // MXdlPerWave - 3, // NXdlPerWave + 2, // NXdlPerWave S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder 2, // ABlockTransferSrcVectorDim 4, // ABlockTransferSrcScalarPerVector - 4, // ABlockTransferDstScalarPerVector_AK1 + 8, // ABlockTransferDstScalarPerVector_AK1 1, // ABlockLdsExtraM S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // BBlockTransferSrcAccessOrder 2, // BBlockTransferSrcVectorDim 4, // BBlockTransferSrcScalarPerVector - 4, // BBlockTransferDstScalarPerVector_BK1 + 8, // BBlockTransferDstScalarPerVector_BK1 1, // BBlockLdsExtraN 1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleNXdlPerWavePerShuffle - S<1, 16, 1, 16>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + S<1, 32, 1, 4>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock 4, // CDEBlockTransferScalarPerVector_NPerBlock ComputeDataType, // AComputeDataType ComputeDataType, // BComputeDataType diff --git a/example/17_convnd_bwd_data/CMakeLists.txt b/example/17_convnd_bwd_data/CMakeLists.txt index 39f9fb8ec06..70228d08938 100644 --- a/example/17_convnd_bwd_data/CMakeLists.txt +++ b/example/17_convnd_bwd_data/CMakeLists.txt @@ -3,6 +3,23 @@ if(result EQUAL 0) target_link_libraries(example_convnd_bwd_data_xdl_fp16 PRIVATE utility) endif() +add_example_executable(example_convnd_bwd_data_xdl_fp32 convnd_bwd_data_xdl_fp32.cpp) +if(result EQUAL 0) + target_link_libraries(example_convnd_bwd_data_xdl_fp32 PRIVATE utility) +endif() + +list(APPEND gpu_list gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_example_executable(example_convnd_bwd_data_xdl_fp32_tf32 convnd_bwd_data_xdl_fp32_tf32.cpp) + if(result EQUAL 0) + target_link_libraries(example_convnd_bwd_data_xdl_fp32_tf32 PRIVATE utility) + endif() + set(target 1) + endif() +endforeach() + add_example_executable(example_convnd_bwd_data_dl_fp16 convnd_bwd_data_dl_fp16.cpp) if(result EQUAL 0) target_link_libraries(example_convnd_bwd_data_dl_fp16 PRIVATE utility) diff --git a/example/17_convnd_bwd_data/convnd_bwd_data_common.hpp b/example/17_convnd_bwd_data/convnd_bwd_data_common.hpp index d219df02453..aead9734901 100644 --- a/example/17_convnd_bwd_data/convnd_bwd_data_common.hpp +++ b/example/17_convnd_bwd_data/convnd_bwd_data_common.hpp @@ -33,7 +33,8 @@ template + typename DeviceConvNdBwdDataInstance, + typename ComputeDataType = OutDataType> int run_conv_bwd_data(bool do_verification, int init_method, bool time_kernel, @@ -150,7 +151,11 @@ int run_conv_bwd_data(bool do_verification, OutDataType, InElementOp, WeiElementOp, - OutElementOp>(); + OutElementOp, + 0, + 0, + 0, + ComputeDataType>(); auto ref_invoker = ref_conv.MakeInvoker(); diff --git a/example/17_convnd_bwd_data/convnd_bwd_data_xdl_fp32.cpp b/example/17_convnd_bwd_data/convnd_bwd_data_xdl_fp32.cpp new file mode 100644 index 00000000000..c4037842a3a --- /dev/null +++ b/example/17_convnd_bwd_data/convnd_bwd_data_xdl_fp32.cpp @@ -0,0 +1,207 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_bwd_data_common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp" + +using InDataType = float; +using WeiDataType = float; +using OutDataType = float; +using AccDataType = float; + +template +using S = ck::Sequence; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvBwdDefault = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default; + +template +using DeviceConvNdBwdDataInstance = ck::tensor_operation::device::DeviceConvNdBwdDataNwcKxcNwk_Xdl< + NDimSpatial, // NDimSpatial + InDataType, // InDataType + WeiDataType, // WeiDataType + OutDataType, // OutDataType + AccDataType, // AccDataType + InElementOp, // InElementwiseOperation + WeiElementOp, // WeiElementwiseOperation + OutElementOp, // OutElementwiseOperation + ConvBwdDefault, // ConvolutionBackwardDataSpecialization + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 4, // K0PerBlock + 8, // K1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 2, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_K1 + true, // ABlockLdsAddExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + S<2, 0, 1>, // BBlockTransferThreadClusterArrangeOrder + S<0, 2, 1>, // BBlockTransferSrcAccessOrder + 1, // BBlockTransferSrcVectorDim + 2, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_K1 + true, // BBlockLdsAddExtraN + 7, + 1>; // GemmCThreadTransferDstScalarPerVector + +int main(int argc, char* argv[]) +{ + namespace ctc = ck::tensor_layout::convolution; + + print_helper_msg(); + + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + ck::utils::conv::ConvParam conv_param{ + 2, 1, 128, 256, 256, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, {1, 1}}; + + if(argc == 1) + { + // use default + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + const ck::index_t num_dim_spatial = std::stoi(argv[4]); + + conv_param = ck::utils::conv::parse_conv_param(num_dim_spatial, 5, argv); + } + + const auto in_element_op = InElementOp{}; + const auto wei_element_op = WeiElementOp{}; + const auto out_element_op = OutElementOp{}; + + if(conv_param.num_dim_spatial_ == 1) + { + using InLayout = ctc::GNWC; + using WeiLayout = ctc::GKXC; + using OutLayout = ctc::GNWK; + + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed( + conv_param); + + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed( + conv_param); + + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed( + conv_param); + + return run_conv_bwd_data<1, + InDataType, + WeiDataType, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + DeviceConvNdBwdDataInstance<1>>(do_verification, + init_method, + time_kernel, + conv_param, + in_g_n_c_wis_desc, + wei_g_k_c_xs_desc, + out_g_n_k_wos_desc, + in_element_op, + wei_element_op, + out_element_op); + } + else if(conv_param.num_dim_spatial_ == 2) + { + using InLayout = ctc::GNHWC; + using WeiLayout = ctc::GKYXC; + using OutLayout = ctc::GNHWK; + + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed( + conv_param); + + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed( + conv_param); + + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed( + conv_param); + + return run_conv_bwd_data<2, + InDataType, + WeiDataType, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + DeviceConvNdBwdDataInstance<2>>(do_verification, + init_method, + time_kernel, + conv_param, + in_g_n_c_wis_desc, + wei_g_k_c_xs_desc, + out_g_n_k_wos_desc, + in_element_op, + wei_element_op, + out_element_op); + } + else if(conv_param.num_dim_spatial_ == 3) + { + using InLayout = ctc::GNDHWC; + using WeiLayout = ctc::GKZYXC; + using OutLayout = ctc::GNDHWK; + + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed( + conv_param); + + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed( + conv_param); + + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed( + conv_param); + + return run_conv_bwd_data<3, + InDataType, + WeiDataType, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + DeviceConvNdBwdDataInstance<3>>(do_verification, + init_method, + time_kernel, + conv_param, + in_g_n_c_wis_desc, + wei_g_k_c_xs_desc, + out_g_n_k_wos_desc, + in_element_op, + wei_element_op, + out_element_op); + } + + return 0; +} diff --git a/example/17_convnd_bwd_data/convnd_bwd_data_xdl_fp32_tf32.cpp b/example/17_convnd_bwd_data/convnd_bwd_data_xdl_fp32_tf32.cpp new file mode 100644 index 00000000000..b4a0a2273a9 --- /dev/null +++ b/example/17_convnd_bwd_data/convnd_bwd_data_xdl_fp32_tf32.cpp @@ -0,0 +1,212 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_bwd_data_common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp" + +using InDataType = float; +using WeiDataType = float; +using OutDataType = float; +using AccDataType = float; +using ComputeDataType = ck::tf32_t; + +template +using S = ck::Sequence; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvBwdDefault = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default; + +template +using DeviceConvNdBwdDataInstance = ck::tensor_operation::device::DeviceConvNdBwdDataNwcKxcNwk_Xdl< + NDimSpatial, // NDimSpatial + InDataType, // InDataType + WeiDataType, // WeiDataType + OutDataType, // OutDataType + AccDataType, // AccDataType + InElementOp, // InElementwiseOperation + WeiElementOp, // WeiElementwiseOperation + OutElementOp, // OutElementwiseOperation + ConvBwdDefault, // ConvolutionBackwardDataSpecialization + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 4, // K0PerBlock + 8, // K1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 2, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_K1 + true, // ABlockLdsAddExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + S<2, 0, 1>, // BBlockTransferThreadClusterArrangeOrder + S<0, 2, 1>, // BBlockTransferSrcAccessOrder + 1, // BBlockTransferSrcVectorDim + 2, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_K1 + true, // BBlockLdsAddExtraN + 7, // CThreadTransferSrcDstVectorDim + 1, // GemmCThreadTransferDstScalarPerVector + ComputeDataType>; + +int main(int argc, char* argv[]) +{ + namespace ctc = ck::tensor_layout::convolution; + + print_helper_msg(); + + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + ck::utils::conv::ConvParam conv_param{ + 2, 1, 128, 256, 256, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, {1, 1}}; + + if(argc == 1) + { + // use default + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + const ck::index_t num_dim_spatial = std::stoi(argv[4]); + + conv_param = ck::utils::conv::parse_conv_param(num_dim_spatial, 5, argv); + } + + const auto in_element_op = InElementOp{}; + const auto wei_element_op = WeiElementOp{}; + const auto out_element_op = OutElementOp{}; + + if(conv_param.num_dim_spatial_ == 1) + { + using InLayout = ctc::GNWC; + using WeiLayout = ctc::GKXC; + using OutLayout = ctc::GNWK; + + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed( + conv_param); + + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed( + conv_param); + + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed( + conv_param); + + return run_conv_bwd_data<1, + InDataType, + WeiDataType, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + DeviceConvNdBwdDataInstance<1>, + ComputeDataType>(do_verification, + init_method, + time_kernel, + conv_param, + in_g_n_c_wis_desc, + wei_g_k_c_xs_desc, + out_g_n_k_wos_desc, + in_element_op, + wei_element_op, + out_element_op); + } + else if(conv_param.num_dim_spatial_ == 2) + { + using InLayout = ctc::GNHWC; + using WeiLayout = ctc::GKYXC; + using OutLayout = ctc::GNHWK; + + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed( + conv_param); + + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed( + conv_param); + + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed( + conv_param); + + return run_conv_bwd_data<2, + InDataType, + WeiDataType, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + DeviceConvNdBwdDataInstance<2>, + ComputeDataType>(do_verification, + init_method, + time_kernel, + conv_param, + in_g_n_c_wis_desc, + wei_g_k_c_xs_desc, + out_g_n_k_wos_desc, + in_element_op, + wei_element_op, + out_element_op); + } + else if(conv_param.num_dim_spatial_ == 3) + { + using InLayout = ctc::GNDHWC; + using WeiLayout = ctc::GKZYXC; + using OutLayout = ctc::GNDHWK; + + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed( + conv_param); + + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed( + conv_param); + + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed( + conv_param); + + return run_conv_bwd_data<3, + InDataType, + WeiDataType, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + DeviceConvNdBwdDataInstance<3>, + ComputeDataType>(do_verification, + init_method, + time_kernel, + conv_param, + in_g_n_c_wis_desc, + wei_g_k_c_xs_desc, + out_g_n_k_wos_desc, + in_element_op, + wei_element_op, + out_element_op); + } + + return 0; +} diff --git a/include/ck/library/utility/check_err.hpp b/include/ck/library/utility/check_err.hpp index 185166f7ec3..ade4d9a5b4a 100644 --- a/include/ck/library/utility/check_err.hpp +++ b/include/ck/library/utility/check_err.hpp @@ -149,11 +149,62 @@ double get_absolute_threshold(const double max_possible_num, const int number_of return std::max(acc_error, midway_error); } +template +typename std::enable_if< + std::is_same_v, ranges::range_value_t> && + std::is_same_v, float>, + bool>::type +check_err(const Range& out, + const RefRange& ref, + const std::string& msg = "Error: Incorrect results!", + double rtol = 1e-5, + double atol = 3e-5) +{ + if(out.size() != ref.size()) + { + std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size() + << std::endl; + return false; + } + + bool res{true}; + int err_count = 0; + double err = 0; + double max_err = std::numeric_limits::min(); + for(std::size_t i = 0; i < ref.size(); ++i) + { + const double o = *std::next(std::begin(out), i); + const double r = *std::next(std::begin(ref), i); + err = std::abs(o - r); + if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r)) + { + max_err = err > max_err ? err : max_err; + if(err_count < 5) + { + std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i + << "] != ref[" << i << "]: " << o << " != " << r << std::endl; + } + res = false; + err_count++; + } + } + if(!res) + { + const float error_percent = + static_cast(err_count) / static_cast(out.size()) * 100.f; + std::cerr << "max err: " << max_err; + std::cerr << ", number of errors: " << err_count; + std::cerr << ", " << error_percent << "% wrong values" << std::endl; + } + return res; +} + template typename std::enable_if< std::is_same_v, ranges::range_value_t> && std::is_floating_point_v> && - !std::is_same_v, half_t>, + !std::is_same_v, half_t> && + !std::is_same_v, float>, bool>::type check_err(const Range& out, const RefRange& ref, diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp index 55015dd30f7..9fdd12adfbb 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp @@ -69,7 +69,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 static constexpr index_t WaveSize = BlockSize / MWaves / NWaves; static constexpr auto xdlops_gemm = - XdlopsGemm{}; + XdlopsGemm{}; static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops; @@ -637,21 +637,19 @@ constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector() } else if constexpr(LoopSched == LoopScheduler::Interwave) { - return BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1< - BlockSize, - FloatA, - FloatB, - FloatAcc, - AK0MK1BlockDesc, - BK0NK1BlockDesc, - MPerXDL, - NPerXDL, - MRepeat, - NRepeat, - KPack, - ComputeTypeA, - ComputeTypeB, - CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS>{}; + return BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; } }; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp index d0743421272..403a7be9689 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp @@ -55,7 +55,8 @@ template + ck::index_t CThreadTransferDstScalarPerVector, + typename ComputeDataType = InDataType> struct DeviceConvNdBwdDataNwcKxcNwk_Xdl : public DeviceConvBwdData< NDimSpatial, @@ -78,6 +79,14 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl WeiElementwiseOperation, OutElementwiseOperation> { + + DeviceConvNdBwdDataNwcKxcNwk_Xdl() + { + static_assert(is_same_v || + (is_same_v && is_same_v), + "InDataType and ComputeDataType need to be the same or (InDataType=float and " + "ComputeDataType=tf32_t)"); + } using DeviceOp = DeviceConvNdBwdDataNwcKxcNwk_Xdl; GET_NXDL_PER_WAVE_IMPL @@ -89,7 +98,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl using CDataType = InDataType; // TODO make A/B datatype different - using ABDataType = InDataType; + using ABDataType = ComputeDataType; static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -1195,6 +1204,36 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl } } + void Print() const + { + std::cout << "InDataType: " << get_type_name() + << "; WeiDataType: " << get_type_name() + << "; OutDataType: " << get_type_name() + << "; AccDataType: " << get_type_name() << std::endl; + auto print_v = [](std::ostream& os, + const std::vector& v, + const std::string& name) -> std::ostream& { + os << name << ": ["; + for(size_t i = 0; i < v.size(); ++i) + { + os << v[i]; + if(i + 1 < v.size()) + os << ", "; + } + os << "]"; + return os; + }; + std::cout << "Conv params: Ndims: " << NDimSpatial << ", N: " << Conv_N_ + << ", K: " << Conv_K_ << ", C: " << Conv_C_ << "\n\t"; + print_v(std::cout, input_spatial_lengths_, "input_spatial_lengths") << "\n\t"; + print_v(std::cout, filter_spatial_lengths_, "filter_spatial_lengths") << "\n\t"; + print_v(std::cout, output_spatial_lengths_, "output_spatial_lengths") << "\n\t"; + print_v(std::cout, conv_filter_strides_, "conv_filter_strides") << "\n\t"; + print_v(std::cout, conv_filter_dilations_, "conv_filter_dilations") << "\n\t"; + print_v(std::cout, input_left_pads_, "input_left_pads") << "\n\t"; + print_v(std::cout, input_right_pads_, "input_right_pads") << std::endl; + } + const ADataType* p_a_grid_; const BDataType* p_b_grid_; CDataType* p_c_grid_; @@ -1226,6 +1265,11 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl template float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { + + if(stream_config.log_level_ > 0) + { + arg.Print(); + } float ave_time = 0; for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++) { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp index cbad6a56739..60dfead1f22 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp @@ -107,8 +107,10 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle using BComputeDataType = conditional_t, ck::bhalf_t, BComputeDataType_>; #else - using AComputeDataType = AComputeDataType_; - using BComputeDataType = BComputeDataType_; + using AComputeDataType = + conditional_t, float, AComputeDataType_>; + using BComputeDataType = + conditional_t, float, BComputeDataType_>; #endif __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() @@ -679,20 +681,21 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle // sanity check constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1); constexpr bool is_single_rate_mfma = - (((is_same::value || - is_same::value) && + (((is_same::value || + is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8) || - ((is_same::value || is_same::value) && + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || + is_same::value) && lcm_AK1_BK1 < 32)) ? true : false; static constexpr auto is_scale_mfma = false; constexpr index_t KPack = math::max(lcm_AK1_BK1, - MfmaSelector::selected_mfma.k_per_blk); @@ -709,8 +712,8 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle NXdlPerWave, KPack, LoopSched, - AComputeDataType, - BComputeDataType>(); + AComputeDataType_, + BComputeDataType_>(); auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp index aa7ce1f5b6a..d2418c09133 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp @@ -164,6 +164,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 using ThisThreadBlock = ThisThreadBlock; + using ElementDataTypeAB = conditional_t, float, FloatAB>; + __host__ static auto CalculateGridSize(index_t M, index_t N) { return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, 1); @@ -236,8 +238,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 // Argument struct Argument : public Problem, public tensor_operation::device::BaseArgument { - __host__ Argument(const FloatAB* p_a_grid_, - const FloatAB* p_b_grid_, + __host__ Argument(const ElementDataTypeAB* p_a_grid_, + const ElementDataTypeAB* p_b_grid_, FloatC* p_c_grid_, index_t M_, index_t N_, @@ -252,8 +254,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 { } - const FloatAB* p_a_grid; - const FloatAB* p_b_grid; + const ElementDataTypeAB* p_a_grid; + const ElementDataTypeAB* p_b_grid; FloatC* p_c_grid; }; @@ -329,7 +331,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 constexpr auto b_block_space_size_aligned = math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align); - return (a_block_space_size_aligned + b_block_space_size_aligned) * sizeof(FloatAB); + return (a_block_space_size_aligned + b_block_space_size_aligned) * + sizeof(ElementDataTypeAB); } template < @@ -450,8 +453,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 using BlockwiseGemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1; + K1, + FloatABAdjusted, + FloatABAdjusted>; return BlockwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n); } @@ -471,8 +476,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 typename AGridDesc_K0_M_K1, typename BGridDesc_K0_N_K1, typename CGridDesc_M_N> - __device__ static void Run(const FloatAB* p_a_grid, - const FloatAB* p_b_grid, + __device__ static void Run(const ElementDataTypeAB* p_a_grid, + const ElementDataTypeAB* p_b_grid, FloatC* p_c_grid, void* __restrict__ p_shared, const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, @@ -533,8 +538,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 Sequence, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatABAdjusted, + ElementDataTypeAB, + ElementDataTypeAB, decltype(a_grid_desc_k0_m_k1), decltype(a_block_desc_k0_m_k1), ABlockTransferSrcAccessOrder, @@ -564,8 +569,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 Sequence, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatABAdjusted, + ElementDataTypeAB, + ElementDataTypeAB, decltype(b_grid_desc_k0_n_k1), decltype(b_block_desc_k0_n_k1), BBlockTransferSrcAccessOrder, @@ -595,8 +600,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 // sanity check auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, - FloatABAdjusted, - FloatABAdjusted, + ElementDataTypeAB, + ElementDataTypeAB, FloatAcc, decltype(a_block_desc_k0_m_k1), decltype(b_block_desc_k0_n_k1), @@ -605,7 +610,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 MXdlPerWave, NXdlPerWave, K1, - LoopSched>(); + LoopSched, + FloatABAdjusted, + FloatABAdjusted>(); auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); @@ -614,10 +621,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); auto a_block_buf = make_dynamic_buffer( - static_cast(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize()); + static_cast(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize()); auto b_block_buf = make_dynamic_buffer( - static_cast(p_shared) + a_block_space_size_aligned, + static_cast(p_shared) + a_block_space_size_aligned, b_block_desc_k0_n_k1.GetElementSpaceSize()); constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); diff --git a/include/ck/utility/amd_xdlops.hpp b/include/ck/utility/amd_xdlops.hpp index be3a5cea423..7ff8e6b057a 100644 --- a/include/ck/utility/amd_xdlops.hpp +++ b/include/ck/utility/amd_xdlops.hpp @@ -1647,8 +1647,8 @@ struct intrin_mfma_f32_16x16x8xf32<16, 16> __device__ static void Run(const float2_t& reg_a, const float2_t& reg_b, FloatC& reg_c) { #if defined(__gfx94__) - reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x8_xf32( - reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x8_xf32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); #else ignore = reg_a; ignore = reg_b; diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp index 10b169c21e6..54f190b3ec6 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp @@ -28,6 +28,7 @@ template = 1 && NDimSpatial <= 3, bool>::type = false> struct ReferenceConvBwdData : public device::BaseOperator { @@ -142,8 +143,10 @@ struct ReferenceConvBwdData : public device::BaseOperator c, x); - v_acc += ck::type_convert(v_out) * - ck::type_convert(v_wei); + v_acc += ck::type_convert( + ck::type_convert(v_out)) * + ck::type_convert( + ck::type_convert(v_wei)); } } } @@ -235,8 +238,11 @@ struct ReferenceConvBwdData : public device::BaseOperator y, x); - v_acc += ck::type_convert(v_out) * - ck::type_convert(v_wei); + v_acc += + ck::type_convert( + ck::type_convert(v_out)) * + ck::type_convert( + ck::type_convert(v_wei)); } } } @@ -354,8 +360,12 @@ struct ReferenceConvBwdData : public device::BaseOperator x); v_acc += - ck::type_convert(v_out) * - ck::type_convert(v_wei); + ck::type_convert( + ck::type_convert< + ComputeDataType>(v_out)) * + ck::type_convert( + ck::type_convert< + ComputeDataType>(v_wei)); } } } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_bilinear_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_bilinear_instance.hpp index 1c3bfef8cec..416e64b5347 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_bilinear_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_bilinear_instance.hpp @@ -16,6 +16,7 @@ namespace instance { using BF16 = ck::bhalf_t; using F16 = ck::half_t; using F32 = float; +using TF32 = ck::tf32_t; template using S = ck::Sequence; @@ -139,6 +140,40 @@ using device_grouped_conv_fwd_xdl_bilinear_f32_instances = std::tuple< // clang-format on >; +template +using device_grouped_conv_fwd_xdl_bilinear_f32_tf32_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1, TF32, TF32>, + // instances for small conv.K and conv.C + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 1, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>, + + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4, TF32, TF32> + // clang-format on + >; + template ; +template , + typename OutElementOp = PassThrough> +using device_grouped_conv_fwd_xdl_f32_tf32_generic_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| AComputeType| BComputeType| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| DATATYPE | DATATYPE | + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl | + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1, TF32, TF32> + // clang-format on + >; + template using S = ck::Sequence; @@ -142,6 +143,27 @@ using device_grouped_conv_fwd_xdl_merged_groups_f32_instances = std::tuple< // clang-format on >; +template , + typename OutElementOp = PassThrough> +using device_grouped_conv_fwd_xdl_merged_groups_f32_tf32_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // Instances with NumGroupsPerBatch > 1 + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, TF32, TF32, LoopScheduler::Default, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, TF32, TF32, LoopScheduler::Default, 16>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, TF32, TF32, LoopScheduler::Default, 32> + // clang-format on + >; + template using S = ck::Sequence; @@ -139,6 +140,40 @@ using device_grouped_conv_fwd_xdl_scale_f32_instances = std::tuple< // clang-format on >; +template +using device_grouped_conv_fwd_xdl_scale_f32_tf32_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1, TF32, TF32>, + // instances for small conv.K and conv.C + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 1, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>, + + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4, TF32, TF32> + // clang-format on + >; + template using S = ck::Sequence; @@ -89,7 +90,7 @@ using device_grouped_conv_fwd_xdl_scaleadd_ab_f32_instances = std::tuple< //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // generic instance DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle,ELayout, ck::Tuple, ck::Tuple, F32, F32, ck::Tuple<>, F32, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 8>, 1>, - // instances for small conv.K and conv.C + // instances for small conv.K and conv.C DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle,ELayout, ck::Tuple, ck::Tuple, F32, F32, ck::Tuple<>, F32, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 1>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle,ELayout, ck::Tuple, ck::Tuple, F32, F32, ck::Tuple<>, F32, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 4>, @@ -97,6 +98,27 @@ using device_grouped_conv_fwd_xdl_scaleadd_ab_f32_instances = std::tuple< // clang-format on >; +template +using device_grouped_conv_fwd_xdl_scaleadd_ab_f32_tf32_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle,ELayout, ck::Tuple, ck::Tuple, F32, F32, ck::Tuple<>, F32, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 8>, 1, TF32, TF32>, + // instances for small conv.K and conv.C + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle,ELayout, ck::Tuple, ck::Tuple, F32, F32, ck::Tuple<>, F32, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 1, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle,ELayout, ck::Tuple, ck::Tuple, F32, F32, ck::Tuple<>, F32, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>, + + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle,ELayout, ck::Tuple, ck::Tuple, F32, F32, ck::Tuple<>, F32, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32> + // clang-format on + >; + template using S = ck::Sequence; @@ -103,6 +104,28 @@ using device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f32_instances = std::tu // clang-format on >; +template +using device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f32_tf32_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, ScaleAddScaleAddRelu, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1, TF32, TF32>, + // instances for small conv.K and conv.C + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, ScaleAddScaleAddRelu, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 1, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, ScaleAddScaleAddRelu, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>, + + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, ScaleAddScaleAddRelu, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32> + // clang-format on + >; + template && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_tf32_instances(op_ptrs); + } #endif #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && @@ -169,10 +175,18 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v && - is_same_v) + is_same_v) { - add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances(op_ptrs); + if constexpr(is_same_v && is_same_v) + { + + add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_tf32_instances(op_ptrs); + } + else + { + + add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances(op_ptrs); + } } #endif #ifdef CK_ENABLE_FP16 @@ -200,20 +214,29 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v && - is_same_v) - { - add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs); - add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_16x16_instances(op_ptrs); - add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instances( - op_ptrs); - add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances( - op_ptrs); - add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances(op_ptrs); - add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances( - op_ptrs); - add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances( - op_ptrs); + is_same_v) + { + if constexpr(is_same_v && is_same_v) + { + add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances(op_ptrs); + add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instances( + op_ptrs); + } + else + { + add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs); + add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_16x16_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances(op_ptrs); + add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances( + op_ptrs); + } } #endif #ifdef CK_ENABLE_FP16 @@ -285,18 +308,27 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v && - is_same_v) - { - add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_f32_instances( - op_ptrs); - add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_instances(op_ptrs); - add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_16x16_instances(op_ptrs); - add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_comp_instances(op_ptrs); - add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_mem_intra_instances( - op_ptrs); - add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_mem_inter_instances( - op_ptrs); + is_same_v) + { + if constexpr(is_same_v && is_same_v) + { + add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_f32_tf32_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_tf32_instances(op_ptrs); + } + else + { + add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_f32_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_instances(op_ptrs); + add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_16x16_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_comp_instances(op_ptrs); + add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_mem_intra_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_mem_inter_instances( + op_ptrs); + } } #endif #ifdef CK_ENABLE_FP16 @@ -352,6 +384,12 @@ struct DeviceOperationInstanceFactory && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_tf32_instances(op_ptrs); + } #endif #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && @@ -428,27 +466,33 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v && - is_same_v) - { - add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(op_ptrs); - add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances(op_ptrs); - add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instances( - op_ptrs); - add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances( - op_ptrs); - add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances(op_ptrs); - add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances( - op_ptrs); - add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances( - op_ptrs); - } - if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v) - { - add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(op_ptrs); + is_same_v) + { + if constexpr(is_same_v && is_same_v) + { + add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + op_ptrs); + } + else + { + add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances( + op_ptrs); + } } + #endif #ifdef CK_ENABLE_FP8 @@ -546,18 +590,27 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v && - is_same_v) - { - add_device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_f32_instances( - op_ptrs); - add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_instances(op_ptrs); - add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_16x16_instances(op_ptrs); - add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_comp_instances(op_ptrs); - add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instances( - op_ptrs); - add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instances( - op_ptrs); + is_same_v) + { + if constexpr(is_same_v && is_same_v) + { + add_device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_f32_tf32_instances( + op_ptrs); + } + else + { + add_device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_f32_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_instances(op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_16x16_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_comp_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instances( + op_ptrs); + } } #endif #ifdef CK_ENABLE_FP16 diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp.hpp index 11e827878c1..2b5ac46ca34 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp.hpp @@ -127,24 +127,34 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v && - is_same_v) + is_same_v) { - add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_instances( - op_ptrs); - add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_16x16_instances( - op_ptrs); - add_device_grouped_conv2d_fwd_bias_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instances( - op_ptrs); - add_device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances( - op_ptrs); - add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances( - op_ptrs); - add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances( - op_ptrs); - add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances( - op_ptrs); + if constexpr(is_same_v && is_same_v) + { + add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instances( + op_ptrs); + } + else + { + add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_16x16_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_bias_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances( + op_ptrs); + } } + #endif } // layout NDHWGC/GKZYXC/NDHWGK @@ -197,32 +207,34 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v && - is_same_v) + is_same_v) { - add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( - op_ptrs); - add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances( - op_ptrs); - add_device_grouped_conv3d_fwd_bias_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instances( - op_ptrs); - add_device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances( - op_ptrs); - add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances( - op_ptrs); - add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances( - op_ptrs); - add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances( - op_ptrs); + if constexpr(is_same_v && is_same_v) + { + add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + op_ptrs); + } + else + { + add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_bias_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances( + op_ptrs); + } } - if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v) - { - add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( - op_ptrs); - } #endif } #endif // CK_USE_XDL diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp_xdl.inc index 045d1623cf8..db743280d91 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp_xdl.inc @@ -480,6 +480,22 @@ void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_instance PassThrough, AddClamp>>>& instances); +void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances( + std::vector, + NHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + AddClamp, + TF32, + TF32>>>& instances); + void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_16x16_instances( std::vector>>& instances); +void add_device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instances( + std::vector, + NHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + AddClamp, + TF32, + TF32>>>& instances); + void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances( std::vector>>& instances); +void add_device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + std::vector, + NDHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + AddClamp, + TF32, + TF32>>>& instances); + void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances( std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + std::vector, + NDHWGK, + F32, + F32, + ck::Tuple, + F32, + PassThrough, + PassThrough, + Bilinear, + TF32, + TF32>>>& instances); #endif #ifdef CK_ENABLE_INT8 @@ -137,8 +153,16 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { - add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_instances( - op_ptrs); + if constexpr(is_same_v) + { + add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + op_ptrs); + } + else + { + add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_instances( + op_ptrs); + } } #endif #ifdef CK_ENABLE_FP16 diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp.hpp index c4fbbf1d903..0a382d19920 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp.hpp @@ -125,23 +125,34 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v && - is_same_v) + is_same_v) { - add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs); - add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_16x16_instances( - op_ptrs); - add_device_grouped_conv2d_fwd_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instances( - op_ptrs); - add_device_grouped_conv2d_fwd_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances( - op_ptrs); - add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances( - op_ptrs); - add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances( - op_ptrs); - add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances( - op_ptrs); + if constexpr(is_same_v && is_same_v) + { + add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instances( + op_ptrs); + } + else + { + add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_16x16_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances( + op_ptrs); + } } + #endif } // layout NDHWGC/GKZYXC/NDHWGK @@ -193,30 +204,32 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v && - is_same_v) - { - add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(op_ptrs); - add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances( - op_ptrs); - add_device_grouped_conv3d_fwd_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instances( - op_ptrs); - add_device_grouped_conv3d_fwd_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances( - op_ptrs); - add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances( - op_ptrs); - add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances( - op_ptrs); - add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances( - op_ptrs); - } - - if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v) + is_same_v) { - add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( - op_ptrs); + if constexpr(is_same_v && is_same_v) + { + add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + op_ptrs); + } + else + { + add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances( + op_ptrs); + } } #endif } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp_xdl.inc index b0061b966d0..59eebb45a39 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp_xdl.inc @@ -480,6 +480,22 @@ void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_instances( PassThrough, Clamp>>>& instances); +void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances( + std::vector, + NHWGK, + F32, + F32, + Tuple<>, + F32, + PassThrough, + PassThrough, + Clamp, + TF32, + TF32>>>& instances); + void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_16x16_instances( std::vector>>& instances); +void add_device_grouped_conv2d_fwd_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instances( + std::vector, + NHWGK, + F32, + F32, + Tuple<>, + F32, + PassThrough, + PassThrough, + Clamp, + TF32, + TF32>>>& instances); + void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances( std::vector>>& instances); +void add_device_grouped_conv3d_fwd_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + std::vector, + NDHWGK, + F32, + F32, + Tuple<>, + F32, + PassThrough, + PassThrough, + Clamp, + TF32, + TF32>>>& instances); + void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances( std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instances( + std::vector, + NDHWGK, + F32, + F32, + ck::Tuple<>, + F32, + PassThrough, + PassThrough, + Scale, + TF32, + TF32>>>& instances); #endif #ifdef CK_ENABLE_INT8 @@ -137,7 +153,16 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { - add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instances(op_ptrs); + if constexpr(is_same_v) + { + add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instances( + op_ptrs); + } + else + { + add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + op_ptrs); + } } #endif #ifdef CK_ENABLE_FP16 diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_ab.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_ab.hpp index 1bea403afa2..2dfaa7eb2b8 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_ab.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_ab.hpp @@ -68,6 +68,21 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f32_inst ScaleAdd, ScaleAdd, PassThrough>>>& instances); +void add_device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + std::vector, + NDHWGK, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + F32, + ScaleAdd, + ScaleAdd, + PassThrough, + TF32, + TF32>>>& instances); #endif #ifdef CK_ENABLE_INT8 @@ -137,8 +152,16 @@ struct DeviceOperationInstanceFactory> && is_same_v && is_same_v) { - add_device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f32_instances( - op_ptrs); + if constexpr(is_same_v) + { + add_device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + op_ptrs); + } + else + { + add_device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f32_instances( + op_ptrs); + } } #endif #ifdef CK_ENABLE_FP16 diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_scaleadd_relu.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_scaleadd_relu.hpp index efb62664266..13894cac919 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_scaleadd_relu.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_scaleadd_relu.hpp @@ -68,6 +68,22 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw PassThrough, PassThrough, ScaleAddScaleAddRelu>>>& instances); + +void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + std::vector, + NDHWGK, + F32, + F32, + ck::Tuple, + F32, + PassThrough, + PassThrough, + ScaleAddScaleAddRelu, + TF32, + TF32>>>& instances); #endif #ifdef CK_ENABLE_INT8 @@ -138,8 +154,16 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { - add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f32_instances( - op_ptrs); + if constexpr(is_same_v) + { + add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + op_ptrs); + } + else + { + add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f32_instances( + op_ptrs); + } } #endif #ifdef CK_ENABLE_FP16 diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc index af6041bbc5a..72f9591915c 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc @@ -55,6 +55,21 @@ void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instances( PassThrough, PassThrough, PassThrough>>>& instances); +void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_tf32_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_INT8 @@ -120,6 +135,21 @@ void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances( PassThrough, PassThrough, PassThrough>>>& instances); +void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_tf32_instances( + std::vector>>& instances); #endif // grouped conv2d forward, NHWGC/GKYXC/NHWGK @@ -211,6 +241,22 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_16x16_instances( PassThrough, PassThrough, PassThrough>>>& instances); + +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_INT8 @@ -276,6 +322,21 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_instances( PassThrough, PassThrough, PassThrough>>>& instances); +void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_tf32_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_INT8 @@ -383,6 +444,22 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_16x16_instances( PassThrough, PassThrough, PassThrough>>>& instances); + +void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_tf32_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_BF16 @@ -623,7 +700,7 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf8_instances( BF8>>>& instances); #endif -#if(defined(CK_ENABLE_FP8) && defined(CK_ENABLE_BF8)) +#if (defined(CK_ENABLE_FP8) && defined(CK_ENABLE_BF8)) void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f8_bf8_instances( std::vector>>& instances); #endif -#if(defined(CK_ENABLE_FP8) && defined(CK_ENABLE_BF8)) +#if (defined(CK_ENABLE_FP8) && defined(CK_ENABLE_BF8)) void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf8_f8_instances( std::vector>>& instances); +void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instances( + std::vector>>& instances); + void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_f32_instances( std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_f32_tf32_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_INT8 @@ -176,6 +208,22 @@ void add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_in PassThrough, PassThrough, PassThrough>>>& instances); + +void add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + std::vector>>& instances); #endif // grouped conv3d forward, NGCDHW/GKCZYX/NGKDHW @@ -225,6 +273,21 @@ void add_device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_f32_in PassThrough, PassThrough, PassThrough>>>& instances); +void add_device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_f32_tf32_instances( + std::vector>>& instances); #endif } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/CMakeLists.txt index ca4ea515bb0..6bb7e202eb1 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/CMakeLists.txt @@ -3,5 +3,6 @@ add_instance_library(device_grouped_conv1d_fwd_instance xdl/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instance.cpp xdl/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instance.cpp xdl/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instance.cpp + xdl/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_tf32_instance.cpp xdl/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/xdl/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/xdl/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_tf32_instance.cpp new file mode 100644 index 00000000000..0078d8788c3 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/xdl/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_tf32_instance.cpp @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_tf32_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_instances<1, + GNWC, + GKXC, + Empty_Tuple, + GNWK, + ConvFwdDefault>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_instances<1, + GNWC, + GKXC, + Empty_Tuple, + GNWK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_instances<1, + GNWC, + GKXC, + Empty_Tuple, + GNWK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt index 7f3621a2ba7..92bdfd88a50 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt @@ -5,10 +5,12 @@ set(GROUPED_CONV2D_FWD xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp + xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_tf32_instance.cpp # NHWGC, GKYXC, NHWGK xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp + xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instance.cpp xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instance.cpp xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_16x16_instance.cpp xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_16x16_instance.cpp @@ -17,9 +19,11 @@ set(GROUPED_CONV2D_FWD xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_instance.cpp xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_instance.cpp xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_instance.cpp + xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_tf32_instance.cpp xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_instance.cpp # NGCHW, GKCYX, NGKHW xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_instance.cpp + xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_tf32_instance.cpp xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_16x16_instance.cpp xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_16x16_instance.cpp xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_16x16_instance.cpp @@ -34,11 +38,13 @@ set(GROUPED_CONV2D_FWD xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instance.cpp xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instance.cpp xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instance.cpp + xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instance.cpp xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_int8_instance.cpp # NGCHW, GKCYX, NGKHW xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_bf16_instance.cpp xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_f16_instance.cpp xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_f32_instance.cpp + xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_f32_tf32_instance.cpp #mem # NHWGC, GKYXC, NHWGK xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_tf32_instance.cpp new file mode 100644 index 00000000000..9c8589c7b3f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_tf32_instance.cpp @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k] +void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_tf32_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_instances<2, + GNHWC, + GKYXC, + Empty_Tuple, + GNHWK, + ConvFwdDefault>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_instances<2, + GNHWC, + GKYXC, + Empty_Tuple, + GNHWK, + ConvFwd1x1P0>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_instances<2, + GNHWC, + GKYXC, + Empty_Tuple, + GNHWK, + ConvFwd1x1S1P0>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_instances<2, + GNHWC, + GKYXC, + Empty_Tuple, + GNHWK, + ConvFwdOddC>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_tf32_instance.cpp new file mode 100644 index 00000000000..6f921c24322 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_tf32_instance.cpp @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_tf32_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_instances<2, + NGCHW, + GKCYX, + Empty_Tuple, + NGKHW, + ConvFwdDefault>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_instances<2, + NGCHW, + GKCYX, + Empty_Tuple, + NGKHW, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_instances<2, + NGCHW, + GKCYX, + Empty_Tuple, + NGKHW, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_tf32_instance.cpp new file mode 100644 index 00000000000..451d5823996 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_tf32_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_tf32_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_generic_instances<2, + NGCHW, + GKYXC, + Empty_Tuple, + NGKHW, + ConvFwdDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instance.cpp new file mode 100644 index 00000000000..8143553d543 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instance.cpp @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1P0>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1S1P0>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_f32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_f32_tf32_instance.cpp new file mode 100644 index 00000000000..8af95f920cd --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_f32_tf32_instance.cpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_f32_tf32_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_f32_tf32_instances<2, + NGCHW, + GKCYX, + Empty_Tuple, + NGKHW, + ConvFwdDefault>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_f32_tf32_instances<2, + NGCHW, + GKCYX, + Empty_Tuple, + NGKHW, + ConvFwd3x3>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instance.cpp new file mode 100644 index 00000000000..5f3f2a22478 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instance.cpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_f32_tf32_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_f32_tf32_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd3x3>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/CMakeLists.txt index c06e4f59538..61f7f4421ee 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/CMakeLists.txt @@ -2,7 +2,7 @@ set(GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP) include(ShardInstantiation) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_instances @@ -11,7 +11,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_instances @@ -20,7 +20,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_instances @@ -29,7 +29,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instances @@ -38,7 +38,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_16x16_instances @@ -47,7 +47,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_16x16_instances @@ -58,7 +58,7 @@ generate_sharded_instantiations( ) # large tensor # NHWGC, GKYXC, NHWGK - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances @@ -67,7 +67,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/large_tensor ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances @@ -76,7 +76,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/large_tensor ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instances @@ -87,7 +87,7 @@ generate_sharded_instantiations( ) # merged groups # NHWGC, GKYXC, NHWGK - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances @@ -96,7 +96,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/merged_groups ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instances @@ -105,7 +105,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/merged_groups ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances @@ -114,9 +114,18 @@ generate_sharded_instantiations( SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/merged_groups ) + +set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) +generate_sharded_instantiations( + INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instances + TEMPLATE_FILE xdl/merged_groups/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instance.in + NUM_SHARDS 3 + SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP + OUTPUT_DIR ${GENERATED_DIR}/xdl/merged_groups +) #mem # NHWGC, GKYXC, NHWGK - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances @@ -125,7 +134,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/mem ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instances @@ -134,7 +143,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/mem ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances @@ -144,7 +153,7 @@ generate_sharded_instantiations( OUTPUT_DIR ${GENERATED_DIR}/xdl/mem ) # NHWGC, GKYXC, NHWGK - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instances @@ -153,7 +162,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/mem ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_mem_inter_instances @@ -162,7 +171,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/mem ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances @@ -173,7 +182,7 @@ generate_sharded_instantiations( ) #comp # NHWGC, GKYXC, NHWGK - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances @@ -182,7 +191,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/comp ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instances @@ -191,7 +200,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/comp ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances @@ -200,7 +209,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/comp ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instances @@ -209,7 +218,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/comp ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_comp_2x_instances @@ -218,7 +227,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/comp ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instances @@ -227,7 +236,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/comp ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_comp_part2_instances diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/xdl/merged_groups/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instance.in b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/xdl/merged_groups/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instance.in new file mode 100644 index 00000000000..3d147035db8 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/xdl/merged_groups/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instance.in @@ -0,0 +1,68 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp" +#include "ck/utility/filter_tuple.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using device_grouped_conv2d_fwd_bias_bn_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instances = + std::vector< + std::unique_ptr, + NHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + BiasNormalizeInInferClamp, + TF32, + TF32>>>; + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +template +void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instances_shard( + device_grouped_conv2d_fwd_bias_bn_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instances& + instances) +{ + add_device_operation_instances(instances, + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_merged_groups_f32_tf32_instances< + 2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwdDefault, + Tuple, + BiasNormalizeInInferClamp>, + Shards, + ShardIndex>{}); + + add_device_operation_instances(instances, + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_merged_groups_f32_tf32_instances< + 2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwd3x3, + Tuple, + BiasNormalizeInInferClamp>, + Shards, + ShardIndex>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/CMakeLists.txt index e63ac766b68..c370e613332 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/CMakeLists.txt @@ -21,9 +21,11 @@ add_instance_library(device_grouped_conv2d_fwd_bias_clamp_instance xdl/comp/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp16_comp_part2_instance.cpp xdl/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_instance.cpp + xdl/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp xdl/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_16x16_instance.cpp xdl/large_tensor/device_grouped_conv2d_fwd_bias_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_fp32_instance.cpp xdl/merged_groups/device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_fp32_instance.cpp + xdl/merged_groups/device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp xdl/mem/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_mem_intra_instance.cpp xdl/mem/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_mem_inter_instance.cpp xdl/comp/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_comp_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp new file mode 100644 index 00000000000..0bf7f8b7b97 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp @@ -0,0 +1,62 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances( + std::vector, + NHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + AddClamp, + TF32, + TF32>>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwdDefault, + Tuple, + AddClamp>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwd1x1P0, + Tuple, + AddClamp>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwd1x1S1P0, + Tuple, + AddClamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/merged_groups/device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/merged_groups/device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp new file mode 100644 index 00000000000..77905b3f67b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/merged_groups/device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instances( + std::vector, + NHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + AddClamp, + TF32, + TF32>>>& instances) +{ + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_f32_tf32_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwdDefault, + Tuple, + AddClamp>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_f32_tf32_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwd3x3, + Tuple, + AddClamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/CMakeLists.txt index 8faed08c050..5183c66a0af 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/CMakeLists.txt @@ -21,9 +21,11 @@ add_instance_library(device_grouped_conv2d_fwd_clamp_instance xdl/comp/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp16_comp_part2_instance.cpp xdl/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_instance.cpp + xdl/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp xdl/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_16x16_instance.cpp xdl/large_tensor/device_grouped_conv2d_fwd_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_fp32_instance.cpp xdl/merged_groups/device_grouped_conv2d_fwd_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_fp32_instance.cpp + xdl/merged_groups/device_grouped_conv2d_fwd_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp xdl/mem/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_mem_intra_instance.cpp xdl/mem/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_mem_inter_instance.cpp xdl/comp/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_comp_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp new file mode 100644 index 00000000000..a4b16917bb6 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp @@ -0,0 +1,62 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances( + std::vector, + NHWGK, + F32, + F32, + Tuple<>, + F32, + PassThrough, + PassThrough, + Clamp, + TF32, + TF32>>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_instances<2, + NHWGC, + GKYXC, + Tuple<>, + NHWGK, + ConvFwdDefault, + Tuple<>, + Clamp>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_instances<2, + NHWGC, + GKYXC, + Tuple<>, + NHWGK, + ConvFwd1x1P0, + Tuple<>, + Clamp>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_instances<2, + NHWGC, + GKYXC, + Tuple<>, + NHWGK, + ConvFwd1x1S1P0, + Tuple<>, + Clamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/merged_groups/device_grouped_conv2d_fwd_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/merged_groups/device_grouped_conv2d_fwd_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp new file mode 100644 index 00000000000..b004b4f3cfc --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/merged_groups/device_grouped_conv2d_fwd_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instances( + std::vector, + NHWGK, + F32, + F32, + Tuple<>, + F32, + PassThrough, + PassThrough, + Clamp, + TF32, + TF32>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_f32_tf32_instances<2, + NHWGC, + GKYXC, + Tuple<>, + NHWGK, + ConvFwdDefault, + Tuple<>, + Clamp>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_f32_tf32_instances<2, + NHWGC, + GKYXC, + Tuple<>, + NHWGK, + ConvFwd3x3, + Tuple<>, + Clamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt index d0ae0ad42e0..3678c32e6c9 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt @@ -24,9 +24,11 @@ set(GROUPED_CONV3D_FWD xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp + xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_f32_instance.cpp + xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_f32_tf32_instance.cpp xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instance.cpp xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_inter_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp new file mode 100644 index 00000000000..d4a05792d74 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_f32_tf32_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_f32_tf32_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd3x3>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_f32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_f32_tf32_instance.cpp new file mode 100644 index 00000000000..753d452990f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_f32_tf32_instance.cpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_f32_tf32_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_f32_tf32_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_f32_tf32_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd3x3>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/CMakeLists.txt index 6a776b49438..99a4d05d10c 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/CMakeLists.txt @@ -123,6 +123,15 @@ generate_sharded_instantiations( SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/merged_groups ) + +set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) +generate_sharded_instantiations( + INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances + TEMPLATE_FILE xdl/merged_groups/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.in + NUM_SHARDS 3 + SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + OUTPUT_DIR ${GENERATED_DIR}/xdl/merged_groups +) #mem # NDHWGC, GKZYXC, NDHWGK diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/xdl/merged_groups/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.in b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/xdl/merged_groups/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.in new file mode 100644 index 00000000000..a857b7de4f0 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/xdl/merged_groups/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.in @@ -0,0 +1,68 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp" +#include "ck/utility/filter_tuple.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances = + std::vector, + NDHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + BiasNormalizeInInferClamp, + TF32, + TF32>>>; + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +template +void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances_shard( + device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances& + instances) +{ + add_device_operation_instances(instances, + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_merged_groups_f32_tf32_instances< + 3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwdDefault, + Tuple, + BiasNormalizeInInferClamp>, + Shards, + ShardIndex>{}); + + add_device_operation_instances(instances, + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_merged_groups_f32_tf32_instances< + 3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd3x3, + Tuple, + BiasNormalizeInInferClamp>, + Shards, + ShardIndex>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/CMakeLists.txt index bcc7020ca9c..aef445a62e4 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/CMakeLists.txt @@ -20,6 +20,7 @@ set(GROUPED_CONV3D_FWD xdl/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_16x16_instance.cpp xdl/large_tensor/device_grouped_conv3d_fwd_bias_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_fp32_instance.cpp xdl/merged_groups/device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_fp32_instance.cpp + xdl/merged_groups/device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp xdl/mem/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_mem_inter_instance.cpp xdl/mem/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_mem_intra_instance.cpp xdl/comp/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_comp_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/xdl/merged_groups/device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/xdl/merged_groups/device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp new file mode 100644 index 00000000000..2988b715e0d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/xdl/merged_groups/device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + std::vector, + NDHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + AddClamp, + TF32, + TF32>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_f32_tf32_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwdDefault, + Tuple, + AddClamp>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_f32_tf32_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd3x3, + Tuple, + AddClamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/CMakeLists.txt index 436c37fd58c..6a4637d6e18 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/CMakeLists.txt @@ -3,6 +3,7 @@ set(GROUPED_CONV3D_FWD_BILINEAR xdl/device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp + xdl/device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp) add_instance_library(device_grouped_conv3d_fwd_bilinear_instance ${GROUPED_CONV3D_FWD_BILINEAR}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/xdl/device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/xdl/device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp new file mode 100644 index 00000000000..869c812b50d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/xdl/device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_bilinear_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + std::vector, + NDHWGK, + F32, + F32, + ck::Tuple, + F32, + PassThrough, + PassThrough, + Bilinear, + TF32, + TF32>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bilinear_f32_tf32_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bilinear_f32_tf32_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bilinear_f32_tf32_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/CMakeLists.txt index 059d22f8d24..260deba2de9 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/CMakeLists.txt @@ -20,6 +20,7 @@ set(GROUPED_CONV3D_FWD xdl/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_16x16_instance.cpp xdl/large_tensor/device_grouped_conv3d_fwd_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_fp32_instance.cpp xdl/merged_groups/device_grouped_conv3d_fwd_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_fp32_instance.cpp + xdl/merged_groups/device_grouped_conv3d_fwd_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp xdl/mem/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_mem_inter_instance.cpp xdl/mem/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_mem_intra_instance.cpp xdl/comp/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_comp_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/merged_groups/device_grouped_conv3d_fwd_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/merged_groups/device_grouped_conv3d_fwd_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp new file mode 100644 index 00000000000..66874c5696a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/merged_groups/device_grouped_conv3d_fwd_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + std::vector, + NDHWGK, + F32, + F32, + Tuple<>, + F32, + PassThrough, + PassThrough, + Clamp, + TF32, + TF32>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_f32_tf32_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwdDefault, + Tuple<>, + Clamp>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_f32_tf32_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd3x3, + Tuple<>, + Clamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/CMakeLists.txt index f36d55d367e..47fc2655bb4 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/CMakeLists.txt @@ -3,6 +3,7 @@ set(GROUPED_CONV3D_FWD_BILINEAR xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp + xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp) add_instance_library(device_grouped_conv3d_fwd_scale_instance ${GROUPED_CONV3D_FWD_BILINEAR}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp new file mode 100644 index 00000000000..5377cc56bd6 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_scale_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + std::vector, + NDHWGK, + F32, + F32, + ck::Tuple<>, + F32, + PassThrough, + PassThrough, + Scale, + TF32, + TF32>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_scale_f32_tf32_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_scale_f32_tf32_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_scale_f32_tf32_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/CMakeLists.txt index 10762494474..74d4a3829aa 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/CMakeLists.txt @@ -3,6 +3,7 @@ set(GROUPED_CONV3D_FWD_SCALEADD_AB xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp + xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp) add_instance_library(device_grouped_conv3d_fwd_scaleadd_ab_instance ${GROUPED_CONV3D_FWD_SCALEADD_AB}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp new file mode 100644 index 00000000000..315aefb8251 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_scaleadd_ab_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + std::vector, + NDHWGK, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + F32, + ScaleAdd, + ScaleAdd, + PassThrough, + TF32, + TF32>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_scaleadd_ab_f32_tf32_instances<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_scaleadd_ab_f32_tf32_instances<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_scaleadd_ab_f32_tf32_instances<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/CMakeLists.txt index 1be1db7d1d9..ea9bbc3a4ab 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/CMakeLists.txt @@ -3,6 +3,7 @@ set(GROUPED_CONV3D_FWD_scaleadd_scaleadd_RELU xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp + xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp) add_instance_library(device_grouped_conv3d_fwd_scaleadd_scaleadd_relu_instance ${GROUPED_CONV3D_FWD_scaleadd_scaleadd_RELU}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp new file mode 100644 index 00000000000..35d86e0e9dd --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + std::vector, + NDHWGK, + F32, + F32, + ck::Tuple, + F32, + PassThrough, + PassThrough, + ScaleAddScaleAddRelu, + TF32, + TF32>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f32_tf32_instances< + 3, + NDHWGC, + GKZYXC, + ck::Tuple, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f32_tf32_instances< + 3, + NDHWGC, + GKZYXC, + ck::Tuple, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f32_tf32_instances< + 3, + NDHWGC, + GKZYXC, + ck::Tuple, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck From 45d0057390e8775b62101fdff20c4c631883d523 Mon Sep 17 00:00:00 2001 From: yingmaolu Date: Thu, 18 Sep 2025 15:21:28 +0800 Subject: [PATCH 02/19] add instances of device_grouped_conv_fwd_xdl_f32_comp_instances --- example/01_gemm/CMakeLists.txt | 10 -- .../01_gemm/gemm_xdl_lds_direct_load_fp32.cpp | 2 +- .../gemm_xdl_lds_direct_load_fp32_tf32.cpp | 4 +- example/09_convnd_fwd/CMakeLists.txt | 9 -- .../blockwise_gemm_pipeline_xdlops_base.hpp | 7 +- .../blockwise_gemm_pipeline_xdlops_v1.hpp | 56 ++++++----- ...kwise_gemm_pipeline_xdlops_v1_ab_scale.hpp | 60 ++++++------ ...ckwise_gemm_pipeline_xdlops_v1_b_scale.hpp | 28 +++--- .../blockwise_gemm_pipeline_xdlops_v2.hpp | 96 +++++++++--------- ...kwise_gemm_pipeline_xdlops_v2_ab_scale.hpp | 46 ++++----- ...ckwise_gemm_pipeline_xdlops_v2_b_scale.hpp | 98 ++++++++++--------- .../blockwise_gemm_pipeline_xdlops_v3.hpp | 34 ++++--- ...kwise_gemm_pipeline_xdlops_v3_ab_scale.hpp | 34 ++++--- ...ckwise_gemm_pipeline_xdlops_v3_b_scale.hpp | 36 +++---- .../blockwise_gemm_pipeline_xdlops_v4.hpp | 46 ++++----- ...ckwise_gemm_pipeline_xdlops_v4_b_scale.hpp | 48 ++++----- .../blockwise_gemm_pipeline_xdlops_v5.hpp | 46 ++++----- ...ice_grouped_conv_fwd_xdl_comp_instance.hpp | 22 +++++ .../gpu/grouped_convolution_forward.hpp | 8 ++ ...grouped_convolution_forward_bias_clamp.hpp | 4 + ...ped_convolution_forward_bias_clamp_xdl.inc | 32 ++++++ .../gpu/grouped_convolution_forward_clamp.hpp | 4 + .../grouped_convolution_forward_clamp_xdl.inc | 32 ++++++ .../grouped_convolution_forward_comp_xdl.inc | 61 ++++++++++++ .../gpu/grouped_conv2d_fwd/CMakeLists.txt | 1 + ...chw_gkcyx_ngkhw_f32_tf32_comp_instance.cpp | 41 ++++++++ ...wgc_gkyxc_nhwgk_f32_tf32_comp_instance.cpp | 68 +++++++++++++ .../CMakeLists.txt | 9 ++ ...hwgc_gkyxc_nhwgk_f32_tf32_comp_instance.in | 82 ++++++++++++++++ .../CMakeLists.txt | 1 + ...gc_gkyxc_nhwgk_fp32_tf32_comp_instance.cpp | 65 ++++++++++++ .../grouped_conv2d_fwd_clamp/CMakeLists.txt | 1 + ...gc_gkyxc_nhwgk_fp32_tf32_comp_instance.cpp | 65 ++++++++++++ .../gpu/grouped_conv3d_fwd/CMakeLists.txt | 6 +- ...c_gkzyxc_ndhwgk_f32_tf32_comp_instance.cpp | 57 +++++++++++ ...w_gkczyx_ngkdhw_f32_tf32_comp_instance.cpp | 57 +++++++++++ .../CMakeLists.txt | 9 ++ ...gc_gkzyxc_ndhwgk_f32_tf32_comp_instance.in | 82 ++++++++++++++++ .../CMakeLists.txt | 1 + ..._gkzyxc_ndhwgk_fp32_tf32_comp_instance.cpp | 64 ++++++++++++ .../grouped_conv3d_fwd_clamp/CMakeLists.txt | 1 + ..._gkzyxc_ndhwgk_fp32_tf32_comp_instance.cpp | 63 ++++++++++++ 42 files changed, 1173 insertions(+), 323 deletions(-) create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_tf32_comp_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_comp_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/xdl/comp/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_comp_instance.in create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/comp/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_comp_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/comp/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_comp_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_comp_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/xdl/comp/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instance.in create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/xdl/comp/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_comp_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/comp/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_comp_instance.cpp diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt index d24362cc171..03bde864214 100644 --- a/example/01_gemm/CMakeLists.txt +++ b/example/01_gemm/CMakeLists.txt @@ -115,16 +115,6 @@ foreach(gpu IN LISTS GPU_TARGETS) endif() endforeach() -list(APPEND gpu_list gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_example_executable(example_gemm_xdl_lds_direct_load_fp32_tf32 gemm_xdl_lds_direct_load_fp32_tf32.cpp) - add_example_dependencies(example_gemm_xdl example_gemm_xdl_lds_direct_load_fp32_tf32) - set(target 1) - endif() -endforeach() - add_example_executable(example_gemm_xdl_fp8 gemm_xdl_fp8.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8) diff --git a/example/01_gemm/gemm_xdl_lds_direct_load_fp32.cpp b/example/01_gemm/gemm_xdl_lds_direct_load_fp32.cpp index 3cff8b30e2e..75971bdecf3 100644 --- a/example/01_gemm/gemm_xdl_lds_direct_load_fp32.cpp +++ b/example/01_gemm/gemm_xdl_lds_direct_load_fp32.cpp @@ -37,7 +37,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle // ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 2, S<4, 64, 1>, S<1, 0, 2>, 2, 1, 1, S<4, 64, 1>, S<1, 0, 2>, 2, 1, 1, 1, 1, S<1, 32, 1, 4>, 4>; + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, 1, 1, S<1, 8, 1, 8>, 4>; // clang-format on #else // clang-format off diff --git a/example/01_gemm/gemm_xdl_lds_direct_load_fp32_tf32.cpp b/example/01_gemm/gemm_xdl_lds_direct_load_fp32_tf32.cpp index 9b2c2df09e5..9b92fad779b 100644 --- a/example/01_gemm/gemm_xdl_lds_direct_load_fp32_tf32.cpp +++ b/example/01_gemm/gemm_xdl_lds_direct_load_fp32_tf32.cpp @@ -43,7 +43,9 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle // ######| XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 2, S<4, 64, 1>, S<1, 0, 2>, 2, 1, 1, S<4, 64, 1>, S<1, 0, 2>, 2, 1, 1, 1, 1, S<1, 32, 1, 4>, 4, ck::LoopScheduler::Default, ck::PipelineVersion::v4, ComputeDataType>; + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 128, 128, 32, + 8, 8, 32, 32, 2, 2, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, + 1, 1, S<1, 8, 1, 8>, 4, ck::LoopScheduler::Default, ck::PipelineVersion::v4, ComputeDataType>; // clang-format on #else // clang-format off diff --git a/example/09_convnd_fwd/CMakeLists.txt b/example/09_convnd_fwd/CMakeLists.txt index 67766011c6b..4f174bfcbb2 100644 --- a/example/09_convnd_fwd/CMakeLists.txt +++ b/example/09_convnd_fwd/CMakeLists.txt @@ -29,12 +29,3 @@ foreach(gpu IN LISTS GPU_TARGETS) set(target 1) endif() endforeach() - -list(APPEND gpu_tf32_list gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_tf32_list AND target EQUAL 0) - add_example_executable(example_convnd_fwd_xdl_fp32_tf32 convnd_fwd_xdl_fp32_tf32.cpp) - set(target 1) - endif() -endforeach() diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp index ff64b6fe2a5..d664a822aa4 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp @@ -54,6 +54,9 @@ struct BlockwiseGemmXdlops_pipeline_base static constexpr auto xdlops_gemm = XdlopsGemm{}; + using ComputeDataTypeBuf = + conditional_t::value, float, ComputeDataType>; + static constexpr index_t AMmaKStride = KPack; static constexpr index_t BMmaKStride = KPack; @@ -376,7 +379,7 @@ struct BlockwiseGemmXdlops_pipeline_base make_tuple(Number{}, Number{}, xdlops_gemm.GetRegSizePerXdlops())); using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, @@ -386,7 +389,7 @@ struct BlockwiseGemmXdlops_pipeline_base A_K1>; using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp index f597573dc2a..f281184c144 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp @@ -140,6 +140,8 @@ struct BlockwiseGemmXdlops_pipeline_v1( + auto a_thread_buf = make_static_buffer( a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( + auto b_thread_buf = make_static_buffer( b_thread_desc_.GetElementSpaceSize()); // Global prefetch 1 @@ -240,20 +242,20 @@ struct BlockwiseGemmXdlops_pipeline_v1{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); using mfma_input_type = - typename vector_type::type; constexpr index_t c_offset = @@ -301,20 +303,20 @@ struct BlockwiseGemmXdlops_pipeline_v1{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); using mfma_input_type = - typename vector_type::type; + typename vector_type::type; constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); @@ -439,6 +441,8 @@ struct BlockwiseGemmXdlops_pipeline_v1( + auto a_thread_buf = make_static_buffer( a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( + auto b_thread_buf = make_static_buffer( b_thread_desc_.GetElementSpaceSize()); // Global prefetch 1 @@ -551,20 +555,20 @@ struct BlockwiseGemmXdlops_pipeline_v1{}([&](auto k_) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); using mfma_input_type = - typename vector_type::type; constexpr index_t c_offset = @@ -640,20 +644,20 @@ struct BlockwiseGemmXdlops_pipeline_v1{}([&](auto k_) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); using mfma_input_type = - typename vector_type::type; constexpr index_t c_offset = @@ -704,7 +708,7 @@ struct BlockwiseGemmXdlops_pipeline_v1, @@ -714,7 +718,7 @@ struct BlockwiseGemmXdlops_pipeline_v1; using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_ab_scale.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_ab_scale.hpp index ea4f5e4a286..1af982e1652 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_ab_scale.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_ab_scale.hpp @@ -144,6 +144,8 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale - // sizeof(ComputeDataType) / sizeof(BDataType) - // ? sizeof(ComputeDataType) / sizeof(ADataType) - // : sizeof(ComputeDataType) / sizeof(BDataType); + // constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataTypeBuf) / sizeof(ADataType) > + // sizeof(ComputeDataTypeBuf) / + // sizeof(BDataType) + // ? sizeof(ComputeDataTypeBuf) / + // sizeof(ADataType) : sizeof(ComputeDataTypeBuf) + // / sizeof(BDataType); constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma); constexpr auto num_mfma_per_issue = num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b); @@ -351,9 +355,9 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale( + auto a_thread_buf = make_static_buffer( a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( + auto b_thread_buf = make_static_buffer( b_thread_desc_.GetElementSpaceSize()); auto a_scale_thread_buf = make_static_buffer( a_scale_thread_desc.GetElementSpaceSize()); @@ -516,17 +520,17 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale()(Number{}) = 0; }); static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number::type; xdlops_gemm.template Run<>( @@ -646,17 +650,17 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale()(Number{}) = 0; }); static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number::type; xdlops_gemm.template Run<>( @@ -737,17 +741,17 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale()(Number{}) = 0; }); static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number::type; xdlops_gemm.template Run<>( @@ -791,17 +795,17 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale()(Number{}) = 0; }); static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number::type; xdlops_gemm.template Run<>( @@ -842,7 +846,7 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale, @@ -852,7 +856,7 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale; using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp index 4246f4a44e7..123174e090d 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp @@ -140,6 +140,8 @@ struct BlockwiseGemmXdlops_pipeline_v1_b_scale( + auto a_thread_buf = make_static_buffer( a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( + auto b_thread_buf = make_static_buffer( b_thread_desc_.GetElementSpaceSize()); - auto b_scale_thread_buf = make_static_buffer( + auto b_scale_thread_buf = make_static_buffer( b_scale_thread_desc.GetElementSpaceSize()); // Global prefetch 1 @@ -279,20 +281,20 @@ struct BlockwiseGemmXdlops_pipeline_v1_b_scale{}([&](auto n0) { c_thread_buf_per_scale.Clear(); static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); using mfma_input_type = - typename vector_type::type; xdlops_gemm.template Run<>( @@ -360,20 +362,20 @@ struct BlockwiseGemmXdlops_pipeline_v1_b_scale{}([&](auto n0) { c_thread_buf_per_scale.Clear(); static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); using mfma_input_type = - typename vector_type::type; + typename vector_type::type; xdlops_gemm.template Run<>( a_thread_vec.template AsType(), diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp index 4cc1cf569d6..b474ddf5286 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp @@ -141,6 +141,8 @@ struct BlockwiseGemmXdlops_pipeline_v2= 1 ? 4 * WaveSize / BlockSize : 1; static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil( @@ -225,9 +227,9 @@ struct BlockwiseGemmXdlops_pipeline_v2( + auto a_thread_buf = make_static_buffer( a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( + auto b_thread_buf = make_static_buffer( b_thread_desc_.GetElementSpaceSize()); // Global prefetch 1 @@ -284,20 +286,20 @@ struct BlockwiseGemmXdlops_pipeline_v2{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); using mfma_input_type = - typename vector_type::type; constexpr index_t c_offset = @@ -355,20 +357,20 @@ struct BlockwiseGemmXdlops_pipeline_v2{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); using mfma_input_type = - typename vector_type::type; constexpr index_t c_offset = @@ -410,20 +412,20 @@ struct BlockwiseGemmXdlops_pipeline_v2{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); using mfma_input_type = - typename vector_type::type; + typename vector_type::type; constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); @@ -461,20 +463,20 @@ struct BlockwiseGemmXdlops_pipeline_v2{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); using mfma_input_type = - typename vector_type::type; + typename vector_type::type; constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); @@ -628,6 +630,8 @@ struct BlockwiseGemmXdlops_pipeline_v2( + auto a_thread_buf = make_static_buffer( a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( + auto b_thread_buf = make_static_buffer( b_thread_desc_.GetElementSpaceSize()); // Global prefetch 1 @@ -786,20 +790,20 @@ struct BlockwiseGemmXdlops_pipeline_v2{}([&](auto k_) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); using mfma_input_type = - typename vector_type::type; constexpr index_t c_offset = @@ -885,20 +889,20 @@ struct BlockwiseGemmXdlops_pipeline_v2{}([&](auto k_) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); using mfma_input_type = - typename vector_type::type; constexpr index_t c_offset = @@ -961,20 +965,20 @@ struct BlockwiseGemmXdlops_pipeline_v2{}([&](auto k_) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); using mfma_input_type = - typename vector_type::type; constexpr index_t c_offset = @@ -1037,20 +1041,20 @@ struct BlockwiseGemmXdlops_pipeline_v2{}([&](auto k_) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); using mfma_input_type = - typename vector_type::type; constexpr index_t c_offset = @@ -1129,7 +1133,7 @@ struct BlockwiseGemmXdlops_pipeline_v2, @@ -1139,7 +1143,7 @@ struct BlockwiseGemmXdlops_pipeline_v2; using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp index 119f8a33060..70f31246f29 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp @@ -143,6 +143,8 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale= 1 ? 4 * WaveSize / BlockSize : 1; static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil( @@ -257,9 +259,9 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale( + auto a_thread_buf = make_static_buffer( a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( + auto b_thread_buf = make_static_buffer( b_thread_desc_.GetElementSpaceSize()); auto a_scale_thread_buf = make_static_buffer( a_scale_thread_desc.GetElementSpaceSize()); @@ -351,20 +353,20 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale{}([&](auto n0) { c_thread_buf_per_scale.Clear(); static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); using mfma_input_type = - typename vector_type::type; xdlops_gemm.template Run<>( @@ -457,20 +459,20 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale{}([&](auto n0) { c_thread_buf_per_scale.Clear(); static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); using mfma_input_type = - typename vector_type::type; xdlops_gemm.template Run<>( @@ -547,20 +549,20 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale{}([&](auto n0) { c_thread_buf_per_scale.Clear(); static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); using mfma_input_type = - typename vector_type::type; + typename vector_type::type; xdlops_gemm.template Run<>( a_thread_vec.template AsType(), @@ -605,20 +607,20 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale{}([&](auto n0) { c_thread_buf_per_scale.Clear(); static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); using mfma_input_type = - typename vector_type::type; + typename vector_type::type; xdlops_gemm.template Run<>( a_thread_vec.template AsType(), diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp index 80c65515e89..aded984c1e1 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp @@ -141,6 +141,8 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale= 1 ? 4 * WaveSize / BlockSize : 1; static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil( @@ -225,9 +227,9 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale( + auto a_thread_buf = make_static_buffer( a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( + auto b_thread_buf = make_static_buffer( b_thread_desc_.GetElementSpaceSize()); // Global prefetch 1 @@ -285,20 +287,20 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); using mfma_input_type = - typename vector_type::type; constexpr index_t c_offset = @@ -356,20 +358,20 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); using mfma_input_type = - typename vector_type::type; constexpr index_t c_offset = @@ -411,20 +413,20 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); using mfma_input_type = - typename vector_type::type; + typename vector_type::type; constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); @@ -462,20 +464,20 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); using mfma_input_type = - typename vector_type::type; + typename vector_type::type; constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); @@ -629,6 +631,8 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale( + auto a_thread_buf = make_static_buffer( a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( + auto b_thread_buf = make_static_buffer( b_thread_desc_.GetElementSpaceSize()); - auto b_scale_thread_buf = make_static_buffer( + auto b_scale_thread_buf = make_static_buffer( b_scale_thread_desc.GetElementSpaceSize()); // Global prefetch 1 @@ -821,20 +825,20 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale{}([&](auto k_) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); using mfma_input_type = - typename vector_type::type; constexpr index_t c_offset = @@ -942,20 +946,20 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale{}([&](auto k_) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); using mfma_input_type = - typename vector_type::type; constexpr index_t c_offset = @@ -1039,20 +1043,20 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale{}([&](auto k_) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); using mfma_input_type = - typename vector_type::type; constexpr index_t c_offset = @@ -1123,20 +1127,20 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale{}([&](auto k_) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); using mfma_input_type = - typename vector_type::type; constexpr index_t c_offset = @@ -1223,7 +1227,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale, @@ -1233,7 +1237,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale; using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp index 7203348418a..a2053dfd397 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp @@ -142,6 +142,8 @@ struct BlockwiseGemmXdlops_pipeline_v3 - // sizeof(ComputeDataType) / sizeof(BDataType) - // ? sizeof(ComputeDataType) / sizeof(ADataType) - // : sizeof(ComputeDataType) / sizeof(BDataType); + // constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataTypeBuf) / sizeof(ADataType) > + // sizeof(ComputeDataTypeBuf) / sizeof(BDataType) + // ? sizeof(ComputeDataTypeBuf) / sizeof(ADataType) + // : sizeof(ComputeDataTypeBuf) / sizeof(BDataType); constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma); constexpr auto num_mfma_per_issue = num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b); @@ -295,9 +297,9 @@ struct BlockwiseGemmXdlops_pipeline_v3( + auto a_thread_buf = make_static_buffer( a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( + auto b_thread_buf = make_static_buffer( b_thread_desc_.GetElementSpaceSize()); // Global prefetch 1 @@ -364,20 +366,20 @@ struct BlockwiseGemmXdlops_pipeline_v3{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); using mfma_input_type = - typename vector_type::type; constexpr index_t c_offset = @@ -424,20 +426,20 @@ struct BlockwiseGemmXdlops_pipeline_v3{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); using mfma_input_type = - typename vector_type::type; + typename vector_type::type; constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp index a7d22066acd..08d5d1f9561 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp @@ -143,6 +143,8 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale - // sizeof(ComputeDataType) / sizeof(BDataType) - // ? sizeof(ComputeDataType) / sizeof(ADataType) - // : sizeof(ComputeDataType) / sizeof(BDataType); + // constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataTypeBuf) / sizeof(ADataType) > + // sizeof(ComputeDataTypeBuf) / sizeof(BDataType) + // ? sizeof(ComputeDataTypeBuf) / sizeof(ADataType) + // : sizeof(ComputeDataTypeBuf) / sizeof(BDataType); constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma); constexpr auto num_mfma_per_issue = num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b); @@ -329,9 +331,9 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale{}) == 1, "Pipeline v3 only support scaleblocksliceN=1"); // assume kperblock = scaleblockk - auto a_thread_buf = make_static_buffer( + auto a_thread_buf = make_static_buffer( a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( + auto b_thread_buf = make_static_buffer( b_thread_desc_.GetElementSpaceSize()); auto a_scale_thread_buf = make_static_buffer( a_scale_thread_desc.GetElementSpaceSize()); @@ -476,20 +478,20 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale()(Number{}) = 0; }); static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); using mfma_input_type = - typename vector_type::type; xdlops_gemm.template Run<>( @@ -578,20 +580,20 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale()(Number{}) = 0; }); static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); using mfma_input_type = - typename vector_type::type; + typename vector_type::type; xdlops_gemm.template Run<>( a_thread_vec.template AsType(), diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp index 3179a90b7fd..a9b399ea5bb 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp @@ -142,6 +142,8 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale - // sizeof(ComputeDataType) / sizeof(BDataType) - // ? sizeof(ComputeDataType) / sizeof(ADataType) - // : sizeof(ComputeDataType) / sizeof(BDataType); + // constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataTypeBuf) / sizeof(ADataType) > + // sizeof(ComputeDataTypeBuf) / sizeof(BDataType) + // ? sizeof(ComputeDataTypeBuf) / sizeof(ADataType) + // : sizeof(ComputeDataTypeBuf) / sizeof(BDataType); constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma); constexpr auto num_mfma_per_issue = num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b); @@ -307,13 +309,13 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale( + auto a_thread_buf = make_static_buffer( a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( + auto b_thread_buf = make_static_buffer( b_thread_desc_.GetElementSpaceSize()); // B scale buffer - auto b_scale_thread_buf = make_static_buffer( + auto b_scale_thread_buf = make_static_buffer( b_scale_thread_desc.GetElementSpaceSize()); // Global prefetch 1 @@ -429,20 +431,20 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); using mfma_input_type = - typename vector_type::type; constexpr index_t c_offset = @@ -491,20 +493,20 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); using mfma_input_type = - typename vector_type::type; + typename vector_type::type; constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp index 9835d9325b4..c762b3be15f 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp @@ -142,6 +142,8 @@ struct BlockwiseGemmXdlops_pipeline_v4( + auto a_thread_buf = make_static_buffer( a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( + auto b_thread_buf = make_static_buffer( b_thread_desc_.GetElementSpaceSize()); StaticallyIndexedArray{}> a_thread_bufs; @@ -369,22 +371,22 @@ struct BlockwiseGemmXdlops_pipeline_v4{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_bufs[mfma_reg_buf] [Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_bufs[mfma_reg_buf] [Number{}]; }); using mfma_input_type = - typename vector_type::type; constexpr index_t c_offset = @@ -439,20 +441,20 @@ struct BlockwiseGemmXdlops_pipeline_v4{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_bufs[mfma_reg_buf][Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_bufs[mfma_reg_buf][Number{}]; }); using mfma_input_type = - typename vector_type::type; + typename vector_type::type; constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); @@ -492,20 +494,20 @@ struct BlockwiseGemmXdlops_pipeline_v4{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_bufs[mfma_reg_buf][Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_bufs[mfma_reg_buf][Number{}]; }); using mfma_input_type = - typename vector_type::type; + typename vector_type::type; constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); @@ -524,20 +526,20 @@ struct BlockwiseGemmXdlops_pipeline_v4{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_bufs[mfma_reg_buf][Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_bufs[mfma_reg_buf][Number{}]; }); using mfma_input_type = - typename vector_type::type; + typename vector_type::type; constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4_b_scale.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4_b_scale.hpp index f35c7a97cc3..3819f572c0f 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4_b_scale.hpp @@ -142,6 +142,8 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale( + auto a_thread_buf = make_static_buffer( a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( + auto b_thread_buf = make_static_buffer( b_thread_desc_.GetElementSpaceSize()); // B scale buffer - auto b_scale_thread_buf = make_static_buffer( + auto b_scale_thread_buf = make_static_buffer( b_scale_thread_desc.GetElementSpaceSize()); StaticallyIndexedArray{}> a_thread_bufs; @@ -478,22 +480,22 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_bufs[mfma_reg_buf] [Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_bufs[mfma_reg_buf] [Number{}]; }); using mfma_input_type = - typename vector_type::type; constexpr index_t c_offset = @@ -549,20 +551,20 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_bufs[mfma_reg_buf][Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_bufs[mfma_reg_buf][Number{}]; }); using mfma_input_type = - typename vector_type::type; + typename vector_type::type; constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); @@ -603,20 +605,20 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_bufs[mfma_reg_buf][Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_bufs[mfma_reg_buf][Number{}]; }); using mfma_input_type = - typename vector_type::type; + typename vector_type::type; constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); @@ -635,20 +637,20 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_bufs[mfma_reg_buf][Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_bufs[mfma_reg_buf][Number{}]; }); using mfma_input_type = - typename vector_type::type; + typename vector_type::type; constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp index 99934fa74e2..d5bc6369ddc 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp @@ -144,6 +144,8 @@ struct BlockwiseGemmXdlops_pipeline_v5( + auto a_thread_buf = make_static_buffer( a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( + auto b_thread_buf = make_static_buffer( b_thread_desc_.GetElementSpaceSize()); // Global prefetch 1 @@ -405,8 +407,8 @@ struct BlockwiseGemmXdlops_pipeline_v5 a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KRepeat, 1>{}([&](auto k0) { if constexpr(k0 == (KRepeat - 1)) @@ -427,18 +429,18 @@ struct BlockwiseGemmXdlops_pipeline_v5{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; }); static_for<0, KPack, 1>{}([&](auto ik) { - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); using mfma_input_type = - typename vector_type::type; constexpr index_t c_offset = @@ -481,8 +483,8 @@ struct BlockwiseGemmXdlops_pipeline_v5 a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KRepeat, 1>{}([&](auto k0) { if constexpr(k0 == (KRepeat - 1)) @@ -497,18 +499,18 @@ struct BlockwiseGemmXdlops_pipeline_v5{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; }); static_for<0, KPack, 1>{}([&](auto ik) { - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); using mfma_input_type = - typename vector_type::type; + typename vector_type::type; constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); @@ -540,25 +542,25 @@ struct BlockwiseGemmXdlops_pipeline_v5 a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KRepeat - 1, 1>{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; }); static_for<0, KPack, 1>{}([&](auto ik) { - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); using mfma_input_type = - typename vector_type::type; + typename vector_type::type; constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); @@ -591,16 +593,16 @@ struct BlockwiseGemmXdlops_pipeline_v5{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = a_thread_buf + a_thread_vec.template AsType()(ik) = a_thread_buf [Number{}]; }); static_for<0, KPack, 1>{}([&](auto ik) { - b_thread_vec.template AsType()(ik) = b_thread_buf + b_thread_vec.template AsType()(ik) = b_thread_buf [Number{}]; }); using mfma_input_type = - typename vector_type::type; + typename vector_type::type; constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); @@ -637,7 +639,7 @@ struct BlockwiseGemmXdlops_pipeline_v5{}, I1, I1, Number{})); using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, @@ -647,7 +649,7 @@ struct BlockwiseGemmXdlops_pipeline_v5; using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp index bbc2a54c342..0920b3277e6 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp @@ -24,6 +24,7 @@ using BF8 = ck::bf8_t; using BF16 = ck::bhalf_t; using F16 = ck::half_t; using F32 = float; +using TF32 = ck::tf32_t; template using S = ck::Sequence; @@ -205,6 +206,27 @@ using device_grouped_conv_fwd_xdl_f32_comp_instances = std::tuple< // clang-format on >; +template , + typename OutElementOp = PassThrough> +using device_grouped_conv_fwd_xdl_f32_tf32_comp_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, TF32, TF32> + // clang-format on + >; + // double rate mfma instances on gfx950 template >>& instances); +void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_comp_instances( + std::vector, + NHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + AddClamp, + TF32, + TF32>>>& instances); + void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances( std::vector>>& instances); +void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instances( + std::vector, + NDHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + AddClamp, + TF32, + TF32>>>& instances); + void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances( std::vector>>& instances); +void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_comp_instances( + std::vector, + NHWGK, + F32, + F32, + Tuple<>, + F32, + PassThrough, + PassThrough, + Clamp, + TF32, + TF32>>>& instances); + void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances( std::vector>>& instances); +void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instances( + std::vector, + NDHWGK, + F32, + F32, + Tuple<>, + F32, + PassThrough, + PassThrough, + Clamp, + TF32, + TF32>>>& instances); + void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances( std::vector>>& instances); +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_comp_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_INT8 @@ -232,6 +247,21 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_comp_instances( PassThrough, PassThrough, PassThrough>>>& instances); +void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_tf32_comp_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_BF16 @@ -281,6 +311,22 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances( PassThrough, PassThrough, PassThrough>>>& instances); + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instances( + std::vector>>& instances); #endif // grouped conv3d forward, NGCDHW/GKCZYX/NGKDHW @@ -386,6 +432,21 @@ void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_comp_instances( PassThrough, PassThrough, PassThrough>>>& instances); +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_comp_instances( + std::vector>>& instances) #endif } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt index 92bdfd88a50..b9b67da4c18 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt @@ -75,6 +75,7 @@ set(GROUPED_CONV2D_FWD # NGCHW, GKCYX, NGKHW xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_comp_instance.cpp xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_comp_instance.cpp + xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_tf32_comp_instance.cpp xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_comp_2x_instance.cpp xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_comp_2x_instance.cpp xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_comp_part2_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_tf32_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_tf32_comp_instance.cpp new file mode 100644 index 00000000000..ad9ad654aab --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_tf32_comp_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_tf32_comp_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_comp_instances<2, + NGCHW, + GKCYX, + Empty_Tuple, + NGKHW, + ConvFwdDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_comp_instance.cpp new file mode 100644 index 00000000000..352aa82d9fd --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_comp_instance.cpp @@ -0,0 +1,68 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_comp_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_comp_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_comp_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1P0>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_comp_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1S1P0>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_comp_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/CMakeLists.txt index 61f7f4421ee..13d890cd76b 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/CMakeLists.txt @@ -210,6 +210,15 @@ generate_sharded_instantiations( OUTPUT_DIR ${GENERATED_DIR}/xdl/comp ) +set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) +generate_sharded_instantiations( + INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_comp_instances + TEMPLATE_FILE xdl/comp/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_comp_instance.in + NUM_SHARDS 4 + SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP + OUTPUT_DIR ${GENERATED_DIR}/xdl/comp +) + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instances diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/xdl/comp/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_comp_instance.in b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/xdl/comp/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_comp_instance.in new file mode 100644 index 00000000000..d12ae33a8e3 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/xdl/comp/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_comp_instance.in @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/utility/filter_tuple.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_comp_instances = + std::vector< + std::unique_ptr, + NHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + BiasNormalizeInInferClamp, + TF32, + TF32>>>; + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +template +void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_comp_instances_shard( + device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_comp_instances& + instances) +{ + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t, + NHWGK, + ConvFwdDefault, + Tuple, + BiasNormalizeInInferClamp>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t, + NHWGK, + ConvFwd1x1P0, + Tuple, + BiasNormalizeInInferClamp>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t, + NHWGK, + ConvFwd1x1S1P0, + Tuple, + BiasNormalizeInInferClamp>, + Shards, + ShardIndex>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/CMakeLists.txt index c370e613332..6296dac50ff 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/CMakeLists.txt @@ -29,4 +29,5 @@ add_instance_library(device_grouped_conv2d_fwd_bias_clamp_instance xdl/mem/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_mem_intra_instance.cpp xdl/mem/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_mem_inter_instance.cpp xdl/comp/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_comp_instance.cpp + xdl/comp/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_comp_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/comp/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/comp/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_comp_instance.cpp new file mode 100644 index 00000000000..61b471cb1c2 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/comp/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_comp_instance.cpp @@ -0,0 +1,65 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_comp_instances( + std::vector, + NHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + AddClamp, + TF32, + TF32>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_comp_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwdDefault, + Tuple, + AddClamp>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_comp_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwd1x1P0, + Tuple, + AddClamp>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_comp_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwd1x1S1P0, + Tuple, + AddClamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/CMakeLists.txt index 5183c66a0af..7bda8c6b83d 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/CMakeLists.txt @@ -29,4 +29,5 @@ add_instance_library(device_grouped_conv2d_fwd_clamp_instance xdl/mem/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_mem_intra_instance.cpp xdl/mem/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_mem_inter_instance.cpp xdl/comp/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_comp_instance.cpp + xdl/comp/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_comp_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/comp/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/comp/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_comp_instance.cpp new file mode 100644 index 00000000000..9977482f8a1 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/comp/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_comp_instance.cpp @@ -0,0 +1,65 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_comp_instances( + std::vector, + NHWGK, + F32, + F32, + Tuple<>, + F32, + PassThrough, + PassThrough, + Clamp, + TF32, + TF32>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_comp_instances<2, + NHWGC, + GKYXC, + Tuple<>, + NHWGK, + ConvFwdDefault, + Tuple<>, + Clamp>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_comp_instances<2, + NHWGC, + GKYXC, + Tuple<>, + NHWGK, + ConvFwd1x1P0, + Tuple<>, + Clamp>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_comp_instances<2, + NHWGC, + GKYXC, + Tuple<>, + NHWGK, + ConvFwd1x1S1P0, + Tuple<>, + Clamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt index 3678c32e6c9..fdad851b4e7 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt @@ -38,8 +38,10 @@ set(GROUPED_CONV3D_FWD xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instance.cpp xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instance.cpp - xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instance.cpp -xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_comp_instance.cpp + xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instance.cpp + xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instance.cpp + xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_comp_instance.cpp + xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_comp_instance.cpp xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_2x_instance.cpp xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_2x_instance.cpp xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_part2_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instance.cpp new file mode 100644 index 00000000000..63ff09234cb --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instance.cpp @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_comp_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_comp_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_comp_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_comp_instance.cpp new file mode 100644 index 00000000000..bb62769b3b4 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_comp_instance.cpp @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_comp_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_comp_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_comp_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_comp_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/CMakeLists.txt index 99a4d05d10c..626aa302d48 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/CMakeLists.txt @@ -219,6 +219,15 @@ generate_sharded_instantiations( OUTPUT_DIR ${GENERATED_DIR}/xdl/comp ) +set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) +generate_sharded_instantiations( + INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instances + TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instance.in + NUM_SHARDS 4 + SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + OUTPUT_DIR ${GENERATED_DIR}/xdl/comp +) + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_2x_instances diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/xdl/comp/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instance.in b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/xdl/comp/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instance.in new file mode 100644 index 00000000000..352b8207b3e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/xdl/comp/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instance.in @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/utility/filter_tuple.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instances = + std::vector, + NDHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + BiasNormalizeInInferClamp, + TF32, + TF32>>>; + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +template +void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instances_shard( + device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instances& + instances) +{ + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t, + NDHWGK, + ConvFwdDefault, + Tuple, + BiasNormalizeInInferClamp>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t, + NDHWGK, + ConvFwd1x1P0, + Tuple, + BiasNormalizeInInferClamp>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t, + NDHWGK, + ConvFwd1x1S1P0, + Tuple, + BiasNormalizeInInferClamp>, + Shards, + ShardIndex>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/CMakeLists.txt index aef445a62e4..c76b6cb7846 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/CMakeLists.txt @@ -24,6 +24,7 @@ set(GROUPED_CONV3D_FWD xdl/mem/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_mem_inter_instance.cpp xdl/mem/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_mem_intra_instance.cpp xdl/comp/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_comp_instance.cpp + xdl/comp/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_comp_instance.cpp xdl/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/xdl/comp/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/xdl/comp/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_comp_instance.cpp new file mode 100644 index 00000000000..4b60dd1b3ef --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/xdl/comp/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_comp_instance.cpp @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instances( + std::vector, + NDHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + AddClamp, + TF32, + TF32>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_comp_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwdDefault, + Tuple, + AddClamp>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_comp_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1P0, + Tuple, + AddClamp>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_comp_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1S1P0, + Tuple, + AddClamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/CMakeLists.txt index 260deba2de9..82b7a56be27 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/CMakeLists.txt @@ -24,6 +24,7 @@ set(GROUPED_CONV3D_FWD xdl/mem/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_mem_inter_instance.cpp xdl/mem/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_mem_intra_instance.cpp xdl/comp/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_comp_instance.cpp + xdl/comp/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_comp_instance.cpp xdl/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/comp/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/comp/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_comp_instance.cpp new file mode 100644 index 00000000000..3a99d693f93 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/comp/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_comp_instance.cpp @@ -0,0 +1,63 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instances( + std::vector, + NDHWGK, + F32, + F32, + Tuple<>, + F32, + PassThrough, + PassThrough, + Clamp, + TF32, + TF32>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_comp_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwdDefault, + Tuple<>, + Clamp>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_comp_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1P0, + Tuple<>, + Clamp>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_comp_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1S1P0, + Tuple<>, + Clamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck From 823ee0717fe32c2ebd2278787d6659fb0be43ddc Mon Sep 17 00:00:00 2001 From: yingmaolu Date: Thu, 18 Sep 2025 16:12:03 +0800 Subject: [PATCH 03/19] add instances of device_grouped_conv_fwd_xdl_f32_tf32_mem_instances --- ...vice_grouped_conv_fwd_xdl_mem_instance.hpp | 38 ++++++++- .../gpu/grouped_convolution_forward.hpp | 12 +++ ...grouped_convolution_forward_bias_clamp.hpp | 8 ++ ...ped_convolution_forward_bias_clamp_xdl.inc | 64 ++++++++++++++ .../gpu/grouped_convolution_forward_clamp.hpp | 8 ++ .../grouped_convolution_forward_clamp_xdl.inc | 64 ++++++++++++++ ...uped_convolution_forward_mem_inter_xdl.inc | 46 ++++++++++ ...uped_convolution_forward_mem_intra_xdl.inc | 71 +++++++++++++--- .../gpu/grouped_conv2d_fwd/CMakeLists.txt | 4 + ...kcyx_ngkhw_f32_tf32_mem_inter_instance.cpp | 42 +++++++++ ...kcyx_ngkhw_f32_tf32_mem_intra_instance.cpp | 42 +++++++++ ...kyxc_nhwgk_f32_tf32_mem_inter_instance.cpp | 70 +++++++++++++++ ...kyxc_nhwgk_f32_tf32_mem_intra_instance.cpp | 70 +++++++++++++++ .../CMakeLists.txt | 20 +++++ ...gkyxc_nhwgk_f32_tf32_mem_inter_instance.in | 85 +++++++++++++++++++ ...gkyxc_nhwgk_f32_tf32_mem_intra_instance.in | 85 +++++++++++++++++++ .../CMakeLists.txt | 2 + ...yxc_nhwgk_fp32_tf32_mem_inter_instance.cpp | 67 +++++++++++++++ ...yxc_nhwgk_fp32_tf32_mem_intra_instance.cpp | 67 +++++++++++++++ .../grouped_conv2d_fwd_clamp/CMakeLists.txt | 2 + ...yxc_nhwgk_fp32_tf32_mem_inter_instance.cpp | 67 +++++++++++++++ ...yxc_nhwgk_fp32_tf32_mem_intra_instance.cpp | 67 +++++++++++++++ .../gpu/grouped_conv3d_fwd/CMakeLists.txt | 16 ++++ ...yxc_ndhwgk_f32_tf32_mem_inter_instance.cpp | 59 +++++++++++++ ...yxc_ndhwgk_f32_tf32_mem_intra_instance.cpp | 59 +++++++++++++ ...czyx_ngkdhw_f32_tf32_mem_inter_instance.in | 67 +++++++++++++++ ...czyx_ngkdhw_f32_tf32_mem_intra_instance.in | 67 +++++++++++++++ .../CMakeLists.txt | 19 +++++ ...zyxc_ndhwgk_f32_tf32_mem_inter_instance.in | 85 +++++++++++++++++++ ...zyxc_ndhwgk_f32_tf32_mem_intra_instance.in | 85 +++++++++++++++++++ .../CMakeLists.txt | 2 + ...xc_ndhwgk_fp32_tf32_mem_inter_instance.cpp | 65 ++++++++++++++ ...xc_ndhwgk_fp32_tf32_mem_intra_instance.cpp | 65 ++++++++++++++ .../grouped_conv3d_fwd_clamp/CMakeLists.txt | 2 + ...xc_ndhwgk_fp32_tf32_mem_inter_instance.cpp | 65 ++++++++++++++ ...xc_ndhwgk_fp32_tf32_mem_intra_instance.cpp | 65 ++++++++++++++ 36 files changed, 1708 insertions(+), 14 deletions(-) create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_tf32_mem_inter_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_tf32_mem_intra_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_intra_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/xdl/mem/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instance.in create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/xdl/mem/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_intra_instance.in create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/mem/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_mem_inter_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/mem/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_mem_intra_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/mem/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_mem_inter_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/mem/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_mem_intra_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_mem_inter_instance.in create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_mem_intra_instance.in create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instance.in create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instance.in create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/xdl/mem/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_mem_inter_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/xdl/mem/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_mem_intra_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/mem/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_mem_inter_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/mem/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_mem_intra_instance.cpp diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp index 57bdeddcf90..44ef8a622ce 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp @@ -24,6 +24,7 @@ using BF8 = ck::bf8_t; using BF16 = ck::bhalf_t; using F16 = ck::half_t; using F32 = float; +using TF32 = ck::tf32_t; template using S = ck::Sequence; @@ -64,7 +65,7 @@ using device_grouped_conv_fwd_xdl_bf16_mem_instances = std::tuple< //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - // Latency friendly + // Latency friendly DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, @@ -163,6 +164,41 @@ using device_grouped_conv_fwd_xdl_f32_mem_instances = std::tuple< // clang-format on >; +template , + typename OutElementOp = PassThrough> +using device_grouped_conv_fwd_xdl_f32_tf32_mem_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, TF32, TF32>, + // Memory friendly + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, TF32, TF32> + // clang-format on + >; + template >>& instances); +void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_intra_instances( + std::vector, + NHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + AddClamp, + TF32, + TF32>>>& instances); + void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances( std::vector>>& instances); +void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instances( + std::vector, + NHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + AddClamp, + TF32, + TF32>>>& instances); + void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( std::vector>>& instances); +void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instances( + std::vector, + NDHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + AddClamp, + TF32, + TF32>>>& instances); + void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances( std::vector>>& instances); +void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instances( + std::vector, + NDHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + AddClamp, + TF32, + TF32>>>& instances); + #endif } // namespace instance diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp.hpp index 5b3a6e7dc50..d129602695b 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp.hpp @@ -135,6 +135,10 @@ struct DeviceOperationInstanceFactory>>& instances); +void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_intra_instances( + std::vector, + NHWGK, + F32, + F32, + Tuple<>, + F32, + PassThrough, + PassThrough, + Clamp, + TF32, + TF32>>>& instances); + void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances( std::vector>>& instances); +void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instances( + std::vector, + NHWGK, + F32, + F32, + Tuple<>, + F32, + PassThrough, + PassThrough, + Clamp, + TF32, + TF32>>>& instances); + void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( std::vector>>& instances); +void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instances( + std::vector, + NDHWGK, + F32, + F32, + Tuple<>, + F32, + PassThrough, + PassThrough, + Clamp, + TF32, + TF32>>>& instances); + void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances( std::vector>>& instances); +void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instances( + std::vector, + NDHWGK, + F32, + F32, + Tuple<>, + F32, + PassThrough, + PassThrough, + Clamp, + TF32, + TF32>>>& instances); + #endif } // namespace instance diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_inter_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_inter_xdl.inc index 00351ceefda..a4fb152828e 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_inter_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_inter_xdl.inc @@ -55,6 +55,21 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances PassThrough, PassThrough, PassThrough>>>& instances); +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_INT8 @@ -120,6 +135,21 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_mem_inter_instances PassThrough, PassThrough, PassThrough>>>& instances); +void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_tf32_mem_inter_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_BF16 @@ -169,6 +199,22 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instan PassThrough, PassThrough, PassThrough>>>& instances); + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instances( + std::vector>>& instances); #endif // grouped conv3d forward, NGCDHW/GKCZYX/NGKDHW diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_intra_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_intra_xdl.inc index bd44116057f..61e9cda9367 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_intra_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_intra_xdl.inc @@ -55,6 +55,21 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances PassThrough, PassThrough, PassThrough>>>& instances); +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_intra_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_INT8 @@ -120,23 +135,38 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_mem_intra_instances PassThrough, PassThrough, PassThrough>>>& instances); -#endif - -#ifdef CK_ENABLE_BF16 -// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK -void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances( - std::vector>>& instances); + PassThrough, + TF32, + TF32>>>& instances) +#endif + +#ifdef CK_ENABLE_BF16 + // grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK + void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_FP16 @@ -169,6 +199,21 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instan PassThrough, PassThrough, PassThrough>>>& instances); +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instances( + std::vector>>& instances); #endif // grouped conv3d forward, NGCDHW/GKCZYX/NGKDHW diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt index b9b67da4c18..0553f262380 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt @@ -50,18 +50,22 @@ set(GROUPED_CONV2D_FWD xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instance.cpp xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instance.cpp + xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_intra_instance.cpp # NHWGC, GKYXC, NHWGK xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instance.cpp xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_inter_instance.cpp + xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instance.cpp xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instance.cpp # NGCHW, GKCYX, NGKHW xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_mem_intra_instance.cpp xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_mem_intra_instance.cpp xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_mem_intra_instance.cpp + xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_tf32_mem_intra_instance.cpp # NGCHW, GKCYX, NGKHW xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_mem_inter_instance.cpp xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_mem_inter_instance.cpp xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_mem_inter_instance.cpp + xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_tf32_mem_inter_instance.cpp #comp # NHWGC, GKYXC, NHWGK xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_tf32_mem_inter_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_tf32_mem_inter_instance.cpp new file mode 100644 index 00000000000..98e52ab15da --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_tf32_mem_inter_instance.cpp @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_tf32_mem_inter_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<2, + NGCHW, + GKCYX, + Empty_Tuple, + NGKHW, + ConvFwdDefault, + Interwave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_tf32_mem_intra_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_tf32_mem_intra_instance.cpp new file mode 100644 index 00000000000..5585de5b4a9 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_tf32_mem_intra_instance.cpp @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_tf32_mem_intra_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<2, + NGCHW, + GKCYX, + Empty_Tuple, + NGKHW, + ConvFwdDefault, + Intrawave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instance.cpp new file mode 100644 index 00000000000..676e2d4a27b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instance.cpp @@ -0,0 +1,70 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault, + Interwave>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1P0, + Interwave>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1S1P0, + Interwave>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC, + Interwave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_intra_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_intra_instance.cpp new file mode 100644 index 00000000000..5601638e774 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_intra_instance.cpp @@ -0,0 +1,70 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_intra_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault, + Intrawave>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1P0, + Intrawave>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1S1P0, + Intrawave>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC, + Intrawave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/CMakeLists.txt index 13d890cd76b..71d1913b4ca 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/CMakeLists.txt @@ -152,6 +152,16 @@ generate_sharded_instantiations( SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/mem ) + +set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) +generate_sharded_instantiations( + INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_intra_instances + TEMPLATE_FILE xdl/mem/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_intra_instance.in + NUM_SHARDS 16 + SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP + OUTPUT_DIR ${GENERATED_DIR}/xdl/mem +) + # NHWGC, GKYXC, NHWGK set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) @@ -180,6 +190,16 @@ generate_sharded_instantiations( SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/mem ) + +set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) +generate_sharded_instantiations( + INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instances + TEMPLATE_FILE xdl/mem/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instance.in + NUM_SHARDS 16 + SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP + OUTPUT_DIR ${GENERATED_DIR}/xdl/mem +) + #comp # NHWGC, GKYXC, NHWGK diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/xdl/mem/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instance.in b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/xdl/mem/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instance.in new file mode 100644 index 00000000000..f516770698f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/xdl/mem/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instance.in @@ -0,0 +1,85 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" +#include "ck/utility/filter_tuple.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instances = + std::vector< + std::unique_ptr, + NHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + BiasNormalizeInInferClamp, + TF32, + TF32>>>; + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +template +void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instances_shard( + device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instances& + instances) +{ + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t, + NHWGK, + ConvFwdDefault, + Interwave, + Tuple, + BiasNormalizeInInferClamp>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t, + NHWGK, + ConvFwd1x1P0, + Interwave, + Tuple, + BiasNormalizeInInferClamp>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t, + NHWGK, + ConvFwd1x1S1P0, + Interwave, + Tuple, + BiasNormalizeInInferClamp>, + Shards, + ShardIndex>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/xdl/mem/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_intra_instance.in b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/xdl/mem/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_intra_instance.in new file mode 100644 index 00000000000..75aabfaa941 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/xdl/mem/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_intra_instance.in @@ -0,0 +1,85 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" +#include "ck/utility/filter_tuple.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_intra_instances = + std::vector< + std::unique_ptr, + NHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + BiasNormalizeInInferClamp, + TF32, + TF32>>>; + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +template +void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_intra_instances_shard( + device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_intra_instances& + instances) +{ + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t, + NHWGK, + ConvFwdDefault, + Intrawave, + Tuple, + BiasNormalizeInInferClamp>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t, + NHWGK, + ConvFwd1x1P0, + Intrawave, + Tuple, + BiasNormalizeInInferClamp>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t, + NHWGK, + ConvFwd1x1S1P0, + Intrawave, + Tuple, + BiasNormalizeInInferClamp>, + Shards, + ShardIndex>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/CMakeLists.txt index 6296dac50ff..ff92bb83be0 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/CMakeLists.txt @@ -27,7 +27,9 @@ add_instance_library(device_grouped_conv2d_fwd_bias_clamp_instance xdl/merged_groups/device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_fp32_instance.cpp xdl/merged_groups/device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp xdl/mem/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_mem_intra_instance.cpp + xdl/mem/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_mem_intra_instance.cpp xdl/mem/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_mem_inter_instance.cpp + xdl/mem/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_mem_inter_instance.cpp xdl/comp/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_comp_instance.cpp xdl/comp/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_comp_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/mem/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_mem_inter_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/mem/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_mem_inter_instance.cpp new file mode 100644 index 00000000000..d9835d7658c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/mem/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_mem_inter_instance.cpp @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instances( + std::vector, + NHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + AddClamp, + TF32, + TF32>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwdDefault, + Interwave, + Tuple, + AddClamp>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwd1x1P0, + Interwave, + Tuple, + AddClamp>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwd1x1S1P0, + Interwave, + Tuple, + AddClamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/mem/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_mem_intra_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/mem/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_mem_intra_instance.cpp new file mode 100644 index 00000000000..43c04443c41 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/mem/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_mem_intra_instance.cpp @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_intra_instances( + std::vector, + NHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + AddClamp, + TF32, + TF32>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwdDefault, + Intrawave, + Tuple, + AddClamp>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwd1x1P0, + Intrawave, + Tuple, + AddClamp>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwd1x1S1P0, + Intrawave, + Tuple, + AddClamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/CMakeLists.txt index 7bda8c6b83d..c3e58be1cb4 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/CMakeLists.txt @@ -27,7 +27,9 @@ add_instance_library(device_grouped_conv2d_fwd_clamp_instance xdl/merged_groups/device_grouped_conv2d_fwd_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_fp32_instance.cpp xdl/merged_groups/device_grouped_conv2d_fwd_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp xdl/mem/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_mem_intra_instance.cpp + xdl/mem/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_mem_intra_instance.cpp xdl/mem/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_mem_inter_instance.cpp + xdl/mem/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_mem_inter_instance.cpp xdl/comp/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_comp_instance.cpp xdl/comp/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_comp_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/mem/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_mem_inter_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/mem/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_mem_inter_instance.cpp new file mode 100644 index 00000000000..b1e53145e36 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/mem/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_mem_inter_instance.cpp @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instances( + std::vector, + NHWGK, + F32, + F32, + Tuple<>, + F32, + PassThrough, + PassThrough, + Clamp, + TF32, + TF32>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<2, + NHWGC, + GKYXC, + Tuple<>, + NHWGK, + ConvFwdDefault, + Interwave, + Tuple<>, + Clamp>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<2, + NHWGC, + GKYXC, + Tuple<>, + NHWGK, + ConvFwd1x1P0, + Interwave, + Tuple<>, + Clamp>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<2, + NHWGC, + GKYXC, + Tuple<>, + NHWGK, + ConvFwd1x1S1P0, + Interwave, + Tuple<>, + Clamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/mem/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_mem_intra_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/mem/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_mem_intra_instance.cpp new file mode 100644 index 00000000000..74555cc227b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/mem/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_mem_intra_instance.cpp @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_intra_instances( + std::vector, + NHWGK, + F32, + F32, + Tuple<>, + F32, + PassThrough, + PassThrough, + Clamp, + TF32, + TF32>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<2, + NHWGC, + GKYXC, + Tuple<>, + NHWGK, + ConvFwdDefault, + Intrawave, + Tuple<>, + Clamp>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<2, + NHWGC, + GKYXC, + Tuple<>, + NHWGK, + ConvFwd1x1P0, + Intrawave, + Tuple<>, + Clamp>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<2, + NHWGC, + GKYXC, + Tuple<>, + NHWGK, + ConvFwd1x1S1P0, + Intrawave, + Tuple<>, + Clamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt index fdad851b4e7..06de44a1a3a 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt @@ -33,10 +33,12 @@ set(GROUPED_CONV3D_FWD xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instance.cpp xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_inter_instance.cpp xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instance.cpp + xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instance.cpp xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instance.cpp xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instance.cpp xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instance.cpp + xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instance.cpp xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instance.cpp xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instance.cpp @@ -106,6 +108,13 @@ generate_sharded_instantiations( SRC_LIST GROUPED_CONV3D_FWD OUTPUT_DIR ${GENERATED_DIR}/xdl/mem ) +generate_sharded_instantiations( + INSTANCES_NAME device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_mem_inter_instances + TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_mem_inter_instance.in + NUM_SHARDS 10 + SRC_LIST GROUPED_CONV3D_FWD + OUTPUT_DIR ${GENERATED_DIR}/xdl/mem +) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instances @@ -128,6 +137,13 @@ generate_sharded_instantiations( SRC_LIST GROUPED_CONV3D_FWD OUTPUT_DIR ${GENERATED_DIR}/xdl/mem ) +generate_sharded_instantiations( + INSTANCES_NAME device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_mem_intra_instances + TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_mem_intra_instance.in + NUM_SHARDS 10 + SRC_LIST GROUPED_CONV3D_FWD + OUTPUT_DIR ${GENERATED_DIR}/xdl/mem +) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instance.cpp new file mode 100644 index 00000000000..fe6141ac694 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instance.cpp @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault, + Interwave>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1P0, + Interwave>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1S1P0, + Interwave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instance.cpp new file mode 100644 index 00000000000..633123e3c81 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instance.cpp @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault, + Intrawave>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1P0, + Intrawave>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1S1P0, + Intrawave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_mem_inter_instance.in b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_mem_inter_instance.in new file mode 100644 index 00000000000..00e39603e71 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_mem_inter_instance.in @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/utility/filter_tuple.hpp" + +namespace ck::tensor_operation::device::instance { + +using device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_mem_inter_instances = + std::vector>>; +template +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_mem_inter_instances_shard( + device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_mem_inter_instances& instances) +{ + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault, + Interwave>, + Shards, + ShardIndex>{}); + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1P0, + Interwave>, + Shards, + ShardIndex>{}); + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1S1P0, + Interwave>, + Shards, + ShardIndex>{}); +} + +} // namespace ck::tensor_operation::device::instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_mem_intra_instance.in b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_mem_intra_instance.in new file mode 100644 index 00000000000..9e13fddd32e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_mem_intra_instance.in @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/utility/filter_tuple.hpp" + +namespace ck::tensor_operation::device::instance { + +using device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_mem_intra_instances = + std::vector>>; +template +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_mem_intra_instances_shard( + device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_mem_intra_instances& instances) +{ + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault, + Intrawave>, + Shards, + ShardIndex>{}); + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1P0, + Intrawave>, + Shards, + ShardIndex>{}); + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1S1P0, + Intrawave>, + Shards, + ShardIndex>{}); +} + +} // namespace ck::tensor_operation::device::instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/CMakeLists.txt index 626aa302d48..83f0c78e003 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/CMakeLists.txt @@ -163,6 +163,15 @@ generate_sharded_instantiations( ) # NDHWGC, GKZYXC, NDHWGK +set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) +generate_sharded_instantiations( + INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instances + TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instance.in + NUM_SHARDS 16 + SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + OUTPUT_DIR ${GENERATED_DIR}/xdl/mem +) + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances @@ -189,6 +198,16 @@ generate_sharded_instantiations( SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/mem ) + +set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) +generate_sharded_instantiations( + INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instances + TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instance.in + NUM_SHARDS 16 + SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + OUTPUT_DIR ${GENERATED_DIR}/xdl/mem +) + #comp # NDHWGC, GKZYXC, NDHWGK diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instance.in b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instance.in new file mode 100644 index 00000000000..b87dce84118 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instance.in @@ -0,0 +1,85 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" +#include "ck/utility/filter_tuple.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instances = + std::vector, + NDHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + BiasNormalizeInInferClamp, + TF32, + TF32>>>; + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +template +void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instances_shard( + device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instances& + instances) +{ + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t, + NDHWGK, + ConvFwdDefault, + Interwave, + Tuple, + BiasNormalizeInInferClamp>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t, + NDHWGK, + ConvFwd1x1P0, + Interwave, + Tuple, + BiasNormalizeInInferClamp>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t, + NDHWGK, + ConvFwd1x1S1P0, + Interwave, + Tuple, + BiasNormalizeInInferClamp>, + Shards, + ShardIndex>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instance.in b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instance.in new file mode 100644 index 00000000000..c1df1e262ed --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instance.in @@ -0,0 +1,85 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" +#include "ck/utility/filter_tuple.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instances = + std::vector, + NDHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + BiasNormalizeInInferClamp, + TF32, + TF32>>>; + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +template +void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instances_shard( + device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instances& + instances) +{ + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t, + NDHWGK, + ConvFwdDefault, + Intrawave, + Tuple, + BiasNormalizeInInferClamp>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t, + NDHWGK, + ConvFwd1x1P0, + Intrawave, + Tuple, + BiasNormalizeInInferClamp>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t, + NDHWGK, + ConvFwd1x1S1P0, + Intrawave, + Tuple, + BiasNormalizeInInferClamp>, + Shards, + ShardIndex>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/CMakeLists.txt index c76b6cb7846..cccca804ed1 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/CMakeLists.txt @@ -22,7 +22,9 @@ set(GROUPED_CONV3D_FWD xdl/merged_groups/device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_fp32_instance.cpp xdl/merged_groups/device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp xdl/mem/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_mem_inter_instance.cpp + xdl/mem/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_mem_inter_instance.cpp xdl/mem/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_mem_intra_instance.cpp + xdl/mem/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_mem_intra_instance.cpp xdl/comp/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_comp_instance.cpp xdl/comp/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_comp_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/xdl/mem/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_mem_inter_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/xdl/mem/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_mem_inter_instance.cpp new file mode 100644 index 00000000000..765719c7b56 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/xdl/mem/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_mem_inter_instance.cpp @@ -0,0 +1,65 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instances( + std::vector, + NDHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + AddClamp, + TF32, + TF32>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwdDefault, + Interwave, + Tuple, + AddClamp>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1P0, + Interwave, + Tuple, + AddClamp>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1S1P0, + Interwave, + Tuple, + AddClamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/xdl/mem/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_mem_intra_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/xdl/mem/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_mem_intra_instance.cpp new file mode 100644 index 00000000000..0daf28adef0 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/xdl/mem/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_mem_intra_instance.cpp @@ -0,0 +1,65 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instances( + std::vector, + NDHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + AddClamp, + TF32, + TF32>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwdDefault, + Intrawave, + Tuple, + AddClamp>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1P0, + Intrawave, + Tuple, + AddClamp>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1S1P0, + Intrawave, + Tuple, + AddClamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/CMakeLists.txt index 82b7a56be27..ff9a5724f5e 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/CMakeLists.txt @@ -22,7 +22,9 @@ set(GROUPED_CONV3D_FWD xdl/merged_groups/device_grouped_conv3d_fwd_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_fp32_instance.cpp xdl/merged_groups/device_grouped_conv3d_fwd_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp xdl/mem/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_mem_inter_instance.cpp + xdl/mem/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_mem_inter_instance.cpp xdl/mem/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_mem_intra_instance.cpp + xdl/mem/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_mem_intra_instance.cpp xdl/comp/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_comp_instance.cpp xdl/comp/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_comp_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/mem/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_mem_inter_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/mem/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_mem_inter_instance.cpp new file mode 100644 index 00000000000..905da7e1d09 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/mem/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_mem_inter_instance.cpp @@ -0,0 +1,65 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instances( + std::vector, + NDHWGK, + F32, + F32, + Tuple<>, + F32, + PassThrough, + PassThrough, + Clamp, + TF32, + TF32>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwdDefault, + Interwave, + Tuple<>, + Clamp>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1P0, + Interwave, + Tuple<>, + Clamp>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1S1P0, + Interwave, + Tuple<>, + Clamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/mem/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_mem_intra_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/mem/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_mem_intra_instance.cpp new file mode 100644 index 00000000000..008dd28921e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/mem/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_mem_intra_instance.cpp @@ -0,0 +1,65 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instances( + std::vector, + NDHWGK, + F32, + F32, + Tuple<>, + F32, + PassThrough, + PassThrough, + Clamp, + TF32, + TF32>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwdDefault, + Intrawave, + Tuple<>, + Clamp>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1P0, + Intrawave, + Tuple<>, + Clamp>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1S1P0, + Intrawave, + Tuple<>, + Clamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck From 255a25d5b96aa86f0d1925376fdd10520f1a16ec Mon Sep 17 00:00:00 2001 From: yingmaolu Date: Thu, 18 Sep 2025 16:35:35 +0800 Subject: [PATCH 04/19] add instances of device_grouped_conv_fwd_xdl_large_tensor_f32_tf32_instances --- ...ped_conv_fwd_xdl_large_tensor_instance.hpp | 22 ++++++++ .../gpu/grouped_convolution_forward.hpp | 4 ++ ...grouped_convolution_forward_bias_clamp.hpp | 4 ++ ...ped_convolution_forward_bias_clamp_xdl.inc | 40 ++++++++++++-- .../gpu/grouped_convolution_forward_clamp.hpp | 4 ++ .../grouped_convolution_forward_clamp_xdl.inc | 32 +++++++++++ ...d_convolution_forward_xdl_large_tensor.inc | 32 +++++++++++ .../gpu/grouped_conv2d_fwd/CMakeLists.txt | 1 + ...or_nhwgc_gkyxc_nhwgk_f32_tf32_instance.cpp | 41 ++++++++++++++ .../CMakeLists.txt | 10 ++++ ...sor_nhwgc_gkyxc_nhwgk_f32_tf32_instance.in | 54 +++++++++++++++++++ .../CMakeLists.txt | 2 +- ...r_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp | 43 +++++++++++++++ .../grouped_conv2d_fwd_clamp/CMakeLists.txt | 1 + ...r_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp | 43 +++++++++++++++ .../gpu/grouped_conv3d_fwd/CMakeLists.txt | 1 + ...ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp | 41 ++++++++++++++ .../CMakeLists.txt | 10 ++++ ..._ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.in | 54 +++++++++++++++++++ .../CMakeLists.txt | 1 + ...dhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp | 43 +++++++++++++++ .../grouped_conv3d_fwd_clamp/CMakeLists.txt | 1 + ...dhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp | 43 +++++++++++++++ 23 files changed, 522 insertions(+), 5 deletions(-) create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/large_tensor/device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/xdl/large_tensor/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instance.in create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/large_tensor/device_grouped_conv2d_fwd_bias_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/large_tensor/device_grouped_conv2d_fwd_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/large_tensor/device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/xdl/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.in create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/xdl/large_tensor/device_grouped_conv3d_fwd_bias_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/large_tensor/device_grouped_conv3d_fwd_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp index 5a4a0115128..10040251731 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp @@ -16,6 +16,7 @@ namespace instance { using BF16 = ck::bhalf_t; using F16 = ck::half_t; using F32 = float; +using TF32 = ck::tf32_t; template using S = ck::Sequence; @@ -99,6 +100,27 @@ using device_grouped_conv_fwd_xdl_large_tensor_f32_instances = std::tuple< // clang-format on >; +template , + typename OutElementOp = PassThrough> +using device_grouped_conv_fwd_xdl_large_tensor_f32_tf32_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1, TF32, TF32>, + + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32> + // clang-format on + >; + template >>& instances); -void add_device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances( +void add_device_grouped_conv2d_fwd_bias_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instances( std::vector>>& instances); + AddClamp, + TF32, + TF32>>>& instances); + + void add_device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances( + std::vector, + NHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + AddClamp>>>& instances); void add_device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instances( std::vector>>& instances); -void add_device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances( +void add_device_grouped_conv3d_fwd_bias_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( std::vector>>& instances); + AddClamp, + TF32, + TF32>>>& instances) + + void add_device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances( + std::vector, + NDHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + AddClamp>>>& instances); void add_device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( std::vector>>& instances); +void add_device_grouped_conv2d_fwd_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instances( + std::vector, + NHWGK, + F32, + F32, + Tuple<>, + F32, + PassThrough, + PassThrough, + Clamp, + TF32, + TF32>>>& instances); + void add_device_grouped_conv2d_fwd_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances( std::vector>>& instances); +void add_device_grouped_conv3d_fwd_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + std::vector, + NDHWGK, + F32, + F32, + Tuple<>, + F32, + PassThrough, + PassThrough, + Clamp, + TF32, + TF32>>>& instances); + void add_device_grouped_conv3d_fwd_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances( std::vector>>& instances); +void add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instances( + std::vector>>& instances); + #endif #ifdef CK_ENABLE_INT8 @@ -120,6 +136,22 @@ void add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_ins PassThrough, PassThrough, PassThrough>>>& instances); + +void add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + std::vector>>& instances) #endif } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt index 0553f262380..b5b732e87aa 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt @@ -32,6 +32,7 @@ set(GROUPED_CONV2D_FWD xdl/large_tensor/device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp xdl/large_tensor/device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f16_instance.cpp xdl/large_tensor/device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instance.cpp + xdl/large_tensor/device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instance.cpp xdl/large_tensor/device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_int8_instance.cpp # merged groups # NHWGC, GKYXC, NHWGK diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/large_tensor/device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/large_tensor/device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instance.cpp new file mode 100644 index 00000000000..9a81ccbb82e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/large_tensor/device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_large_tensor_f32_tf32_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/CMakeLists.txt index 71d1913b4ca..47ae049ddcc 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/CMakeLists.txt @@ -85,6 +85,16 @@ generate_sharded_instantiations( SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/large_tensor ) + +set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) +generate_sharded_instantiations( + INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instances + TEMPLATE_FILE xdl/large_tensor/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instance.in + NUM_SHARDS 2 + SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP + OUTPUT_DIR ${GENERATED_DIR}/xdl/large_tensor +) + # merged groups # NHWGC, GKYXC, NHWGK diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/xdl/large_tensor/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instance.in b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/xdl/large_tensor/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instance.in new file mode 100644 index 00000000000..6073ad94d3e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/xdl/large_tensor/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instance.in @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp" +#include "ck/utility/filter_tuple.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using device_grouped_conv2d_fwd_bias_bn_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instances = + std::vector< + std::unique_ptr, + NHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + BiasNormalizeInInferClamp, + TF32, + TF32>>>; + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +template +void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instances_shard( + device_grouped_conv2d_fwd_bias_bn_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instances& + instances) +{ + add_device_operation_instances(instances, + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_large_tensor_f32_tf32_instances< + 2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwdDefault, + Tuple, + BiasNormalizeInInferClamp>, + Shards, + ShardIndex>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/CMakeLists.txt index ff92bb83be0..0766b00d86f 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/CMakeLists.txt @@ -23,7 +23,7 @@ add_instance_library(device_grouped_conv2d_fwd_bias_clamp_instance xdl/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_instance.cpp xdl/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp xdl/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_16x16_instance.cpp - xdl/large_tensor/device_grouped_conv2d_fwd_bias_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_fp32_instance.cpp + xdl/large_tensor/device_grouped_conv2d_fwd_bias_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp xdl/merged_groups/device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_fp32_instance.cpp xdl/merged_groups/device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp xdl/mem/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_mem_intra_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/large_tensor/device_grouped_conv2d_fwd_bias_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/large_tensor/device_grouped_conv2d_fwd_bias_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp new file mode 100644 index 00000000000..b982a92b020 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/large_tensor/device_grouped_conv2d_fwd_bias_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_bias_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instances( + std::vector, + NHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + AddClamp, + TF32, + TF32>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_large_tensor_f32_tf32_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwdDefault, + Tuple, + AddClamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/CMakeLists.txt index c3e58be1cb4..f0404cd0f42 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/CMakeLists.txt @@ -24,6 +24,7 @@ add_instance_library(device_grouped_conv2d_fwd_clamp_instance xdl/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp xdl/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_16x16_instance.cpp xdl/large_tensor/device_grouped_conv2d_fwd_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_fp32_instance.cpp + xdl/large_tensor/device_grouped_conv2d_fwd_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp xdl/merged_groups/device_grouped_conv2d_fwd_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_fp32_instance.cpp xdl/merged_groups/device_grouped_conv2d_fwd_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp xdl/mem/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_mem_intra_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/large_tensor/device_grouped_conv2d_fwd_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/large_tensor/device_grouped_conv2d_fwd_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp new file mode 100644 index 00000000000..f4933e62b8c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/large_tensor/device_grouped_conv2d_fwd_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instances( + std::vector, + NHWGK, + F32, + F32, + Tuple<>, + F32, + PassThrough, + PassThrough, + Clamp, + TF32, + TF32>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_large_tensor_f32_tf32_instances<2, + NHWGC, + GKYXC, + Tuple<>, + NHWGK, + ConvFwdDefault, + Tuple<>, + Clamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt index 06de44a1a3a..3c58e4d0c72 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt @@ -20,6 +20,7 @@ set(GROUPED_CONV3D_FWD xdl/large_tensor/device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp xdl/large_tensor/device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp xdl/large_tensor/device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp + xdl/large_tensor/device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/large_tensor/device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/large_tensor/device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp new file mode 100644 index 00000000000..b6c8cd1bdb3 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/large_tensor/device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_large_tensor_f32_tf32_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/CMakeLists.txt index 83f0c78e003..b6377ba2b41 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/CMakeLists.txt @@ -94,6 +94,16 @@ generate_sharded_instantiations( SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/large_tensor ) + +set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) +generate_sharded_instantiations( + INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances + TEMPLATE_FILE xdl/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.in + NUM_SHARDS 2 + SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + OUTPUT_DIR ${GENERATED_DIR}/xdl/large_tensor +) + # merged groups # NDHWGC, GKZYXC, NDHWGK diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/xdl/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.in b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/xdl/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.in new file mode 100644 index 00000000000..74308b1c9db --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/xdl/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.in @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp" +#include "ck/utility/filter_tuple.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances = + std::vector, + NDHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + BiasNormalizeInInferClamp, + TF32, + TF32>>>; + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +template +void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances_shard( + device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances& + instances) +{ + add_device_operation_instances(instances, + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_large_tensor_f32_tf32_instances< + 3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwdDefault, + Tuple, + BiasNormalizeInInferClamp>, + Shards, + ShardIndex>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/CMakeLists.txt index cccca804ed1..ef7cc22bc49 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/CMakeLists.txt @@ -19,6 +19,7 @@ set(GROUPED_CONV3D_FWD xdl/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_instance.cpp xdl/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_16x16_instance.cpp xdl/large_tensor/device_grouped_conv3d_fwd_bias_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_fp32_instance.cpp + xdl/large_tensor/device_grouped_conv3d_fwd_bias_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp xdl/merged_groups/device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_fp32_instance.cpp xdl/merged_groups/device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp xdl/mem/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_mem_inter_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/xdl/large_tensor/device_grouped_conv3d_fwd_bias_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/xdl/large_tensor/device_grouped_conv3d_fwd_bias_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp new file mode 100644 index 00000000000..04d750d2b92 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/xdl/large_tensor/device_grouped_conv3d_fwd_bias_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_bias_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + std::vector, + NDHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + AddClamp, + TF32, + TF32>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_large_tensor_f32_tf32_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwdDefault, + Tuple, + AddClamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/CMakeLists.txt index ff9a5724f5e..0c126b20843 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/CMakeLists.txt @@ -19,6 +19,7 @@ set(GROUPED_CONV3D_FWD xdl/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_instance.cpp xdl/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_16x16_instance.cpp xdl/large_tensor/device_grouped_conv3d_fwd_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_fp32_instance.cpp + xdl/large_tensor/device_grouped_conv3d_fwd_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp xdl/merged_groups/device_grouped_conv3d_fwd_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_fp32_instance.cpp xdl/merged_groups/device_grouped_conv3d_fwd_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp xdl/mem/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_mem_inter_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/large_tensor/device_grouped_conv3d_fwd_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/large_tensor/device_grouped_conv3d_fwd_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp new file mode 100644 index 00000000000..58595768354 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/large_tensor/device_grouped_conv3d_fwd_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + std::vector, + NDHWGK, + F32, + F32, + Tuple<>, + F32, + PassThrough, + PassThrough, + Clamp, + TF32, + TF32>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_large_tensor_f32_tf32_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwdDefault, + Tuple<>, + Clamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck From 58a3fa167cf978a2a544544ab74a7d849923dacd Mon Sep 17 00:00:00 2001 From: yingmaolu Date: Thu, 18 Sep 2025 17:06:35 +0800 Subject: [PATCH 05/19] review --- example/09_convnd_fwd/convnd_fwd_xdl_fp32.cpp | 22 ++++----- .../convnd_fwd_xdl_fp32_tf32.cpp | 22 ++++----- .../blockwise_gemm_pipeline_xdlops_v3.hpp | 6 +-- ...kwise_gemm_pipeline_xdlops_v3_ab_scale.hpp | 6 +-- ...ckwise_gemm_pipeline_xdlops_v3_b_scale.hpp | 6 +-- .../gpu/block/blockwise_gemm_xdlops.hpp | 30 ++++++------ .../gpu/grouped_convolution_forward.hpp | 38 +++++++-------- ...grouped_convolution_forward_bias_clamp.hpp | 4 +- .../gpu/grouped_convolution_forward_clamp.hpp | 4 +- ...d_convolution_forward_xdl_large_tensor.inc | 2 +- .../gpu/grouped_conv2d_fwd/CMakeLists.txt | 1 + .../CMakeLists.txt | 48 +++++++++---------- .../CMakeLists.txt | 1 + 13 files changed, 97 insertions(+), 93 deletions(-) diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_fp32.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_fp32.cpp index d4e2a8e0b71..40c38b39d87 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl_fp32.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl_fp32.cpp @@ -46,32 +46,32 @@ using DeviceGroupedConvNDFwdInstance = GemmSpec, // GemmSpecialization 1, // 256, // BlockSize - 64, // MPerBlock - 64, // NPerBlock - 32, // KPerBlock - 8, // AK1 - 8, // BK1 - 16, // MPerXdl - 16, // NPerXdl + 128, // MPerBlock + 256, // NPerBlock + 16, // KPerBlock + 4, // AK1 + 4, // BK1 + 32, // MPerXdl + 32, // NPerXdl 2, // MXdlPerWave - 2, // NXdlPerWave + 4, // NXdlPerWave S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder 2, // ABlockTransferSrcVectorDim 4, // ABlockTransferSrcScalarPerVector - 8, // ABlockTransferDstScalarPerVector_AK1 + 4, // ABlockTransferDstScalarPerVector_AK1 1, // ABlockLdsExtraM S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // BBlockTransferSrcAccessOrder 2, // BBlockTransferSrcVectorDim 4, // BBlockTransferSrcScalarPerVector - 8, // BBlockTransferDstScalarPerVector_BK1 + 4, // BBlockTransferDstScalarPerVector_BK1 1, // BBlockLdsExtraN 1, 1, - S<1, 32, 1, 4>, + S<1, 16, 1, 16>, 4>; #include "run_convnd_fwd_example.inc" diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_fp32_tf32.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_fp32_tf32.cpp index 9264aee24d6..348da7e1ef4 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl_fp32_tf32.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl_fp32_tf32.cpp @@ -49,32 +49,32 @@ using DeviceGroupedConvNDFwdInstance = GemmSpec, // GemmSpecialization 1, // NumGemmKPrefetchStage 256, // BlockSize - 64, // MPerBlock - 64, // NPerBlock - 32, // KPerBlock - 8, // AK1 - 8, // BK1 - 16, // MPerXdl - 16, // NPerXdl + 128, // MPerBlock + 192, // NPerBlock + 16, // KPerBlock + 4, // AK1 + 4, // BK1 + 32, // MPerXdl + 32, // NPerXdl 2, // MXdlPerWave - 2, // NXdlPerWave + 3, // NXdlPerWave S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder 2, // ABlockTransferSrcVectorDim 4, // ABlockTransferSrcScalarPerVector - 8, // ABlockTransferDstScalarPerVector_AK1 + 4, // ABlockTransferDstScalarPerVector_AK1 1, // ABlockLdsExtraM S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // BBlockTransferSrcAccessOrder 2, // BBlockTransferSrcVectorDim 4, // BBlockTransferSrcScalarPerVector - 8, // BBlockTransferDstScalarPerVector_BK1 + 4, // BBlockTransferDstScalarPerVector_BK1 1, // BBlockLdsExtraN 1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleNXdlPerWavePerShuffle - S<1, 32, 1, 4>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + S<1, 16, 1, 16>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock 4, // CDEBlockTransferScalarPerVector_NPerBlock ComputeDataType, // AComputeDataType ComputeDataType, // BComputeDataType diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp index a2053dfd397..f797c611a8f 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp @@ -199,9 +199,9 @@ struct BlockwiseGemmXdlops_pipeline_v3 - // sizeof(ComputeDataTypeBuf) / sizeof(BDataType) - // ? sizeof(ComputeDataTypeBuf) / sizeof(ADataType) - // : sizeof(ComputeDataTypeBuf) / sizeof(BDataType); + // sizeof(ComputeDataTypeBuf) / sizeof(BDataType) + // ? sizeof(ComputeDataTypeBuf) / sizeof(ADataType) + // : sizeof(ComputeDataTypeBuf) / sizeof(BDataType); constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma); constexpr auto num_mfma_per_issue = num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp index 08d5d1f9561..3f4f7ea7e8b 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp @@ -199,9 +199,9 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale - // sizeof(ComputeDataTypeBuf) / sizeof(BDataType) - // ? sizeof(ComputeDataTypeBuf) / sizeof(ADataType) - // : sizeof(ComputeDataTypeBuf) / sizeof(BDataType); + // sizeof(ComputeDataTypeBuf) / sizeof(BDataType) + // ? sizeof(ComputeDataTypeBuf) / sizeof(ADataType) + // : sizeof(ComputeDataTypeBuf) / sizeof(BDataType); constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma); constexpr auto num_mfma_per_issue = num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp index a9b399ea5bb..35be8b9551e 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp @@ -198,9 +198,9 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale - // sizeof(ComputeDataTypeBuf) / sizeof(BDataType) - // ? sizeof(ComputeDataTypeBuf) / sizeof(ADataType) - // : sizeof(ComputeDataTypeBuf) / sizeof(BDataType); + // sizeof(ComputeDataTypeBuf) / sizeof(BDataType) + // ? sizeof(ComputeDataTypeBuf) / sizeof(ADataType) + // : sizeof(ComputeDataTypeBuf) / sizeof(BDataType); constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma); constexpr auto num_mfma_per_issue = num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp index 9fdd12adfbb..55015dd30f7 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp @@ -69,7 +69,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 static constexpr index_t WaveSize = BlockSize / MWaves / NWaves; static constexpr auto xdlops_gemm = - XdlopsGemm{}; + XdlopsGemm{}; static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops; @@ -637,19 +637,21 @@ constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector() } else if constexpr(LoopSched == LoopScheduler::Interwave) { - return BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; + return BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1< + BlockSize, + FloatA, + FloatB, + FloatAcc, + AK0MK1BlockDesc, + BK0NK1BlockDesc, + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + KPack, + ComputeTypeA, + ComputeTypeB, + CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS>{}; } }; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp index 45c50fea0d9..fcfb276fece 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp @@ -129,16 +129,16 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v && - is_same_v) - { - add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instances(op_ptrs); - } - if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v) + is_same_v) { - add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_tf32_instances(op_ptrs); + if constexpr(is_same_v && is_same_v) + { + add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_tf32_instances(op_ptrs); + } + else + { + add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instances(op_ptrs); + } } #endif #ifdef CK_ENABLE_FP16 @@ -393,16 +393,16 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v && - is_same_v) - { - add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_instances(op_ptrs); - } - if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v) + is_same_v) { - add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_tf32_instances(op_ptrs); + if constexpr(is_same_v && is_same_v) + { + add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_tf32_instances(op_ptrs); + } + else + { + add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_instances(op_ptrs); + } } #endif #ifdef CK_ENABLE_FP16 @@ -482,7 +482,7 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { - if constexpr(is_same_v && is_same_v) + if constexpr(is_same_v && is_same_v) { add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( op_ptrs); diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp.hpp index d91ae178f42..bf5ce173e75 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp.hpp @@ -129,7 +129,7 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { - if constexpr(is_same_v && is_same_v) + if constexpr(is_same_v && is_same_v) { add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances( op_ptrs); @@ -217,7 +217,7 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { - if constexpr(is_same_v && is_same_v) + if constexpr(is_same_v && is_same_v) { add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( op_ptrs); diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp.hpp index 2c7a4829372..778d62b7333 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp.hpp @@ -127,7 +127,7 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { - if constexpr(is_same_v && is_same_v) + if constexpr(is_same_v && is_same_v) { add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances( op_ptrs); @@ -214,7 +214,7 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { - if constexpr(is_same_v && is_same_v) + if constexpr(is_same_v && is_same_v) { add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( op_ptrs); diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl_large_tensor.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl_large_tensor.inc index c1626d61ea2..e67d71f8aba 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl_large_tensor.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl_large_tensor.inc @@ -151,7 +151,7 @@ void add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf3 PassThrough, PassThrough, TF32, - TF32>>>& instances) + TF32>>>& instances); #endif } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt index b5b732e87aa..0e48c974cae 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt @@ -72,6 +72,7 @@ set(GROUPED_CONV2D_FWD xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instance.cpp xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instance.cpp + xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_comp_instance.cpp xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_comp_instance.cpp xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instance.cpp xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_2x_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/CMakeLists.txt index 47ae049ddcc..a801144bfd0 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/CMakeLists.txt @@ -2,7 +2,7 @@ set(GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP) include(ShardInstantiation) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_instances @@ -11,7 +11,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_instances @@ -20,7 +20,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_instances @@ -29,7 +29,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instances @@ -38,7 +38,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_16x16_instances @@ -47,7 +47,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_16x16_instances @@ -58,7 +58,7 @@ generate_sharded_instantiations( ) # large tensor # NHWGC, GKYXC, NHWGK - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances @@ -67,7 +67,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/large_tensor ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances @@ -76,7 +76,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/large_tensor ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instances @@ -97,7 +97,7 @@ generate_sharded_instantiations( # merged groups # NHWGC, GKYXC, NHWGK - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances @@ -106,7 +106,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/merged_groups ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instances @@ -115,7 +115,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/merged_groups ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances @@ -135,7 +135,7 @@ generate_sharded_instantiations( ) #mem # NHWGC, GKYXC, NHWGK - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances @@ -144,7 +144,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/mem ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instances @@ -153,7 +153,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/mem ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances @@ -173,7 +173,7 @@ generate_sharded_instantiations( ) # NHWGC, GKYXC, NHWGK - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instances @@ -182,7 +182,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/mem ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_mem_inter_instances @@ -191,7 +191,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/mem ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances @@ -212,7 +212,7 @@ generate_sharded_instantiations( #comp # NHWGC, GKYXC, NHWGK - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances @@ -221,7 +221,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/comp ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instances @@ -230,7 +230,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/comp ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances @@ -257,7 +257,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/comp ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_comp_2x_instances @@ -266,7 +266,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/comp ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instances @@ -275,7 +275,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/comp ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f16_comp_part2_instances diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/CMakeLists.txt index 0766b00d86f..41274f8027a 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/CMakeLists.txt @@ -23,6 +23,7 @@ add_instance_library(device_grouped_conv2d_fwd_bias_clamp_instance xdl/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_instance.cpp xdl/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp xdl/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_16x16_instance.cpp + xdl/large_tensor/device_grouped_conv2d_fwd_bias_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_fp32_instance.cpp xdl/large_tensor/device_grouped_conv2d_fwd_bias_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp xdl/merged_groups/device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_fp32_instance.cpp xdl/merged_groups/device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp From 7f6962e494dacbcab6639e0728f0b227d031b0c8 Mon Sep 17 00:00:00 2001 From: yingmaolu Date: Fri, 19 Sep 2025 11:15:56 +0800 Subject: [PATCH 06/19] tf32:conv:add instances for base class DeviceConvFwd --- .../gpu/device/device_conv_fwd.hpp | 3 +- .../device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp | 17 ++- .../gpu/conv2d_fwd/CMakeLists.txt | 1 + ...d_xdl_nhwc_kyxc_nhwk_f32_tf32_instance.cpp | 136 ++++++++++++++++++ 4 files changed, 149 insertions(+), 8 deletions(-) create mode 100644 library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_tf32_instance.cpp diff --git a/include/ck/tensor_operation/gpu/device/device_conv_fwd.hpp b/include/ck/tensor_operation/gpu/device/device_conv_fwd.hpp index 4dc11dbefd7..9859b6d5854 100644 --- a/include/ck/tensor_operation/gpu/device/device_conv_fwd.hpp +++ b/include/ck/tensor_operation/gpu/device/device_conv_fwd.hpp @@ -20,7 +20,8 @@ template + typename OutElementwiseOperation, + typename ComputeDataType = InDataType> struct DeviceConvFwd : public BaseOperator { virtual std::unique_ptr diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp index cecfa48408f..16a53c24b2b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp @@ -54,7 +54,8 @@ template + ck::index_t CThreadTransferDstScalarPerVector, + typename ComputeDataType = InDataType> struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K : public DeviceConvFwd<2, ck::tensor_layout::convolution::NHWC, @@ -65,7 +66,8 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K OutDataType, InElementwiseOperation, WeiElementwiseOperation, - OutElementwiseOperation> + OutElementwiseOperation, + ComputeDataType> { using DeviceOp = DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K; @@ -78,7 +80,8 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K using CDataType = OutDataType; // TODO make A/B datatype different - using ABDataType = InDataType; + using ABDataTypeElementwise = ADataType; // for load/store and elementwise operation + using ABDataTypeGemm = ComputeDataType; // only for gemm computation static constexpr index_t NDimSpatial = 2; @@ -331,7 +334,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K template using GridwiseGemmBase = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< BlockSize, - ABDataType, // TODO: distinguish A/B datatype + ABDataTypeGemm, // TODO: distinguish A/B datatype AccDataType, CDataType, InMemoryDataOperationEnum::Set, @@ -472,7 +475,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K { const auto kernel = kernel_gemm_xdlops_v2r3()) + if(!ck::is_xdl_wmma_supported()) { return false; } diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/conv2d_fwd/CMakeLists.txt index 04b313d075b..028d5b518f1 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd/CMakeLists.txt @@ -1,6 +1,7 @@ # ONLY XDL_KERNELS set(DEVICE_CONV2D_FWD_INSTANCES) list(APPEND DEVICE_CONV2D_FWD_INSTANCES device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_tf32_instance.cpp device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_tf32_instance.cpp new file mode 100644 index 00000000000..ffcdf57d220 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_tf32_instance.cpp @@ -0,0 +1,136 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#ifdef CK_ENABLE_FP32 +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; +using TF32 = ck::tf32_t; + +template +using S = ck::Sequence; + +using NHWC = ck::tensor_layout::convolution::NHWC; +using KYXC = ck::tensor_layout::convolution::KYXC; +using NHWK = ck::tensor_layout::convolution::NHWK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_tf32_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32> + // clang-format on + >; + +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_f32_tf32_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32> + // clang-format on + >; + +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f32_tf32_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32> + // clang-format on + >; + +void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_tf32_instances( + std::vector>>& instances) +{ +#if CK_BUILD_DEPRECATED +#pragma message "These instances are getting deprecated" + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_tf32_instances{}); + add_device_operation_instances( + instances, device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_f32_tf32_instances{}); + add_device_operation_instances( + instances, device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f32_tf32_instances{}); +#else +#pragma message "These instances were deprecated" + std::ignore = instances; +#endif +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif From ddfc65dde00c54add79abcd37a046a37a08920c4 Mon Sep 17 00:00:00 2001 From: yingmaolu Date: Fri, 19 Sep 2025 13:30:42 +0800 Subject: [PATCH 07/19] tf32:conv:add instances for base class DeviceGroupedConvBwdDataMultipleD --- ...d_conv_bwd_data_transpose_xdl_instance.hpp | 43 ++++++- ...ed_conv_bwd_data_xdl_bilinear_instance.hpp | 38 ++++++ ...ice_grouped_conv_bwd_data_xdl_instance.hpp | 82 ++++++++++++- ...ouped_conv_bwd_data_xdl_scale_instance.hpp | 38 ++++++ .../gpu/grouped_convolution_backward_data.hpp | 64 ++++++++-- .../grouped_convolution_backward_data_xdl.inc | 110 ++++++++++++++++++ .../grouped_conv2d_bwd_data/CMakeLists.txt | 7 ++ ...dl_gnhwc_gkyxc_gnhwk_f32_tf32_instance.cpp | 51 ++++++++ ...hw_gkcyx_ngkhw_f32_tf32_16_16_instance.cpp | 42 +++++++ ...dl_ngchw_gkcyx_ngkhw_f32_tf32_instance.cpp | 42 +++++++ ..._ngkhw_f32_tf32_vec_transpose_instance.cpp | 42 +++++++ ...dl_ngchw_gkyxc_ngkhw_f32_tf32_instance.cpp | 42 +++++++ ...gc_gkyxc_nhwgk_f32_tf32_16_16_instance.cpp | 51 ++++++++ ...dl_nhwgc_gkyxc_nhwgk_f32_tf32_instance.cpp | 51 ++++++++ ...gndhwc_gkzyxc_gndhwk_f32_tf32_instance.cpp | 52 +++++++++ ..._gkzyxc_ndhwgk_f32_tf32_16_16_instance.cpp | 51 ++++++++ ..._gkczyx_ngkdhw_f32_tf32_16_16_instance.cpp | 42 +++++++ ...ngkdhw_f32_tf32_vec_transpose_instance.cpp | 42 +++++++ ...ngcdhw_gkzyxc_ngkdhw_f32_tf32_instance.cpp | 42 +++++++ ...ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp | 51 ++++++++ 20 files changed, 968 insertions(+), 15 deletions(-) create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_tf32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f32_tf32_16_16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f32_tf32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f32_tf32_vec_transpose_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkyxc_ngkhw_f32_tf32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_16_16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_f32_tf32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_16_16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_16_16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_vec_transpose_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkzyxc_ngkdhw_f32_tf32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_bilinear/xdl/device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_transpose_xdl_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_transpose_xdl_instance.hpp index e535ba0170a..48b76ed8d83 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_transpose_xdl_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_transpose_xdl_instance.hpp @@ -18,6 +18,7 @@ namespace instance { using BF16 = ck::bhalf_t; using F16 = ck::half_t; using F32 = float; +using TF32 = ck::tf32_t; using BF8 = ck::bf8_t; using F8 = ck::f8_t; @@ -84,17 +85,17 @@ using device_grouped_conv_bwd_data_transpose_xdl_bf16_instances = DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, make_default_loop_scheduler(), BF16, BF16, 2, 2>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, make_default_loop_scheduler(), BF16, BF16, 2, 2>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, make_default_loop_scheduler(), BF16, BF16, 2, 2>, - + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, make_default_loop_scheduler(), BF16, BF16, 4, 4>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, make_default_loop_scheduler(), BF16, BF16, 4, 4>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, make_default_loop_scheduler(), BF16, BF16, 4, 4>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, make_default_loop_scheduler(), BF16, BF16, 4, 4>, - + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, make_default_loop_scheduler(), BF16, BF16, 1, 2>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, make_default_loop_scheduler(), BF16, BF16, 1, 2>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, make_default_loop_scheduler(), BF16, BF16, 1, 2>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, make_default_loop_scheduler(), BF16, BF16, 1, 2>, - + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, make_default_loop_scheduler(), BF16, BF16, 2, 1>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, make_default_loop_scheduler(), BF16, BF16, 2, 1>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, make_default_loop_scheduler(), BF16, BF16, 2, 1>, @@ -138,6 +139,42 @@ using device_grouped_conv_bwd_data_transpose_xdl_f32_instances = // clang-format on >; +// f32_f32_f32_f32 tf32 +template +using device_grouped_conv_bwd_data_transpose_xdl_f32_tf32_instances = + std::tuple< + // clang-format off + // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| LoopSched| AComputeType| BComputeType| MaxTranspose| MaxTranspose| + // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| | | | TransferIn| TransferOut| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| | | | ScalarPer| ScalarPer| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Vector| Vector| + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4, make_default_loop_scheduler(), TF32, TF32, 2, 2>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 4>, 4, make_default_loop_scheduler(), TF32, TF32, 2, 2>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 4>, 4, make_default_loop_scheduler(), TF32, TF32, 2, 2>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 4>, 4, make_default_loop_scheduler(), TF32, TF32, 2, 2>, + + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4, make_default_loop_scheduler(), TF32, TF32, 4, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 4>, 4, make_default_loop_scheduler(), TF32, TF32, 4, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 4>, 4, make_default_loop_scheduler(), TF32, TF32, 4, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 4>, 4, make_default_loop_scheduler(), TF32, TF32, 4, 4>, + + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4, make_default_loop_scheduler(), TF32, TF32, 1, 2>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 4>, 4, make_default_loop_scheduler(), TF32, TF32, 1, 2>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 4>, 4, make_default_loop_scheduler(), TF32, TF32, 1, 2>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 4>, 4, make_default_loop_scheduler(), TF32, TF32, 1, 2>, + + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4, make_default_loop_scheduler(), TF32, TF32, 2, 1>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 4>, 4, make_default_loop_scheduler(), TF32, TF32, 2, 1>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 4>, 4, make_default_loop_scheduler(), TF32, TF32, 2, 1>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 4>, 4, make_default_loop_scheduler(), TF32, TF32, 2, 1> + // clang-format on + >; + } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_bilinear_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_bilinear_instance.hpp index 216b4e2fe7c..e5f1bdc3e7d 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_bilinear_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_bilinear_instance.hpp @@ -18,6 +18,7 @@ namespace instance { using BF16 = ck::bhalf_t; using F16 = ck::half_t; using F32 = float; +using TF32 = ck::tf32_t; template using S = ck::Sequence; @@ -143,6 +144,43 @@ using device_grouped_conv_bwd_data_xdl_bilinear_f32_instances = // clang-format on >; +// f32_f32_f32_f32 tf32 +template +using device_grouped_conv_bwd_data_xdl_bilinear_f32_tf32_instances = + std::tuple< + // clang-format off + // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| + // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, make_default_loop_scheduler(), TF32, TF32>, + // instances for small conv.K and conv.C + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8> , 4, make_default_loop_scheduler(), TF32, TF32>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 32, 1, 4>, 1, make_default_loop_scheduler(), TF32, TF32>, + + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8> , 4, make_default_loop_scheduler(), TF32, TF32>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8> , 4, make_default_loop_scheduler(), TF32, TF32>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8> , 4, make_default_loop_scheduler(), TF32, TF32>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 8> , 4, make_default_loop_scheduler(), TF32, TF32>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8> , 4, make_default_loop_scheduler(), TF32, TF32>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 4> , 4, make_default_loop_scheduler(), TF32, TF32>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 8> , 4, make_default_loop_scheduler(), TF32, TF32>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 4> , 4, make_default_loop_scheduler(), TF32, TF32>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8> , 4, make_default_loop_scheduler(), TF32, TF32>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8> , 4, make_default_loop_scheduler(), TF32, TF32>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 4> , 4, make_default_loop_scheduler(), TF32, TF32>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 8> , 4, make_default_loop_scheduler(), TF32, TF32>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 4> , 4, make_default_loop_scheduler(), TF32, TF32>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4> , 8, make_default_loop_scheduler(), TF32, TF32> + // clang-format on + >; + } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp index 11a8ff8e91f..064666c8641 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp @@ -18,6 +18,7 @@ namespace instance { using BF16 = ck::bhalf_t; using F16 = ck::half_t; using F32 = float; +using TF32 = ck::tf32_t; using BF8 = ck::bf8_t; using F8 = ck::f8_t; @@ -127,13 +128,13 @@ using device_grouped_conv_bwd_data_xdl_f16_nchw_instances = // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // generic instance DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 8, 1, 8>, 1>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 8, 1, 8>, 1>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 16, 1, 8>, 4>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 4>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 8>, 4>, @@ -293,6 +294,83 @@ using device_grouped_conv_bwd_data_xdl_f32_instances = // clang-format on >; +// f32_f32_f32_f32 tf32 +template +using device_grouped_conv_bwd_data_xdl_f32_tf32_generic_instances = + std::tuple< + // clang-format off + // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| + // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, make_default_loop_scheduler(), TF32, TF32> + // clang-format on + >; + +template +using device_grouped_conv_bwd_data_xdl_f32_tf32_16_16_instances = + std::tuple< + // clang-format off + // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| + // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 16, 64, 32, 8, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, 4, make_default_loop_scheduler(), TF32, TF32>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 16, 64, 32, 8, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, make_default_loop_scheduler(), TF32, TF32>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 16, 64, 32, 8, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, 4, make_default_loop_scheduler(), TF32, TF32>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 16, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, 4, make_default_loop_scheduler(), TF32, TF32>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 16, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, make_default_loop_scheduler(), TF32, TF32>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 16, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, 4, make_default_loop_scheduler(), TF32, TF32> + // clang-format on + >; + +template +using device_grouped_conv_bwd_data_xdl_f32_tf32_instances = + std::tuple< + // clang-format off + // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| + // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, make_default_loop_scheduler(), TF32, TF32>, + // instances for small conv.K and conv.C + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4, make_default_loop_scheduler(), TF32, TF32>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 32, 1, 4>, 1, make_default_loop_scheduler(), TF32, TF32>, + + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 4, make_default_loop_scheduler(), TF32, TF32>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4, make_default_loop_scheduler(), TF32, TF32>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4, make_default_loop_scheduler(), TF32, TF32>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4, make_default_loop_scheduler(), TF32, TF32>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4, make_default_loop_scheduler(), TF32, TF32>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 4>, 4, make_default_loop_scheduler(), TF32, TF32>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4, make_default_loop_scheduler(), TF32, TF32>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 4>, 4, make_default_loop_scheduler(), TF32, TF32>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4, make_default_loop_scheduler(), TF32, TF32>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4, make_default_loop_scheduler(), TF32, TF32>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 4>, 4, make_default_loop_scheduler(), TF32, TF32>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4, make_default_loop_scheduler(), TF32, TF32>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 4>, 4, make_default_loop_scheduler(), TF32, TF32>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, 4, make_default_loop_scheduler(), TF32, TF32> + // clang-format on + >; + // f16_f16_f16_comp_f8 template using S = ck::Sequence; @@ -143,6 +144,43 @@ using device_grouped_conv_bwd_data_xdl_scale_f32_instances = // clang-format on >; +// f32_f32_f32_f32 tf32 +template +using device_grouped_conv_bwd_data_xdl_scale_f32_tf32_instances = + std::tuple< + // clang-format off + // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| + // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1>, + // instances for small conv.K and conv.C + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 32, 1, 4>, 1>, + + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp index e9ff75a91d0..b16f6f2754c 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp @@ -87,7 +87,16 @@ struct DeviceOperationInstanceFactory< is_same_v && is_same_v && is_same_v) { - add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances(op_ptrs); + if(is_same_v && is_same_v) + { + add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_tf32_instances( + op_ptrs); + } + else + { + add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances( + op_ptrs); + } } #endif #ifdef CK_ENABLE_BF16 @@ -118,9 +127,21 @@ struct DeviceOperationInstanceFactory< is_same_v && is_same_v && is_same_v) { - add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(op_ptrs); - add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_16_16_instances( - op_ptrs); + + if(is_same_v && is_same_v) + { + add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_tf32_instances( + op_ptrs); + add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_tf32_16_16_instances( + op_ptrs); + } + else + { + add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances( + op_ptrs); + add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_16_16_instances( + op_ptrs); + } } #endif #ifdef CK_ENABLE_BF16 @@ -151,7 +172,16 @@ struct DeviceOperationInstanceFactory< is_same_v && is_same_v && is_same_v) { - add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkyxc_ngchw_f32_instances(op_ptrs); + if(is_same_v && is_same_v) + { + add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkyxc_ngchw_f32_tf32_instances( + op_ptrs); + } + else + { + add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkyxc_ngchw_f32_instances( + op_ptrs); + } } #endif #ifdef CK_ENABLE_BF16 @@ -184,11 +214,25 @@ struct DeviceOperationInstanceFactory< is_same_v && is_same_v && is_same_v) { - add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_instances(op_ptrs); - add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_16_16_instances( - op_ptrs); - add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_vec_transpose_instances( - op_ptrs); + + if(is_same_v && is_same_v) + { + add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_tf32_instances( + op_ptrs); + add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_tf32_16_16_instances( + op_ptrs); + add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_tf32_vec_transpose_instances( + op_ptrs); + } + else + { + add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_instances( + op_ptrs); + add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_16_16_instances( + op_ptrs); + add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_vec_transpose_instances( + op_ptrs); + } } #endif #ifdef CK_ENABLE_BF16 diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_xdl.inc index c723be0db8a..be2cbff9e4c 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_xdl.inc @@ -38,6 +38,21 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances( PassThrough, PassThrough, PassThrough>>>& instances); +void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_tf32_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_BF16 void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances( @@ -112,6 +127,38 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_16_16_instance PassThrough, PassThrough, PassThrough>>>& instances); + +void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_tf32_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_tf32_16_16_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_BF16 void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances( @@ -172,6 +219,21 @@ void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkyxc_ngchw_f32_instances( PassThrough, PassThrough, PassThrough>>>& instances); +void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkyxc_ngchw_f32_tf32_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_BF16 void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkyxc_ngchw_bf16_instances( @@ -274,6 +336,54 @@ void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_vec_transpose_ PassThrough, PassThrough, PassThrough>>>& instances); + +void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_tf32_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_tf32_16_16_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_tf32_vec_transpose_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_BF16 void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_bf16_instances( diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt index 0ef09c55eee..6b8df51e491 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt @@ -4,24 +4,31 @@ add_instance_library( xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp + xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_tf32_instance.cpp xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp + xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instance.cpp xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_16_16_instance.cpp xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_16_16_instance.cpp xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_16_16_instance.cpp + xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_16_16_instance.cpp xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkyxc_ngkhw_f16_instance.cpp xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkyxc_ngkhw_bf16_instance.cpp xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkyxc_ngkhw_f32_instance.cpp + xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkyxc_ngkhw_f32_tf32_instance.cpp xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f16_16_16_instance.cpp xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_bf16_16_16_instance.cpp xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f32_16_16_instance.cpp + xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f32_tf32_16_16_instance.cpp xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f16_instance.cpp xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_bf16_instance.cpp xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f32_instance.cpp + xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f32_tf32_instance.cpp xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f16_vec_transpose_instance.cpp xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_bf16_vec_transpose_instance.cpp xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f32_vec_transpose_instance.cpp + xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f32_tf32_vec_transpose_instance.cpp wmma/device_grouped_conv2d_bwd_data_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instance.cpp wmma/device_grouped_conv2d_bwd_data_wmma_nhwgc_gkyxc_nhwgk_f16_1x1s1p0_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_tf32_instance.cpp new file mode 100644 index 00000000000..4bfd07f60d0 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_tf32_instance.cpp @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_tf32_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_f32_tf32_16_16_instances<2, + GNHWK, + GKYXC, + Empty_Tuple, + GNHWC, + ConvBwdDataDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances(instances, + device_grouped_conv_bwd_data_xdl_f32_tf32_16_16_instances< + 2, + GNHWK, + GKYXC, + Empty_Tuple, + GNHWC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f32_tf32_16_16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f32_tf32_16_16_instance.cpp new file mode 100644 index 00000000000..5fe6268a091 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f32_tf32_16_16_instance.cpp @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_transpose_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_tf32_16_16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_f32_tf32_16_16_instances<2, + NGKHW, + GKCYX, + Empty_Tuple, + NGCHW, + ConvBwdDataDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f32_tf32_instance.cpp new file mode 100644 index 00000000000..cc103cd4f13 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f32_tf32_instance.cpp @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_transpose_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_tf32_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_f32_tf32_instances<2, + NGKHW, + GKCYX, + Empty_Tuple, + NGCHW, + ConvBwdDataDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f32_tf32_vec_transpose_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f32_tf32_vec_transpose_instance.cpp new file mode 100644 index 00000000000..93f217ca054 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f32_tf32_vec_transpose_instance.cpp @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_transpose_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_tf32_vec_transpose_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_transpose_xdl_f32_tf32_instances<2, + NGKHW, + GKCYX, + Empty_Tuple, + NGCHW, + ConvBwdDataDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkyxc_ngkhw_f32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkyxc_ngkhw_f32_tf32_instance.cpp new file mode 100644 index 00000000000..6af5fc7fbba --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkyxc_ngkhw_f32_tf32_instance.cpp @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_transpose_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkyxc_ngchw_f32_tf32_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_f32_tf32_generic_instances<2, + NGKHW, + GKYXC, + Empty_Tuple, + NGCHW, + ConvBwdDataDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_16_16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_16_16_instance.cpp new file mode 100644 index 00000000000..9dafbfe77d3 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_16_16_instance.cpp @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_tf32_16_16_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_f32_tf32_16_16_instances<2, + NHWGK, + GKYXC, + Empty_Tuple, + NHWGC, + ConvBwdDataDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances(instances, + device_grouped_conv_bwd_data_xdl_f32_tf32_16_16_instances< + 2, + NHWGK, + GKYXC, + Empty_Tuple, + NHWGC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instance.cpp new file mode 100644 index 00000000000..56dc9222d48 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instance.cpp @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_tf32_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_f32_tf32_instances<2, + NHWGK, + GKYXC, + Empty_Tuple, + NHWGC, + ConvBwdDataDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_f32_tf32_instances<2, + NHWGK, + GKYXC, + Empty_Tuple, + NHWGC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_f32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_f32_tf32_instance.cpp new file mode 100644 index 00000000000..c9223e42ea4 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_f32_tf32_instance.cpp @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// wo, k] +void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_tf32_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_f32_tf32_16_16_instances<3, + GNDHWK, + GKZYXC, + Empty_Tuple, + GNDHWC, + ConvBwdDataDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances(instances, + device_grouped_conv_bwd_data_xdl_f32_tf32_16_16_instances< + 3, + GNDHWK, + GKZYXC, + Empty_Tuple, + GNDHWC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_16_16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_16_16_instance.cpp new file mode 100644 index 00000000000..63e90333a96 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_16_16_instance.cpp @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_tf32_16_16_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_f32_tf32_16_16_instances<3, + NDHWGK, + GKZYXC, + Empty_Tuple, + NDHWGC, + ConvBwdDataDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances(instances, + device_grouped_conv_bwd_data_xdl_f32_tf32_16_16_instances< + 3, + NDHWGK, + GKZYXC, + Empty_Tuple, + NDHWGC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_16_16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_16_16_instance.cpp new file mode 100644 index 00000000000..cea4aac2ff2 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_16_16_instance.cpp @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_transpose_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f32_tf32_16_16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_f32_tf32_16_16_instances<3, + NGKDHW, + GKCZYX, + Empty_Tuple, + NGCDHW, + ConvBwdDataDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_vec_transpose_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_vec_transpose_instance.cpp new file mode 100644 index 00000000000..4b12fece2d1 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_vec_transpose_instance.cpp @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_transpose_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f32_tf32_vec_transpose_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_transpose_xdl_f32_tf32_instances<3, + NGKDHW, + GKCZYX, + Empty_Tuple, + NGCDHW, + ConvBwdDataDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkzyxc_ngkdhw_f32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkzyxc_ngkdhw_f32_tf32_instance.cpp new file mode 100644 index 00000000000..39bcb567be0 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkzyxc_ngkdhw_f32_tf32_instance.cpp @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_transpose_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkzyxc_ngcdhw_f32_tf32_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_f32_tf32_generic_instances<3, + NGKDHW, + GKZYXC, + Empty_Tuple, + NGCDHW, + ConvBwdDataDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_bilinear/xdl/device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_bilinear/xdl/device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp new file mode 100644 index 00000000000..12b36b77ca0 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_bilinear/xdl/device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_bilinear_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_f32_tf32_instances( + std::vector, + NDHWGC, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + Bilinear, + TF32, + TF32>>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_bilinear_f32_tf32_instances<3, + NDHWGK, + GKZYXC, + Tuple, + NDHWGC, + ConvBwdDataDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances(instances, + device_grouped_conv_bwd_data_xdl_bilinear_f32_tf32_instances< + 3, + NDHWGK, + GKZYXC, + Tuple, + NDHWGC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck From de9a5507fa76f999c2e7fa167682d1eec78f218c Mon Sep 17 00:00:00 2001 From: yingmaolu Date: Mon, 22 Sep 2025 10:42:55 +0800 Subject: [PATCH 08/19] tf32:conv:add instances for base class DeviceGroupedConvBwdWeight --- .../grid/gridwise_gemm_xdlops_bwd_weight.hpp | 18 +- ...rouped_conv_bwd_weight_v3_xdl_instance.hpp | 21 +- ..._conv_bwd_weight_xdl_bilinear_instance.hpp | 114 +++++--- ...e_grouped_conv_bwd_weight_xdl_instance.hpp | 157 ++++++---- ...ped_conv_bwd_weight_xdl_scale_instance.hpp | 112 +++++--- .../grouped_convolution_backward_weight.hpp | 216 ++++++++++---- ...d_convolution_backward_weight_bilinear.hpp | 27 +- ...uped_convolution_backward_weight_scale.hpp | 29 +- ...rouped_convolution_backward_weight_xdl.inc | 268 +++++++++++++++++- .../grouped_conv1d_bwd_weight/CMakeLists.txt | 1 + ...t_xdl_gnwc_gkxc_gnwk_f32_tf32_instance.cpp | 47 +++ .../grouped_conv2d_bwd_weight/CMakeLists.txt | 12 +- ...gnhwk_f32_tf32_default_pipev1_instance.cpp | 42 +++ ...dl_gnhwc_gkyxc_gnhwk_f32_tf32_instance.cpp | 50 ++++ ...xc_gnhwk_f32_tf32_pad0_pipev1_instance.cpp | 42 +++ ...dl_ngchw_gkcyx_ngkhw_f32_tf32_instance.cpp | 51 ++++ ...dl_ngchw_gkyxc_ngkhw_f32_tf32_instance.cpp | 41 +++ ...nhwgk_f32_tf32_default_pipev2_instance.cpp | 42 +++ ...nhwgk_f32_tf32_default_pipev5_instance.cpp | 42 +++ ...dl_nhwgc_gkyxc_nhwgk_f32_tf32_instance.cpp | 48 ++++ ...xc_nhwgk_f32_tf32_pad0_pipev2_instance.cpp | 42 +++ ...xc_nhwgk_f32_tf32_pad0_pipev5_instance.cpp | 42 +++ .../grouped_conv3d_bwd_weight/CMakeLists.txt | 8 + ...gndhwc_gkzyxc_gndhwk_f32_tf32_instance.cpp | 49 ++++ ...dhwgk_f32_tf32_default_pipev2_instance.cpp | 42 +++ ...dhwgk_f32_tf32_default_pipev5_instance.cpp | 42 +++ ...ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp | 48 ++++ ...c_ndhwgk_f32_tf32_pad0_pipev2_instance.cpp | 42 +++ ...c_ndhwgk_f32_tf32_pad0_pipev5_instance.cpp | 42 +++ ...ngcdhw_gkczyx_ngkdhw_f32_tf32_instance.cpp | 51 ++++ ...ngcdhw_gkzyxc_ngkdhw_f32_tf32_instance.cpp | 41 +++ .../CMakeLists.txt | 1 + ...ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp | 52 ++++ .../CMakeLists.txt | 1 + ...ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp | 52 ++++ 35 files changed, 1726 insertions(+), 209 deletions(-) create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_tf32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_tf32_default_pipev1_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_tf32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_tf32_pad0_pipev1_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_xdl_ngchw_gkcyx_ngkhw_f32_tf32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/ngchw_gkyxc_ngkhw/device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_f32_tf32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_default_pipev2_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_default_pipev5_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_pad0_pipev2_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_pad0_pipev5_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/gndhwc_gkzyxc_gndhwk/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_tf32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_default_pipev2_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_default_pipev5_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_pad0_pipev2_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_pad0_pipev5_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ngcdhw_gkzyxc_ngkdhw/device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_f32_tf32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_bilinear/xdl/device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_scale/xdl/device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp index abb8c52e0f5..cb841c36eaa 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp @@ -280,8 +280,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight using FloatBAdjusted = conditional_t, ck::bhalf_t, ComputeTypeB>; #else - using FloatAAdjusted = ComputeTypeA; - using FloatBAdjusted = ComputeTypeB; + using FloatAAdjusted = conditional_t, float, ComputeTypeA>; + using FloatBAdjusted = conditional_t, float, ComputeTypeB>; #endif // M0/M1/M1Padding @@ -760,19 +760,19 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight // register // sanity check constexpr bool is_single_rate_mfma = - (((is_same::value || is_same::value) && + (((is_same::value || is_same::value) && K1 <= 4) || - (is_same::value && K1 <= 8) || - ((is_same::value || is_same::value) && + (is_same::value && K1 <= 8) || + ((is_same::value || is_same::value) && K1 < 32)) ? true : false; constexpr auto is_scale_mfma = false; constexpr index_t KPack = math::max(K1, - MfmaSelector::selected_mfma.k_per_blk); @@ -787,7 +787,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight NPerXdl, MRepeat, NRepeat, - KPack>{}; + KPack, + ComputeTypeA, + ComputeTypeB>{}; auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp index b445e0001d9..114a6cab35d 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp @@ -18,6 +18,7 @@ using namespace ck::tensor_layout::convolution; using BF16 = ck::bhalf_t; using F16 = ck::half_t; using F32 = float; +using TF32 = ck::tf32_t; #ifdef CK_ENABLE_FP8 using F8 = ck::f8_t; @@ -58,6 +59,24 @@ using device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f32_instances = std::tuple // clang-format on >; +template +using device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f32_tf32_instances = std::tuple< + // clang-format off + //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| + //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| + //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | + // generic instance + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 16, 32, 8, 16, 16, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion, TF32, TF32> + // clang-format on + >; + template , S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>, DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 128, 32, 8, 32, 32, 1, 4, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>, DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 32, 8, 32, 32, 2, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>, - DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 32, 8, 32, 32, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>, + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 32, 8, 32, 32, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>, DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 80, 32, 8, 16, 16, 4, 5, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 5, 4, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>, DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 112, 32, 8, 16, 16, 4, 7, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 7, 4, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion> // clang-format on diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_bilinear_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_bilinear_instance.hpp index 8b830d91d54..a81ec510819 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_bilinear_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_bilinear_instance.hpp @@ -18,6 +18,7 @@ using namespace ck::tensor_layout::convolution; using BF16 = ck::bhalf_t; using F16 = ck::half_t; using F32 = float; +using TF32 = ck::tf32_t; #ifdef CK_ENABLE_FP8 using F8 = ck::f8_t; @@ -74,6 +75,39 @@ using device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_bilinear_instances = std: // clang-format on >; +template +using device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_tf32_bilinear_instances = std::tuple< + // clang-format off + //#########################################| Num| InLayout| WeiLayout| OutLayout| DsData| InData| WeiData| OutData| AccData| DsData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| + //#########################################| Dim| | | | Layout| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| + //#########################################| Spatial| | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + // generic instance + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F32, F32, F32, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, 1, 1, S<1, 16, 1, 4>, 1, TF32, TF32>, + // instances for small conv.K and conv.C + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F32, F32, F32, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 1, true, 1, 1, S<1, 32, 1, 4>, 1, TF32, TF32>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F32, F32, F32, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 2, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4, TF32, TF32>, + + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F32, F32, F32, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 8>, 4, TF32, TF32>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F32, F32, F32, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4, TF32, TF32>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F32, F32, F32, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4, TF32, TF32>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F32, F32, F32, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4, TF32, TF32>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F32, F32, F32, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4, TF32, TF32>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F32, F32, F32, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4, TF32, TF32>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F32, F32, F32, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4, TF32, TF32>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F32, F32, F32, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4, TF32, TF32>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F32, F32, F32, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4, TF32, TF32>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F32, F32, F32, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4, TF32, TF32>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F32, F32, F32, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4, TF32, TF32>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F32, F32, F32, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 16, 1, 4>, 4, TF32, TF32>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F32, F32, F32, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4, TF32, TF32> + // clang-format on + >; + template , F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, 1, 1, S<1, 16, 1, 4>, 1>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, 1, 1, S<1, 16, 1, 4>, 2>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, 1, 1, S<1, 16, 1, 4>, 2>, // instance for small conv.K // for fp16 conv.K and conv.C must be divisible by 2 // since half_t atomic_add require scalar_per_x_vector % 2 == 0 - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 1, true, 1, 1, S<1, 32, 1, 4>, 2>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 1, true, 1, 1, S<1, 32, 1, 4>, 2>, DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8>, DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8>, DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8> // clang-format on >; @@ -122,23 +156,23 @@ using device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_bilinear_instances = std //#########################################| Spatial| | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | // generic instance - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, 1, 1, S<1, 16, 1, 4>, 1>, // instance for small conv.K - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 1, true, 1, 1, S<1, 32, 1, 4>, 1>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 1, true, 1, 1, S<1, 32, 1, 4>, 1>, DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 4>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 4>, DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4> // clang-format on >; @@ -156,25 +190,25 @@ using device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_comp_bf8_f8_bilinear_inst //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | #if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 // generic instance - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, 1, 1, S<1, 16, 1, 4>, 2, BF8, F8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, 1, 1, S<1, 16, 1, 4>, 2, BF8, F8>, // instance for small conv.K // for fp16 conv.K and conv.C must be divisible by 2 // since half_t atomic_add require scalar_per_x_vector % 2 == 0 - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 1, true, 1, 1, S<1, 32, 1, 4>, 2, BF8, F8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 1, true, 1, 1, S<1, 32, 1, 4>, 2, BF8, F8>, DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8>, DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8>, DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8> #endif // clang-format on diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp index 3587570e427..10f4f3b69a0 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp @@ -18,6 +18,7 @@ using namespace ck::tensor_layout::convolution; using BF16 = ck::bhalf_t; using F16 = ck::half_t; using F32 = float; +using TF32 = ck::tf32_t; #ifdef CK_ENABLE_FP8 using F8 = ck::f8_t; @@ -96,6 +97,62 @@ using device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_instances = std::tuple< // clang-format on >; +template +using device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_tf32_generic_instances = std::tuple< + // clang-format off + //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| + //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| + //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + // generic instance + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, 1, 1, S<1, 16, 1, 4>, 1, TF32, TF32> + // clang-format on + >; + +template +using device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_tf32_instances = std::tuple< + // clang-format off + //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| Compute| Compute| MaxTranspose| MaxTranspose| + //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| TypeA| TypeB| TransferSrc| TransferDst| + //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| | | ScalarPerVector| ScalarPerVector| + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | | + // generic instance + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, 1, 1, S<1, 16, 1, 4>, 1, TF32, TF32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + // instances for small conv.K and conv.C + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 1, true, 1, 1, S<1, 32, 1, 4>, 1, TF32, TF32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 2, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4, TF32, TF32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 8>, 4, TF32, TF32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4, TF32, TF32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4, TF32, TF32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4, TF32, TF32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4, TF32, TF32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4, TF32, TF32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4, TF32, TF32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4, TF32, TF32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4, TF32, TF32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4, TF32, TF32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4, TF32, TF32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 16, 1, 4>, 4, TF32, TF32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4, TF32, TF32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 8, 8, 32, 32, 1, 1, S<1, 8, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, 2, 4, 4, true, S<1, 8, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 8, 8, 32, 32, 1, 1, S<1, 8, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, 2, 4, 4, true, S<1, 8, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, 2, 1, 4, true, 1, 1, S<1, 4, 1, 64>, 1, TF32, TF32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 8, 8, 32, 32, 1, 1, S<1, 8, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, 2, 1, 4, true, S<1, 8, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 8, 8, 32, 32, 1, 1, S<1, 8, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, 2, 1, 4, true, S<1, 8, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, 2, 1, 4, true, 1, 1, S<1, 4, 1, 64>, 1, TF32, TF32, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector> + // clang-format on + >; + template , S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, 1, 1, S<1, 16, 1, 4>, 2, F16, F16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, 1, 1, S<1, 16, 1, 4>, 2, F16, F16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, // instance for small conv.K // for fp16 conv.K and conv.C must be divisible by 2 // since half_t atomic_add require scalar_per_x_vector % 2 == 0 - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 1, true, 1, 1, S<1, 32, 1, 4>, 2, F16, F16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 1, true, 1, 1, S<1, 32, 1, 4>, 2, F16, F16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, F16, F16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 8, F16, F16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8, F16, F16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, F16, F16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, F16, F16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, F16, F16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8, F16, F16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, F16, F16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, F16, F16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, F16, F16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, F16, F16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, F16, F16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8, F16, F16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, F16, F16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8, F16, F16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, F16, F16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8, F16, F16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, F16, F16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8, F16, F16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8, F16, F16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, F16, F16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8, F16, F16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, F16, F16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector> // clang-format on >; @@ -179,23 +236,23 @@ using device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_f32_bf16_instances = std //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| | | ScalarPerVector| ScalarPerVector| //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | | // generic instance - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, // instance for small conv.K - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 1, true, 1, 1, S<1, 32, 1, 4>, 1, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 1, true, 1, 1, S<1, 32, 1, 4>, 1, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 4, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 4, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 4, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 4, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 4, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector> // clang-format on >; @@ -214,25 +271,25 @@ using device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_instances = std::tuple< //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| | | ScalarPerVector| ScalarPerVector| //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | | // generic instance - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, 1, 1, S<1, 16, 1, 4>, 2, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, 1, 1, S<1, 16, 1, 4>, 2, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, // instance for small conv.K // for bf16 conv.K and conv.C must be divisible by 2 // since half_t atomic_add require scalar_per_x_vector % 2 == 0 - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 1, true, 1, 1, S<1, 32, 1, 4>, 2, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 1, true, 1, 1, S<1, 32, 1, 4>, 2, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 8, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector> // clang-format on >; @@ -252,25 +309,25 @@ using device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_comp_bf8_f8_instances = s //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | | #if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 // generic instance - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, 1, 1, S<1, 16, 1, 4>, 2, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, 1, 1, S<1, 16, 1, 4>, 2, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, // instance for small conv.K // for fp16 conv.K and conv.C must be divisible by 2 // since half_t atomic_add require scalar_per_x_vector % 2 == 0 - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 1, true, 1, 1, S<1, 32, 1, 4>, 2, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 1, true, 1, 1, S<1, 32, 1, 4>, 2, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector> #endif // clang-format on diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_scale_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_scale_instance.hpp index dc4c8fa8048..dc365c4fdcf 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_scale_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_scale_instance.hpp @@ -18,6 +18,7 @@ using namespace ck::tensor_layout::convolution; using BF16 = ck::bhalf_t; using F16 = ck::half_t; using F32 = float; +using TF32 = ck::tf32_t; #ifdef CK_ENABLE_FP8 using F8 = ck::f8_t; @@ -74,6 +75,39 @@ using device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_scale_instances = std::tu // clang-format on >; +template +using device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_tf32_scale_instances = std::tuple< + // clang-format off + //#########################################| Num| InLayout| WeiLayout| OutLayout| DsData| InData| WeiData| OutData| AccData| DsData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| + //#########################################| Dim| | | | Layout| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| + //#########################################| Spatial| | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + // generic instance + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F32, F32, F32, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, 1, 1, S<1, 16, 1, 4>, 1, TF32, TF32>, + // instances for small conv.K and conv.C + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F32, F32, F32, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 1, true, 1, 1, S<1, 32, 1, 4>, 1, TF32, TF32>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F32, F32, F32, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 2, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4, TF32, TF32>, + + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F32, F32, F32, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 8>, 4, TF32, TF32>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F32, F32, F32, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 64, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4, TF32, TF32>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F32, F32, F32, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4, TF32, TF32>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F32, F32, F32, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4, TF32, TF32>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F32, F32, F32, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4, TF32, TF32>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F32, F32, F32, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4, TF32, TF32>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F32, F32, F32, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4, TF32, TF32>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F32, F32, F32, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4, TF32, TF32>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F32, F32, F32, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 32, 1, 4>, 4, TF32, TF32>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F32, F32, F32, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, 1, 1, S<1, 32, 1, 4>, 4, TF32, TF32>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F32, F32, F32, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 1, true, S<1, 4, 32, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4, TF32, TF32>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F32, F32, F32, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, 1, 1, S<1, 16, 1, 4>, 4, TF32, TF32>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F32, F32, F32, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 2, true, S<1, 4, 16, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4, TF32, TF32> + // clang-format on + >; + template , S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, 1, 1, S<1, 16, 1, 4>, 2>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, 1, 1, S<1, 16, 1, 4>, 2>, // instance for small conv.K // for fp16 conv.K and conv.C must be divisible by 2 // since half_t atomic_add require scalar_per_x_vector % 2 == 0 - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 1, true, 1, 1, S<1, 32, 1, 4>, 2>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 1, true, 1, 1, S<1, 32, 1, 4>, 2>, DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8>, DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8>, DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8> // clang-format on >; @@ -121,23 +155,23 @@ using device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_scale_instances = std::t //#########################################| Spatial| | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | // generic instance - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, 1, 1, S<1, 16, 1, 4>, 1>, // instance for small conv.K - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 1, true, 1, 1, S<1, 32, 1, 4>, 1>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 1, true, 1, 1, S<1, 32, 1, 4>, 1>, DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 4>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 4>, DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4> // clang-format on >; @@ -155,25 +189,25 @@ using device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_comp_bf8_f8_scale_instanc //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | #if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 // generic instance - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, 1, 1, S<1, 16, 1, 4>, 2, BF8, F8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, 1, 1, S<1, 16, 1, 4>, 2, BF8, F8>, // instance for small conv.K // for fp16 conv.K and conv.C must be divisible by 2 // since half_t atomic_add require scalar_per_x_vector % 2 == 0 - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 1, true, 1, 1, S<1, 32, 1, 4>, 2, BF8, F8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 1, true, 1, 1, S<1, 32, 1, 4>, 2, BF8, F8>, DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8>, DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8>, DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8> #endif // clang-format on diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp index 3c0784eef3a..b05f70d42f4 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp @@ -277,7 +277,18 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { - add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances(op_ptrs); + static_assert(is_same_v, + "Error: ComputeTypeA and ComputeTypeB should be the same"); + if constexpr(is_same_v) + { + add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_tf32_instances( + op_ptrs); + } + else + { + add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances( + op_ptrs); + } } #endif #ifdef CK_ENABLE_FP16 @@ -310,13 +321,25 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { - add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances( - op_ptrs); + static_assert(is_same_v, + "Error: ComputeTypeA and ComputeTypeB should be the same"); + if constexpr(is_same_v) + { + add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_tf32_default_pipev1_instances( + op_ptrs); + add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_tf32_pad0_pipev1_instances( + op_ptrs); + } + else + { + add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances( + op_ptrs); - add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_default_pipev1_instances( - op_ptrs); - add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_pad0_pipev1_instances( - op_ptrs); + add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_default_pipev1_instances( + op_ptrs); + add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_pad0_pipev1_instances( + op_ptrs); + } } #endif #ifdef CK_ENABLE_FP16 @@ -349,20 +372,37 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v && - is_same_v) - { - add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances( - op_ptrs); + is_same_v) + { + static_assert(is_same_v, + "Error: ComputeTypeA and ComputeTypeB should be the same"); + if constexpr(is_same_v) + { + add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances( + op_ptrs); - add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_default_pipev2_instances( - op_ptrs); - add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_default_pipev5_instances( - op_ptrs); - add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_pad0_pipev2_instances( - op_ptrs); - add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_pad0_pipev5_instances( - op_ptrs); + add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_default_pipev2_instances( + op_ptrs); + add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_default_pipev5_instances( + op_ptrs); + add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_pad0_pipev2_instances( + op_ptrs); + add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_pad0_pipev5_instances( + op_ptrs); + } + else if constexpr(is_same_v) + { + add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances( + op_ptrs); + add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_default_pipev2_instances( + op_ptrs); + add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_default_pipev5_instances( + op_ptrs); + add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_pad0_pipev2_instances( + op_ptrs); + add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_pad0_pipev5_instances( + op_ptrs); + } } #endif #ifdef CK_ENABLE_FP16 @@ -514,11 +554,20 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v && - is_same_v) - { - add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkcyx_ngkhw_f32_instances( - op_ptrs); + is_same_v) + { + static_assert(is_same_v, + "Error: ComputeTypeA and ComputeTypeB should be the same"); + if constexpr(is_same_v) + { + add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkcyx_ngkhw_f32_instances( + op_ptrs); + } + else if constexpr(is_same_v) + { + add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkcyx_ngkhw_f32_tf32_instances( + op_ptrs); + } } #endif } @@ -547,11 +596,20 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v && - is_same_v) - { - add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_f32_instances( - op_ptrs); + is_same_v) + { + static_assert(is_same_v, + "Error: ComputeTypeA and ComputeTypeB should be the same"); + if constexpr(is_same_v) + { + add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_f32_instances( + op_ptrs); + } + else if constexpr(is_same_v) + { + add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_f32_tf32_instances( + op_ptrs); + } } #endif } @@ -563,11 +621,20 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v && - is_same_v) - { - add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances( - op_ptrs); + is_same_v) + { + static_assert(is_same_v, + "Error: ComputeTypeA and ComputeTypeB should be the same"); + if constexpr(is_same_v) + { + add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances( + op_ptrs); + } + else if constexpr(is_same_v) + { + add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_tf32_instances( + op_ptrs); + } } #endif #ifdef CK_ENABLE_FP16 @@ -595,20 +662,37 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v && - is_same_v) + is_same_v) { - add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( - op_ptrs); + static_assert(is_same_v, + "Error: ComputeTypeA and ComputeTypeB should be the same"); - add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_default_pipev2_instances( - op_ptrs); - add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_default_pipev5_instances( - op_ptrs); - add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_pad0_pipev2_instances( - op_ptrs); - add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_pad0_pipev2_instances( - op_ptrs); + if constexpr(is_same_v) + { + add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( + op_ptrs); + add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_default_pipev2_instances( + op_ptrs); + add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_default_pipev5_instances( + op_ptrs); + add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_pad0_pipev2_instances( + op_ptrs); + add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_pad0_pipev5_instances( + op_ptrs); + } + else if constexpr(is_same_v) + { + add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + op_ptrs); + add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_default_pipev2_instances( + op_ptrs); + add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_default_pipev5_instances( + op_ptrs); + add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_pad0_pipev2_instances( + op_ptrs); + add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_pad0_pipev5_instances( + op_ptrs); + } } #endif #ifdef CK_ENABLE_FP16 @@ -769,11 +853,20 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v && - is_same_v) - { - add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkczyx_ngkdhw_f32_instances( - op_ptrs); + is_same_v) + { + static_assert(is_same_v, + "Error: ComputeTypeA and ComputeTypeB should be the same"); + if constexpr(is_same_v) + { + add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkczyx_ngkdhw_f32_instances( + op_ptrs); + } + else if constexpr(is_same_v) + { + add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_instances( + op_ptrs); + } } #endif } @@ -802,11 +895,20 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v && - is_same_v) - { - add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_f32_instances( - op_ptrs); + is_same_v) + { + static_assert(is_same_v, + "Error: ComputeTypeA and ComputeTypeB should be the same"); + if constexpr(is_same_v) + { + add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_f32_instances( + op_ptrs); + } + else if constexpr(is_same_v) + { + add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_f32_tf32_instances( + op_ptrs); + } } #endif } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_bilinear.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_bilinear.hpp index 50b6f0b6d8a..b8904ebc5d0 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_bilinear.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_bilinear.hpp @@ -62,6 +62,21 @@ void add_device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_ PassThrough, Bilinear, PassThrough>>>& instances); +void add_device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + std::vector, + F32, + F32, + F32, + Tuple, + PassThrough, + Bilinear, + PassThrough, + TF32, + TF32>>>& instances); #endif #if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 void add_device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_f8_instances( @@ -141,8 +156,16 @@ struct DeviceOperationInstanceFactory< is_same_v && is_same_v && is_same_v) { - add_device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_instances( - op_ptrs); + if(is_same_v && is_same_v) + { + add_device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + op_ptrs); + } + else + { + add_device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_instances( + op_ptrs); + } } #endif #ifdef CK_ENABLE_FP16 diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_scale.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_scale.hpp index 89a28489203..af24943ca6e 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_scale.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_scale.hpp @@ -62,6 +62,22 @@ void add_device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_ins PassThrough, Scale, PassThrough>>>& instances); + +void add_device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + std::vector, + F32, + F32, + F32, + Tuple<>, + PassThrough, + Scale, + PassThrough, + TF32, + TF32>>>& instances); #endif #if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 void add_device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_f8_instances( @@ -141,8 +157,17 @@ struct DeviceOperationInstanceFactory< is_same_v && is_same_v && is_same_v) { - add_device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instances( - op_ptrs); + if constexpr(is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + op_ptrs); + } + else + { + add_device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instances( + op_ptrs); + } } #endif #ifdef CK_ENABLE_FP16 diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_xdl.inc index 31926ce9084..609ea925d7c 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_xdl.inc @@ -47,6 +47,19 @@ void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances( PassThrough, PassThrough, PassThrough>>>& instances); +void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_tf32_instances( + std::vector>>& instances); #endif // conv2d backward weight #ifdef CK_ENABLE_BF16 @@ -124,6 +137,20 @@ void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_default_pipe PassThrough, PassThrough>>>& instances); +void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_tf32_default_pipev1_instances( + std::vector>>& instances); + void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_pad0_pipev1_instances( std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_tf32_pad0_pipev1_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_BF16 void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instances( @@ -570,6 +611,20 @@ void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances( PassThrough, PassThrough>>>& instances); +void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances( + std::vector>>& instances); + void add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_f32_instances( std::vector>>& instances); +void add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_f32_tf32_instances( + std::vector>>& instances); + void add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkcyx_ngkhw_f32_instances( std::vector>>& instances); +void add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkcyx_ngkhw_f32_tf32_instances( + std::vector>>& instances); + void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_default_pipev2_instances( std::vector>>& instances); +void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_default_pipev2_instances( + std::vector>>& instances); + void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_default_pipev5_instances( std::vector>>& instances); +void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_default_pipev5_instances( + std::vector>>& instances); + void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_pad0_pipev2_instances( std::vector>>& instances); +void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_pad0_pipev2_instances( + std::vector>>& instances); + void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_pad0_pipev5_instances( std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_pad0_pipev5_instances( + std::vector>>& instances); #endif // conv3d backward weight #ifdef CK_ENABLE_BF16 @@ -681,6 +820,20 @@ void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances PassThrough, PassThrough, PassThrough>>>& instances); + +void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_tf32_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_BF16 void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances( @@ -1117,6 +1270,20 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_f32_instances PassThrough, PassThrough>>>& instances); +void add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_f32_tf32_instances( + std::vector>>& instances); + void add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkczyx_ngkdhw_f32_instances( std::vector>>& instances); +void add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_instances( + std::vector>>& instances); + void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( std::vector>>& instances); +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + std::vector>>& instances); + void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_default_pipev2_instances( std::vector>>& instances); +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_default_pipev2_instances( + std::vector>>& instances); + void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_default_pipev5_instances( std::vector>>& instances); +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_default_pipev5_instances( + std::vector>>& instances); + void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_pad0_pipev2_instances( std::vector>>& instances); +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_pad0_pipev2_instances( + std::vector>>& instances); -void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_pad0_pipev2_instances( +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_pad0_pipev5_instances( std::vector>>& instances); -#endif -#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 -void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_f8_instances( + +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_pad0_pipev5_instances( std::vector>>& instances); + TF32, + TF32>>>& instances) +#endif +#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 + void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_f8_instances( + std::vector>>& instances); #endif } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/CMakeLists.txt index b057e0c8d27..7f4e94da487 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/CMakeLists.txt @@ -2,6 +2,7 @@ set(GROUPED_CONV1D_BWD_WEIGHT xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instance.cpp xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instance.cpp + xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_tf32_instance.cpp xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_instance.cpp) if(DL_KERNELS) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_tf32_instance.cpp new file mode 100644 index 00000000000..258e67e3e7e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_tf32_instance.cpp @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_tf32_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_tf32_instances<1, + GNWC, + GKXC, + GNWK, + ConvBwdWeightDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances(instances, + device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_tf32_instances< + 1, + GNWC, + GKXC, + GNWK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt index 7264c4688d6..5eb7650746c 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt @@ -2,14 +2,18 @@ set(GROUPED_CONV2D_BWD_WEIGHT xdl/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp xdl/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp + xdl/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_tf32_instance.cpp xdl/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instance.cpp xdl/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_default_pipev1_instance.cpp xdl/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_pad0_pipev1_instance.cpp xdl/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_default_pipev1_instance.cpp + xdl/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_tf32_default_pipev1_instance.cpp xdl/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_pad0_pipev1_instance.cpp + xdl/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_tf32_pad0_pipev1_instance.cpp xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp + xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instance.cpp xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instance.cpp xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_default_pipev2_instance.cpp @@ -21,9 +25,13 @@ set(GROUPED_CONV2D_BWD_WEIGHT xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_pad0_pipev2_instance.cpp xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_pad0_pipev5_instance.cpp xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_default_pipev2_instance.cpp + xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_default_pipev2_instance.cpp xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_default_pipev5_instance.cpp + xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_default_pipev5_instance.cpp xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_pad0_pipev2_instance.cpp + xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_pad0_pipev2_instance.cpp xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_pad0_pipev5_instance.cpp + xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_pad0_pipev5_instance.cpp xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_instance.cpp xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instance.cpp xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev2_instance.cpp @@ -39,9 +47,10 @@ set(GROUPED_CONV2D_BWD_WEIGHT xdl/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_xdl_ngchw_gkcyx_ngkhw_f16_instance.cpp xdl/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_xdl_ngchw_gkcyx_ngkhw_f32_instance.cpp + xdl/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_xdl_ngchw_gkcyx_ngkhw_f32_tf32_instance.cpp xdl/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_xdl_ngchw_gkcyx_ngkhw_bf16_instance.cpp xdl/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_f16_pipev2_instance.cpp - xdl/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_f16_pipev5_instance.cpp + xdl/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_f16_pipev5_instance.cpp xdl/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_bf16_pipev2_instance.cpp xdl/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_bf16_pipev5_instance.cpp xdl/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_f16_pipev1_instance.cpp @@ -50,6 +59,7 @@ set(GROUPED_CONV2D_BWD_WEIGHT xdl/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_bf16_pipev1_part2_instance.cpp xdl/ngchw_gkyxc_ngkhw/device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_f32_instance.cpp + xdl/ngchw_gkyxc_ngkhw/device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_f32_tf32_instance.cpp xdl/ngchw_gkyxc_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev1_instance.cpp xdl/ngchw_gkyxc_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev1_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_tf32_default_pipev1_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_tf32_default_pipev1_instance.cpp new file mode 100644 index 00000000000..404fe4f5225 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_tf32_default_pipev1_instance.cpp @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k] +void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_tf32_default_pipev1_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f32_tf32_instances< + 2, + GNHWC, + GKYXC, + GNHWK, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_tf32_instance.cpp new file mode 100644 index 00000000000..569edd62ff6 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_tf32_instance.cpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k] +void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_tf32_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_tf32_generic_instances< + 2, + GNHWC, + GKYXC, + GNHWK, + ConvBwdWeightDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_tf32_generic_instances< + 2, + GNHWC, + GKYXC, + GNHWK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_tf32_pad0_pipev1_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_tf32_pad0_pipev1_instance.cpp new file mode 100644 index 00000000000..91bee86045c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_tf32_pad0_pipev1_instance.cpp @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k] +void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_tf32_pad0_pipev1_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f32_tf32_instances< + 2, + GNHWC, + GKYXC, + GNHWK, + ConvBwdWeightFilter1x1Stride1Pad0, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_xdl_ngchw_gkcyx_ngkhw_f32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_xdl_ngchw_gkcyx_ngkhw_f32_tf32_instance.cpp new file mode 100644 index 00000000000..9d56ee5c03a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_xdl_ngchw_gkcyx_ngkhw_f32_tf32_instance.cpp @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkcyx_ngkhw_f32_tf32_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_tf32_instances<2, + NGCHW, + GKCYX, + NGKHW, + ConvBwdWeightDefault, + 1, + 1>{}); + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_tf32_instances<2, + NGCHW, + GKCYX, + NGKHW, + ConvBwdWeightDefault, + 4, + 4>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/ngchw_gkyxc_ngkhw/device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_f32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/ngchw_gkyxc_ngkhw/device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_f32_tf32_instance.cpp new file mode 100644 index 00000000000..f39bbc7120d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/ngchw_gkyxc_ngkhw/device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_f32_tf32_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_f32_tf32_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_tf32_generic_instances< + 2, + NGCHW, + GKYXC, + NGKHW, + ConvBwdWeightDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_default_pipev2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_default_pipev2_instance.cpp new file mode 100644 index 00000000000..e3161c5ff4e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_default_pipev2_instance.cpp @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_default_pipev2_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f32_tf32_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v2>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_default_pipev5_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_default_pipev5_instance.cpp new file mode 100644 index 00000000000..65b811e068a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_default_pipev5_instance.cpp @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_default_pipev5_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f32_tf32_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v5>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instance.cpp new file mode 100644 index 00000000000..6bfbf4ee739 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instance.cpp @@ -0,0 +1,48 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_tf32_instances<2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances(instances, + device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_tf32_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_pad0_pipev2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_pad0_pipev2_instance.cpp new file mode 100644 index 00000000000..f1d425fab41 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_pad0_pipev2_instance.cpp @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_pad0_pipev2_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f32_tf32_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightFilter1x1Stride1Pad0, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v2>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_pad0_pipev5_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_pad0_pipev5_instance.cpp new file mode 100644 index 00000000000..70b1739ee10 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_pad0_pipev5_instance.cpp @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_pad0_pipev5_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f32_tf32_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightFilter1x1Stride1Pad0, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v5>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt index 5574cf82f9f..301641ffde3 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt @@ -2,10 +2,12 @@ set(GROUPED_CONV3D_BWD_WEIGHT xdl/gndhwc_gkzyxc_gndhwk/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp xdl/gndhwc_gkzyxc_gndhwk/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp + xdl/gndhwc_gkzyxc_gndhwk/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_tf32_instance.cpp xdl/gndhwc_gkzyxc_gndhwk/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instance.cpp xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp + xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instance.cpp xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_default_pipev2_instance.cpp @@ -17,9 +19,13 @@ set(GROUPED_CONV3D_BWD_WEIGHT xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pad0_pipev2_instance.cpp xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pad0_pipev5_instance.cpp xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_default_pipev2_instance.cpp + xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_default_pipev2_instance.cpp xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_default_pipev5_instance.cpp + xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_default_pipev5_instance.cpp xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_pad0_pipev2_instance.cpp + xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_pad0_pipev2_instance.cpp xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_pad0_pipev5_instance.cpp + xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_pad0_pipev5_instance.cpp xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_instance.cpp xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instance.cpp xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev2_instance.cpp @@ -34,11 +40,13 @@ set(GROUPED_CONV3D_BWD_WEIGHT xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev5_irregular_instance.cpp xdl/ngcdhw_gkzyxc_ngkdhw/device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_f32_instance.cpp + xdl/ngcdhw_gkzyxc_ngkdhw/device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_f32_tf32_instance.cpp xdl/ngcdhw_gkzyxc_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev1_instance.cpp xdl/ngcdhw_gkzyxc_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev1_instance.cpp xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkczyx_ngkdhw_f32_instance.cpp + xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_instance.cpp xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_f16_pipev2_instance.cpp xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_f16_pipev5_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/gndhwc_gkzyxc_gndhwk/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/gndhwc_gkzyxc_gndhwk/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_tf32_instance.cpp new file mode 100644 index 00000000000..88746535329 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/gndhwc_gkzyxc_gndhwk/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_tf32_instance.cpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_tf32_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_tf32_generic_instances< + 3, + GNDHWC, + GKZYXC, + GNDHWK, + ConvBwdWeightDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_tf32_generic_instances< + 3, + GNDHWC, + GKZYXC, + GNDHWK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_default_pipev2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_default_pipev2_instance.cpp new file mode 100644 index 00000000000..dab91ec7475 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_default_pipev2_instance.cpp @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_default_pipev2_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f32_tf32_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v2>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_default_pipev5_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_default_pipev5_instance.cpp new file mode 100644 index 00000000000..01229234ff2 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_default_pipev5_instance.cpp @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_default_pipev5_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f32_tf32_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v5>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp new file mode 100644 index 00000000000..ac6c3b60e40 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp @@ -0,0 +1,48 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_tf32_instances<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances(instances, + device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_tf32_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_pad0_pipev2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_pad0_pipev2_instance.cpp new file mode 100644 index 00000000000..c479cc20481 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_pad0_pipev2_instance.cpp @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_pad0_pipev2_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f32_tf32_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightFilter1x1Stride1Pad0, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v2>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_pad0_pipev5_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_pad0_pipev5_instance.cpp new file mode 100644 index 00000000000..cfb0e8a65e5 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_pad0_pipev5_instance.cpp @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_pad0_pipev5_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f32_tf32_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightFilter1x1Stride1Pad0, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v5>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_instance.cpp new file mode 100644 index 00000000000..43719e9339e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_instance.cpp @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_tf32_instances<3, + NGCDHW, + GKCZYX, + NGKDHW, + ConvBwdWeightDefault, + 1, + 1>{}); + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_tf32_instances<3, + NGCDHW, + GKCZYX, + NGKDHW, + ConvBwdWeightDefault, + 4, + 4>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ngcdhw_gkzyxc_ngkdhw/device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_f32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ngcdhw_gkzyxc_ngkdhw/device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_f32_tf32_instance.cpp new file mode 100644 index 00000000000..a819c3fe996 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ngcdhw_gkzyxc_ngkdhw/device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_f32_tf32_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_f32_tf32_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_tf32_generic_instances< + 3, + NGCDHW, + GKZYXC, + NGKDHW, + ConvBwdWeightDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_bilinear/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_bilinear/CMakeLists.txt index 329e8e4c7f7..b8621e73aae 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_bilinear/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_bilinear/CMakeLists.txt @@ -2,6 +2,7 @@ set(GROUPED_CONV3D_BWD_WEIGHT_BILINEAR xdl/device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp xdl/device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp + xdl/device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp xdl/device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp) if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_bilinear/xdl/device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_bilinear/xdl/device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp new file mode 100644 index 00000000000..522598ba87d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_bilinear/xdl/device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_bilinear_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + std::vector, + F32, + F32, + F32, + Tuple, + PassThrough, + Bilinear, + PassThrough, + TF32, + TF32>>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_tf32_bilinear_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_tf32_bilinear_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_scale/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_scale/CMakeLists.txt index 9a42d1ec3a2..5277b04ed4c 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_scale/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_scale/CMakeLists.txt @@ -2,6 +2,7 @@ set(GROUPED_CONV3D_BWD_WEIGHT_SCALE xdl/device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp xdl/device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp + xdl/device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp xdl/device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp) if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_scale/xdl/device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_scale/xdl/device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp new file mode 100644 index 00000000000..37692f3478a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_scale/xdl/device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_scale_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + std::vector, + F32, + F32, + F32, + Tuple<>, + PassThrough, + Scale, + PassThrough, + TF32, + TF32>>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_tf32_scale_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_tf32_scale_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck From b3db6c1acd4180c0538346a91a9b51496c1814bf Mon Sep 17 00:00:00 2001 From: yingmaolu Date: Tue, 23 Sep 2025 14:23:48 +0800 Subject: [PATCH 09/19] self review --- .../gpu/grouped_convolution_backward_data.hpp | 30 ++++++++++--------- .../grouped_convolution_backward_weight.hpp | 6 ++-- ...d_convolution_backward_weight_bilinear.hpp | 7 +++-- ...uped_convolution_backward_weight_scale.hpp | 3 +- ...rouped_convolution_backward_weight_xdl.inc | 2 +- .../gpu/grouped_convolution_forward.hpp | 24 +++++++++++---- ...grouped_convolution_forward_bias_clamp.hpp | 8 +++-- ...ped_convolution_forward_bias_clamp_xdl.inc | 30 +++++++++---------- .../gpu/grouped_convolution_forward_clamp.hpp | 8 +++-- .../grouped_convolution_forward_comp_xdl.inc | 2 +- ...uped_convolution_forward_mem_intra_xdl.inc | 30 +++++++++---------- ...rouped_convolution_forward_scaleadd_ab.hpp | 2 +- 12 files changed, 86 insertions(+), 66 deletions(-) diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp index b16f6f2754c..04608e1996e 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp @@ -84,10 +84,11 @@ struct DeviceOperationInstanceFactory< #endif #ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v) + is_same_v) { - if(is_same_v && is_same_v) + static_assert(is_same_v, + "Error: this operator requires the same compute type"); + if constexpr(is_same_v) { add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_tf32_instances( op_ptrs); @@ -124,11 +125,11 @@ struct DeviceOperationInstanceFactory< #endif #ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v) + is_same_v) { - - if(is_same_v && is_same_v) + static_assert(is_same_v, + "Error: this operator requires the same compute type"); + if constexpr(is_same_v) { add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_tf32_instances( op_ptrs); @@ -169,10 +170,11 @@ struct DeviceOperationInstanceFactory< #endif #ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v) + is_same_v) { - if(is_same_v && is_same_v) + static_assert(is_same_v, + "Error: this operator requires the same compute type"); + if constexpr(is_same_v) { add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkyxc_ngchw_f32_tf32_instances( op_ptrs); @@ -211,11 +213,11 @@ struct DeviceOperationInstanceFactory< #endif #ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v) + is_same_v) { - - if(is_same_v && is_same_v) + static_assert(is_same_v, + "Error: this operator requires the same compute type"); + if constexpr(is_same_v) { add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_tf32_instances( op_ptrs); diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp index b05f70d42f4..6d20b39ad27 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp @@ -274,8 +274,7 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v && - is_same_v) + is_same_v) { static_assert(is_same_v, "Error: ComputeTypeA and ComputeTypeB should be the same"); @@ -318,8 +317,7 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v && - is_same_v) + is_same_v) { static_assert(is_same_v, "Error: ComputeTypeA and ComputeTypeB should be the same"); diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_bilinear.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_bilinear.hpp index b8904ebc5d0..ffe98602a39 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_bilinear.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_bilinear.hpp @@ -153,10 +153,11 @@ struct DeviceOperationInstanceFactory< { #ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v) + is_same_v) { - if(is_same_v && is_same_v) + static_assert(is_same_v, + "Error: this operator requires the same compute type"); + if constexpr(is_same_v) { add_device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( op_ptrs); diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_scale.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_scale.hpp index af24943ca6e..4bb44b62e4b 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_scale.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_scale.hpp @@ -154,8 +154,7 @@ struct DeviceOperationInstanceFactory< { #ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v) + is_same_v) { if constexpr(is_same_v && is_same_v) diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_xdl.inc index 609ea925d7c..7086f7034cf 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_xdl.inc @@ -1437,7 +1437,7 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_pad0 PassThrough, PassThrough, TF32, - TF32>>>& instances) + TF32>>>& instances); #endif #if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_f8_instances( diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp index fcfb276fece..d13038a0f2b 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp @@ -131,7 +131,9 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { - if constexpr(is_same_v && is_same_v) + static_assert(is_same_v, + "Error: AComputeType and BComputeType should be the same"); + if constexpr(is_same_v) { add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_tf32_instances(op_ptrs); } @@ -177,7 +179,9 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { - if constexpr(is_same_v && is_same_v) + static_assert(is_same_v, + "Error: AComputeType and BComputeType should be the same"); + if constexpr(is_same_v) { add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_tf32_instances(op_ptrs); @@ -216,7 +220,9 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { - if constexpr(is_same_v && is_same_v) + static_assert(is_same_v, + "Error: AComputeType and BComputeType should be the same"); + if constexpr(is_same_v) { add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances(op_ptrs); add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instances( @@ -318,7 +324,9 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { - if constexpr(is_same_v && is_same_v) + static_assert(is_same_v, + "Error: AComputeType and BComputeType should be the same"); + if constexpr(is_same_v) { add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_f32_tf32_instances( op_ptrs); @@ -395,7 +403,9 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { - if constexpr(is_same_v && is_same_v) + static_assert(is_same_v, + "Error: AComputeType and BComputeType should be the same"); + if constexpr(is_same_v) { add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_tf32_instances(op_ptrs); } @@ -614,7 +624,9 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { - if constexpr(is_same_v && is_same_v) + static_assert(is_same_v, + "Error: AComputeType and BComputeType should be the same"); + if constexpr(is_same_v) { add_device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_f32_tf32_instances( op_ptrs); diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp.hpp index bf5ce173e75..e41e1b833bd 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp.hpp @@ -129,7 +129,9 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { - if constexpr(is_same_v && is_same_v) + static_assert(is_same_v, + "Error: AComputeType and BComputeType should be the same"); + if constexpr(is_same_v) { add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances( op_ptrs); @@ -217,7 +219,9 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { - if constexpr(is_same_v && is_same_v) + static_assert(is_same_v, + "Error: AComputeType and BComputeType should be the same"); + if constexpr(is_same_v) { add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( op_ptrs); diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp_xdl.inc index dc5e75d3cb6..988d678419c 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp_xdl.inc @@ -732,21 +732,21 @@ void add_device_grouped_conv3d_fwd_bias_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndh PassThrough, AddClamp, TF32, - TF32>>>& instances) - - void add_device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances( - std::vector, - NDHWGK, - F32, - F32, - Tuple, - F32, - PassThrough, - PassThrough, - AddClamp>>>& instances); + TF32>>>& instances); + +void add_device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances( + std::vector, + NDHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + AddClamp>>>& instances); void add_device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( std::vector && is_same_v && is_same_v) { - if constexpr(is_same_v && is_same_v) + static_assert(is_same_v, + "Error: AComputeType and BComputeType should be the same"); + if constexpr(is_same_v) { add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances( op_ptrs); @@ -214,7 +216,9 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { - if constexpr(is_same_v && is_same_v) + static_assert(is_same_v, + "Error: AComputeType and BComputeType should be the same"); + if constexpr(is_same_v) { add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( op_ptrs); diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_comp_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_comp_xdl.inc index 87e5c0e9375..d0dcdb7b84d 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_comp_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_comp_xdl.inc @@ -446,7 +446,7 @@ void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_comp_instan PassThrough, PassThrough, TF32, - TF32>>>& instances) + TF32>>>& instances); #endif } // namespace instance diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_intra_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_intra_xdl.inc index 61e9cda9367..0d07d9da43d 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_intra_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_intra_xdl.inc @@ -149,24 +149,24 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_tf32_mem_intra_inst PassThrough, PassThrough, TF32, - TF32>>>& instances) + TF32>>>& instances); #endif #ifdef CK_ENABLE_BF16 - // grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK - void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances( - std::vector>>& instances); +// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_FP16 diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_ab.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_ab.hpp index 2dfaa7eb2b8..433660fd811 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_ab.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_ab.hpp @@ -150,7 +150,7 @@ struct DeviceOperationInstanceFactory> && is_same_v> && - is_same_v && is_same_v) + is_same_v) { if constexpr(is_same_v) { From b3bb54f818595153d28ce2aaf02e75282bee40cf Mon Sep 17 00:00:00 2001 From: yingmaolu Date: Tue, 23 Sep 2025 16:41:13 +0800 Subject: [PATCH 10/19] add tf32 in profiler --- include/ck/library/utility/check_err.hpp | 32 +++-- include/ck/utility/numeric_utils.hpp | 18 +++ .../profile_grouped_conv_bwd_data_impl.hpp | 46 ++++--- .../src/profile_grouped_conv_bwd_data.cpp | 114 +++++++++++++----- .../src/profile_grouped_conv_bwd_weight.cpp | 47 ++++++++ profiler/src/profile_grouped_conv_fwd.cpp | 36 ++++++ .../profile_grouped_conv_fwd_bias_clamp.cpp | 35 ++++-- .../src/profile_grouped_conv_fwd_clamp.cpp | 33 +++-- 8 files changed, 286 insertions(+), 75 deletions(-) diff --git a/include/ck/library/utility/check_err.hpp b/include/ck/library/utility/check_err.hpp index ade4d9a5b4a..84a166f1c5c 100644 --- a/include/ck/library/utility/check_err.hpp +++ b/include/ck/library/utility/check_err.hpp @@ -31,13 +31,15 @@ double get_relative_threshold(const int number_of_accumulations = 1) using F16 = ck::half_t; using BF16 = ck::bhalf_t; using F32 = float; + using TF32 = ck::tf32_t; using I8 = int8_t; using I32 = int32_t; static_assert(is_same_v || is_same_v || is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v || is_same_v, + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v, "Warning: Unhandled ComputeDataType for setting up the relative threshold!"); double compute_error = 0; if constexpr(is_same_v || is_same_v || @@ -52,8 +54,9 @@ double get_relative_threshold(const int number_of_accumulations = 1) static_assert(is_same_v || is_same_v || is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v || is_same_v, + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v, "Warning: Unhandled OutDataType for setting up the relative threshold!"); double output_error = 0; if constexpr(is_same_v || is_same_v || @@ -69,8 +72,9 @@ double get_relative_threshold(const int number_of_accumulations = 1) static_assert(is_same_v || is_same_v || is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v || is_same_v, + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v, "Warning: Unhandled AccDataType for setting up the relative threshold!"); double acc_error = 0; if constexpr(is_same_v || is_same_v || @@ -93,13 +97,15 @@ double get_absolute_threshold(const double max_possible_num, const int number_of using F16 = ck::half_t; using BF16 = ck::bhalf_t; using F32 = float; + using TF32 = ck::tf32_t; using I8 = int8_t; using I32 = int32_t; static_assert(is_same_v || is_same_v || is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v || is_same_v, + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v, "Warning: Unhandled ComputeDataType for setting up the absolute threshold!"); auto expo = std::log2(std::abs(max_possible_num)); double compute_error = 0; @@ -115,8 +121,9 @@ double get_absolute_threshold(const double max_possible_num, const int number_of static_assert(is_same_v || is_same_v || is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v || is_same_v, + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v, "Warning: Unhandled OutDataType for setting up the absolute threshold!"); double output_error = 0; if constexpr(is_same_v || is_same_v || @@ -132,8 +139,9 @@ double get_absolute_threshold(const double max_possible_num, const int number_of static_assert(is_same_v || is_same_v || is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v || is_same_v, + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v, "Warning: Unhandled AccDataType for setting up the absolute threshold!"); double acc_error = 0; if constexpr(is_same_v || is_same_v || diff --git a/include/ck/utility/numeric_utils.hpp b/include/ck/utility/numeric_utils.hpp index 726f6675186..1c347069982 100644 --- a/include/ck/utility/numeric_utils.hpp +++ b/include/ck/utility/numeric_utils.hpp @@ -43,6 +43,24 @@ struct NumericUtils using bitwise_type = uint32_t; }; +template <> +struct NumericUtils +{ + static constexpr int exp = 8; + static constexpr int mant = 10; + static constexpr int bias = 127; + static constexpr uint32_t nan_mask = 0x7F800000; + static constexpr uint32_t head_mask = 0xFF800000; + static constexpr uint32_t mant_mask = 0x7FFFFF; + static constexpr uint32_t exp_mask = 0xFF; + static constexpr uint32_t Inf = 0x7F800000; + static constexpr uint32_t NegInf = 0xFF800000; + static constexpr uint32_t NaN = 0x7F800001; + static constexpr uint32_t Neg0 = 0x80000000; + static constexpr bool has_inf = true; + using bitwise_type = uint32_t; +}; + template <> struct NumericUtils { diff --git a/profiler/include/profiler/profile_grouped_conv_bwd_data_impl.hpp b/profiler/include/profiler/profile_grouped_conv_bwd_data_impl.hpp index 0aeefaabfbf..9a95c87958d 100644 --- a/profiler/include/profiler/profile_grouped_conv_bwd_data_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_bwd_data_impl.hpp @@ -29,7 +29,8 @@ template + typename InDataType, + typename ComputeDataType = InDataType> bool profile_grouped_conv_bwd_data_impl(int do_verification, int init_method, bool do_log, @@ -95,7 +96,11 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification, OutDataType, InElementOp, WeiElementOp, - OutElementOp>(); + OutElementOp, + 0, + 0, + 0, + ComputeDataType>(); auto ref_invoker = ref_conv.MakeInvoker(); @@ -164,9 +169,13 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification, { in_device_buf.FromDevice(in_device.mData.data()); - using ComputeType = std::conditional_t; + using ComputeType_ = std::conditional_t; + using ComputeType = + std::conditional_t; using AccDataType = std::conditional_t, int32_t, float>; const index_t num_accums = conv_param.K_; @@ -209,18 +218,21 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification, }; // do GEMM - using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD, - InLayout, - OutDataType, - WeiDataType, - ck::Tuple<>, - InDataType, - OutElementOp, - WeiElementOp, - InElementOp>; + using DeviceOp = + ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD, + InLayout, + OutDataType, + WeiDataType, + ck::Tuple<>, + InDataType, + OutElementOp, + WeiElementOp, + InElementOp, + ComputeDataType, + ComputeDataType>; // get device op instances const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< diff --git a/profiler/src/profile_grouped_conv_bwd_data.cpp b/profiler/src/profile_grouped_conv_bwd_data.cpp index 5cdece499e1..62482fc35a7 100644 --- a/profiler/src/profile_grouped_conv_bwd_data.cpp +++ b/profiler/src/profile_grouped_conv_bwd_data.cpp @@ -21,9 +21,10 @@ enum struct ConvLayout enum struct ConvDataType { - F32_F32_F32, // 0 - F16_F16_F16, // 1 - BF16_BF16_BF16, // 2 + F32_F32_F32, // 0 + F16_F16_F16, // 1 + BF16_BF16_BF16, // 2 + F32_F32_F32_TF32, // 3 }; #define OP_NAME "grouped_conv_bwd_data" @@ -37,6 +38,7 @@ static void print_helper_msg() << "arg2: data type (0: Output fp32, Weight fp32, Input fp32\n" << " 1: Output fp16, Weight fp16, Input fp16\n" << " 2: Output bf16, Weight bf16, Input bf16\n" + << " 3: Output fp32, Weight fp32, Input fp32, Compute tf32)\n" << "arg3: tensor layout (0: Output[G, N, Ho, Wo, C], Weight[G, K, Y, X, C], Input[G, N, Hi, Wi, K]\n" << " 1: Output[N, Ho, Wo, G, C], Weight[G, K, Y, X, C], Input[N, Hi, Wi, G, K])\n" << " 2: Output[N, G, C, Ho, Wo], Weight[G, K, Y, X, C], Input[N, G, K, Hi, Wi])\n" @@ -82,6 +84,9 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[]) using F32 = float; using F16 = ck::half_t; using BF16 = ck::bhalf_t; +#if defined(__gfx942__) + using TF32 = ck::tf32_t; +#endif using namespace ck::tensor_layout::convolution; @@ -94,7 +99,8 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[]) auto in_layout, auto wei_type, auto out_type, - auto in_type) { + auto in_type, + auto compute_type) { constexpr ck::index_t NDimSpatial = num_dim_spatial_tmp.value; using OutLayout = decltype(out_layout); @@ -104,6 +110,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[]) using OutDataType = decltype(out_type); using WeiDataType = decltype(wei_type); using InDataType = decltype(in_type); + using ComputeDataType = decltype(compute_type); bool pass = ck::profiler::profile_grouped_conv_bwd_data_impl( + InDataType, + ComputeDataType>( do_verification, init_method, do_log, time_kernel, params, split_k); return pass ? 0 : 1; @@ -123,60 +131,84 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[]) { if(data_type == ConvDataType::F32_F32_F32) { - return profile(I2, GNHWK{}, GKYXC{}, GNHWC{}, F32{}, F32{}, F32{}); + return profile(I2, GNHWK{}, GKYXC{}, GNHWC{}, F32{}, F32{}, F32{}, F32{}); } else if(data_type == ConvDataType::F16_F16_F16) { - return profile(I2, GNHWK{}, GKYXC{}, GNHWC{}, F16{}, F16{}, F16{}); + return profile(I2, GNHWK{}, GKYXC{}, GNHWC{}, F16{}, F16{}, F16{}, F16{}); } else if(data_type == ConvDataType::BF16_BF16_BF16) { - return profile(I2, GNHWK{}, GKYXC{}, GNHWC{}, BF16{}, BF16{}, BF16{}); + return profile(I2, GNHWK{}, GKYXC{}, GNHWC{}, BF16{}, BF16{}, BF16{}, BF16{}); + } + else if(data_type == ConvDataType::F32_F32_F32_TF32) + { +#if defined(__gfx942__) + return profile(I2, GNHWK{}, GKYXC{}, GNHWC{}, F32{}, F32{}, F32{}, TF32{}); +#endif } } else if(layout == ConvLayout::NHWGC_GKYXC_NHWGK) { if(data_type == ConvDataType::F32_F32_F32) { - return profile(I2, NHWGK{}, GKYXC{}, NHWGC{}, F32{}, F32{}, F32{}); + return profile(I2, NHWGK{}, GKYXC{}, NHWGC{}, F32{}, F32{}, F32{}, F32{}); } else if(data_type == ConvDataType::F16_F16_F16) { - return profile(I2, NHWGK{}, GKYXC{}, NHWGC{}, F16{}, F16{}, F16{}); + return profile(I2, NHWGK{}, GKYXC{}, NHWGC{}, F16{}, F16{}, F16{}, F16{}); } else if(data_type == ConvDataType::BF16_BF16_BF16) { - return profile(I2, NHWGK{}, GKYXC{}, NHWGC{}, BF16{}, BF16{}, BF16{}); + return profile(I2, NHWGK{}, GKYXC{}, NHWGC{}, BF16{}, BF16{}, BF16{}, BF16{}); + } + else if(data_type == ConvDataType::F32_F32_F32_TF32) + { +#if defined(__gfx942__) + return profile(I2, NHWGK{}, GKYXC{}, NHWGC{}, F32{}, F32{}, F32{}, TF32{}); +#endif } } else if(layout == ConvLayout::NGCHW_GKYXC_NGKHW) { if(data_type == ConvDataType::F32_F32_F32) { - return profile(I2, NGKHW{}, GKYXC{}, NGCHW{}, F32{}, F32{}, F32{}); + return profile(I2, NGKHW{}, GKYXC{}, NGCHW{}, F32{}, F32{}, F32{}, F32{}); } else if(data_type == ConvDataType::F16_F16_F16) { - return profile(I2, NGKHW{}, GKYXC{}, NGCHW{}, F16{}, F16{}, F16{}); + return profile(I2, NGKHW{}, GKYXC{}, NGCHW{}, F16{}, F16{}, F16{}, F16{}); } else if(data_type == ConvDataType::BF16_BF16_BF16) { - return profile(I2, NGKHW{}, GKYXC{}, NGCHW{}, BF16{}, BF16{}, BF16{}); + return profile(I2, NGKHW{}, GKYXC{}, NGCHW{}, BF16{}, BF16{}, BF16{}, BF16{}); + } + else if(data_type == ConvDataType::F32_F32_F32_TF32) + { +#if defined(__gfx942__) + return profile(I2, NGKHW{}, GKYXC{}, NGCHW{}, F32{}, F32{}, F32{}, TF32{}); +#endif } } else if(layout == ConvLayout::NGCHW_GKCYX_NGKHW) { if(data_type == ConvDataType::F32_F32_F32) { - return profile(I2, NGKHW{}, GKCYX{}, NGCHW{}, F32{}, F32{}, F32{}); + return profile(I2, NGKHW{}, GKCYX{}, NGCHW{}, F32{}, F32{}, F32{}, F32{}); } else if(data_type == ConvDataType::F16_F16_F16) { - return profile(I2, NGKHW{}, GKCYX{}, NGCHW{}, F16{}, F16{}, F16{}); + return profile(I2, NGKHW{}, GKCYX{}, NGCHW{}, F16{}, F16{}, F16{}, F16{}); } else if(data_type == ConvDataType::BF16_BF16_BF16) { - return profile(I2, NGKHW{}, GKCYX{}, NGCHW{}, BF16{}, BF16{}, BF16{}); + return profile(I2, NGKHW{}, GKCYX{}, NGCHW{}, BF16{}, BF16{}, BF16{}, BF16{}); + } + else if(data_type == ConvDataType::F32_F32_F32_TF32) + { +#if defined(__gfx942__) + return profile(I2, NGKHW{}, GKCYX{}, NGCHW{}, F32{}, F32{}, F32{}, TF32{}); +#endif } } } @@ -186,60 +218,84 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[]) { if(data_type == ConvDataType::F32_F32_F32) { - return profile(I3, GNDHWK{}, GKZYXC{}, GNDHWC{}, F32{}, F32{}, F32{}); + return profile(I3, GNDHWK{}, GKZYXC{}, GNDHWC{}, F32{}, F32{}, F32{}, F32{}); } else if(data_type == ConvDataType::F16_F16_F16) { - return profile(I3, GNDHWK{}, GKZYXC{}, GNDHWC{}, F16{}, F16{}, F16{}); + return profile(I3, GNDHWK{}, GKZYXC{}, GNDHWC{}, F16{}, F16{}, F16{}, F16{}); } else if(data_type == ConvDataType::BF16_BF16_BF16) { - return profile(I3, GNDHWK{}, GKZYXC{}, GNDHWC{}, BF16{}, BF16{}, BF16{}); + return profile(I3, GNDHWK{}, GKZYXC{}, GNDHWC{}, BF16{}, BF16{}, BF16{}, BF16{}); + } + else if(data_type == ConvDataType::F32_F32_F32_TF32) + { +#if defined(__gfx942__) + return profile(I3, GNDHWK{}, GKZYXC{}, GNDHWC{}, F32{}, F32{}, F32{}, TF32{}); +#endif } } else if(layout == ConvLayout::NHWGC_GKYXC_NHWGK) { if(data_type == ConvDataType::F32_F32_F32) { - return profile(I3, NDHWGK{}, GKZYXC{}, NDHWGC{}, F32{}, F32{}, F32{}); + return profile(I3, NDHWGK{}, GKZYXC{}, NDHWGC{}, F32{}, F32{}, F32{}, F32{}); } else if(data_type == ConvDataType::F16_F16_F16) { - return profile(I3, NDHWGK{}, GKZYXC{}, NDHWGC{}, F16{}, F16{}, F16{}); + return profile(I3, NDHWGK{}, GKZYXC{}, NDHWGC{}, F16{}, F16{}, F16{}, F16{}); } else if(data_type == ConvDataType::BF16_BF16_BF16) { - return profile(I3, NDHWGK{}, GKZYXC{}, NDHWGC{}, BF16{}, BF16{}, BF16{}); + return profile(I3, NDHWGK{}, GKZYXC{}, NDHWGC{}, BF16{}, BF16{}, BF16{}, BF16{}); + } + else if(data_type == ConvDataType::F32_F32_F32_TF32) + { +#if defined(__gfx942__) + return profile(I3, NDHWGK{}, GKZYXC{}, NDHWGC{}, F32{}, F32{}, F32{}, TF32{}); +#endif } } else if(layout == ConvLayout::NGCHW_GKYXC_NGKHW) { if(data_type == ConvDataType::F32_F32_F32) { - return profile(I3, NGKDHW{}, GKZYXC{}, NGCDHW{}, F32{}, F32{}, F32{}); + return profile(I3, NGKDHW{}, GKZYXC{}, NGCDHW{}, F32{}, F32{}, F32{}, F32{}); } else if(data_type == ConvDataType::F16_F16_F16) { - return profile(I3, NGKDHW{}, GKZYXC{}, NGCDHW{}, F16{}, F16{}, F16{}); + return profile(I3, NGKDHW{}, GKZYXC{}, NGCDHW{}, F16{}, F16{}, F16{}, F16{}); } else if(data_type == ConvDataType::BF16_BF16_BF16) { - return profile(I3, NGKDHW{}, GKZYXC{}, NGCDHW{}, BF16{}, BF16{}, BF16{}); + return profile(I3, NGKDHW{}, GKZYXC{}, NGCDHW{}, BF16{}, BF16{}, BF16{}, BF16{}); + } + else if(data_type == ConvDataType::F32_F32_F32_TF32) + { +#if defined(__gfx942__) + return profile(I3, NGKDHW{}, GKZYXC{}, NGCDHW{}, F32{}, F32{}, F32{}, TF32{}); +#endif } } else if(layout == ConvLayout::NGCHW_GKYXC_NGKHW) { if(data_type == ConvDataType::F32_F32_F32) { - return profile(I3, NGKDHW{}, GKCZYX{}, NGCDHW{}, F32{}, F32{}, F32{}); + return profile(I3, NGKDHW{}, GKCZYX{}, NGCDHW{}, F32{}, F32{}, F32{}, F32{}); } else if(data_type == ConvDataType::F16_F16_F16) { - return profile(I3, NGKDHW{}, GKCZYX{}, NGCDHW{}, F16{}, F16{}, F16{}); + return profile(I3, NGKDHW{}, GKCZYX{}, NGCDHW{}, F16{}, F16{}, F16{}, F16{}); } else if(data_type == ConvDataType::BF16_BF16_BF16) { - return profile(I3, NGKDHW{}, GKCZYX{}, NGCDHW{}, BF16{}, BF16{}, BF16{}); + return profile(I3, NGKDHW{}, GKCZYX{}, NGCDHW{}, BF16{}, BF16{}, BF16{}, BF16{}); + } + else if(data_type == ConvDataType::F32_F32_F32_TF32) + { +#if defined(__gfx942__) + return profile(I3, NGKDHW{}, GKCZYX{}, NGCDHW{}, F32{}, F32{}, F32{}, TF32{}); +#endif } } } diff --git a/profiler/src/profile_grouped_conv_bwd_weight.cpp b/profiler/src/profile_grouped_conv_bwd_weight.cpp index 8347ce0e429..24c848e08b7 100644 --- a/profiler/src/profile_grouped_conv_bwd_weight.cpp +++ b/profiler/src/profile_grouped_conv_bwd_weight.cpp @@ -28,6 +28,7 @@ enum struct ConvDataType F16_F16_F16_BF8_F8, // 3 I8_I8_I8, // 4 BF16_BF16_BF16, // 5 + F32_F32_F32_TF32, // 6 }; #define OP_NAME "grouped_conv_bwd_weight" @@ -42,6 +43,7 @@ static void print_helper_msg() << " 3: Input fp16, Weight fp16, Output fp16, Gemm bf8@fp8\n" << " 4: Input int8, Weight int8, Output int8\n" << " 5: Input bf16, Weight bf16, Output bf16)\n" + << " 6: Input fp32, Weight fp32, Output fp32, Compute tf32)\n" << "arg3: tensor layout (0: Input[G, N, C, Hi, Wi], Weight[G, K, C, Y, X], Output[G, " "N, K, Ho, Wo]\n" << " 1: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, " @@ -97,6 +99,9 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) using BF16 = ck::bhalf_t; using F8 = ck::f8_t; using BF8 = ck::bf8_t; +#if defined(__gfx942__) + using TF32 = ck::tf32_t; +#endif using namespace ck::tensor_layout::convolution; @@ -155,6 +160,12 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) // fp32 atomic add is used for weight tensor in bf16 kernel return profile(I1, GNWC{}, GKXC{}, GNWK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{}); } + if(data_type == ConvDataType::F32_F32_F32_TF32) + { +#if defined(__gfx942__) + return profile(I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); +#endif + } } if(num_dim_spatial == 2 && layout == ConvLayout::GNHWC_GKYXC_GNHWK) { @@ -171,6 +182,12 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) // fp32 atomic add is used for weight tensor in bf16 kernel return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{}); } + if(data_type == ConvDataType::F32_F32_F32_TF32) + { +#if defined(__gfx942__) + return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); +#endif + } } if(num_dim_spatial == 2 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) { @@ -191,6 +208,12 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) { return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{}); } + if(data_type == ConvDataType::F32_F32_F32_TF32) + { +#if defined(__gfx942__) + return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); +#endif + } } else if(num_dim_spatial == 2 && layout == ConvLayout::NGCHW_GKYXC_NGKHW) { @@ -218,6 +241,12 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) { return profile(I2, NGCHW{}, GKCYX{}, NGKHW{}, F32{}, F32{}, F32{}, F32{}, F32{}); } + if(data_type == ConvDataType::F32_F32_F32_TF32) + { +#if defined(__gfx942__) + return profile(I2, NGCHW{}, GKCYX{}, NGKHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); +#endif + } } if(num_dim_spatial == 3 && layout == ConvLayout::GNHWC_GKYXC_GNHWK) { @@ -239,6 +268,12 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) return profile( I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, int8_t{}, int8_t{}, int8_t{}, int8_t{}, int8_t{}); } + else if(data_type == ConvDataType::F32_F32_F32_TF32) + { +#if defined(__gfx942__) + return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); +#endif + } } if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) { @@ -269,6 +304,12 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) return profile( I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, int8_t{}, int8_t{}, int8_t{}, int8_t{}, int8_t{}); } + if(data_type == ConvDataType::F32_F32_F32_TF32) + { +#if defined(__gfx942__) + return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); +#endif + } } else if(num_dim_spatial == 3 && layout == ConvLayout::NGCHW_GKYXC_NGKHW) { @@ -297,6 +338,12 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) { return profile(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F32{}, F32{}, F32{}, F32{}, F32{}); } + if(data_type == ConvDataType::F32_F32_F32_TF32) + { +#if defined(__gfx942__) + return profile(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); +#endif + } } std::cout << "this data_type & layout is not implemented" << std::endl; diff --git a/profiler/src/profile_grouped_conv_fwd.cpp b/profiler/src/profile_grouped_conv_fwd.cpp index a8d343405d1..13f5cd1cda6 100644 --- a/profiler/src/profile_grouped_conv_fwd.cpp +++ b/profiler/src/profile_grouped_conv_fwd.cpp @@ -226,6 +226,12 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) { return profile(I1, GNWC{}, GKXC{}, GNWK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{}); } + else if(data_type == ConvDataType::F32_F32_F32_TF32) + { +#if defined(__gfx942__) + return profile(I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); +#endif + } } else if(num_dim_spatial == 2 && layout == ConvLayout::GNHWC_GKYXC_GNHWK) { @@ -245,6 +251,12 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) { return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{}); } + else if(data_type == ConvDataType::F32_F32_F32_TF32) + { +#if defined(__gfx942__) + return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); +#endif + } } else if(num_dim_spatial == 3 && layout == ConvLayout::GNHWC_GKYXC_GNHWK) { @@ -292,6 +304,12 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) { return profile(I1, NWGC{}, GKXC{}, NWGK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{}); } + else if(data_type == ConvDataType::F32_F32_F32_TF32) + { +#if defined(__gfx942__) + return profile(I1, NWGC{}, GKXC{}, NWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); +#endif + } } else if(num_dim_spatial == 2 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) { @@ -311,6 +329,12 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) { return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{}); } + else if(data_type == ConvDataType::F32_F32_F32_TF32) + { +#if defined(__gfx942__) + return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); +#endif + } } else if(num_dim_spatial == 2 && layout == ConvLayout::NGCHW_GKYXC_NGKHW) { @@ -326,6 +350,12 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) { return profile(I2, NGCHW{}, GKYXC{}, NGKHW{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{}); } + else if(data_type == ConvDataType::F32_F32_F32_TF32) + { +#if defined(__gfx942__) + return profile(I2, NGCHW{}, GKYXC{}, NGKHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); +#endif + } } else if(num_dim_spatial == 2 && layout == ConvLayout::NGCHW_GKCYX_NGKHW) { @@ -341,6 +371,12 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) { return profile(I2, NGCHW{}, GKCYX{}, NGKHW{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{}); } + else if(data_type == ConvDataType::F32_F32_F32_TF32) + { +#if defined(__gfx942__) + return profile(I2, NGCHW{}, GKCYX{}, NGKHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); +#endif + } } else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) { diff --git a/profiler/src/profile_grouped_conv_fwd_bias_clamp.cpp b/profiler/src/profile_grouped_conv_fwd_bias_clamp.cpp index 34b3df1c654..fb1eedf2a7d 100644 --- a/profiler/src/profile_grouped_conv_fwd_bias_clamp.cpp +++ b/profiler/src/profile_grouped_conv_fwd_bias_clamp.cpp @@ -20,14 +20,15 @@ enum struct ConvLayout enum struct ConvDataType { - F32_F32_F32, // 0 - F16_F16_F16, // 1 - BF16_BF16_BF16, // 2 - INT8_INT8_INT8, // 3 - F8_F8_F8, // 4 - BF8_BF8_F8, // 5 - F8_BF8_F8, // 6 - BF8_F8_F8, // 7 + F32_F32_F32, // 0 + F16_F16_F16, // 1 + BF16_BF16_BF16, // 2 + INT8_INT8_INT8, // 3 + F8_F8_F8, // 4 + BF8_BF8_F8, // 5 + F8_BF8_F8, // 6 + BF8_F8_F8, // 7 + F32_F32_F32_TF32, // 8 }; enum struct IndexType @@ -51,7 +52,8 @@ static void print_helper_msg() << " 4: Input fp8, Weight fp8, Output fp8\n" << " 5: Input bf8, Weight bf8, Output fp8\n" << " 6: Input fp8, Weight bf8, Output fp8\n" - << " 7: Input bf8, Weight fp8, Output fp8)\n" + << " 7: Input bf8, Weight fp8, Output fp8\n" + << " 8: Input fp32, Weight fp32, Output fp32, Compute tf32)\n" << "arg3: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]\n" << " 1: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K]\n" << " 2: Input[N, G, C, Hi, Wi], Weight[G, K, Y, X, C], Output[N, " @@ -103,6 +105,9 @@ int grouped_conv_fwd_bias_clamp(int argc, char* argv[]) using F32 = float; using BF16 = ck::bhalf_t; using F16 = ck::half_t; +#if defined(__gfx942__) + using TF32 = ck::tf32_t; +#endif using GKZYXC = ck::tensor_layout::convolution::GKZYXC; using NDHWGC = ck::tensor_layout::convolution::NDHWGC; @@ -165,6 +170,12 @@ int grouped_conv_fwd_bias_clamp(int argc, char* argv[]) { return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{}); } + else if(data_type == ConvDataType::F32_F32_F32_TF32) + { +#if defined(__gfx942__) + return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); +#endif + } } else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) { @@ -181,6 +192,12 @@ int grouped_conv_fwd_bias_clamp(int argc, char* argv[]) return profile( I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{}); } + else if(data_type == ConvDataType::F32_F32_F32_TF32) + { +#if defined(__gfx942__) + return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); +#endif + } } std::cout << "this data_type & layout is not implemented" << std::endl; diff --git a/profiler/src/profile_grouped_conv_fwd_clamp.cpp b/profiler/src/profile_grouped_conv_fwd_clamp.cpp index 600f91744aa..f23e2ddc110 100644 --- a/profiler/src/profile_grouped_conv_fwd_clamp.cpp +++ b/profiler/src/profile_grouped_conv_fwd_clamp.cpp @@ -20,14 +20,15 @@ enum struct ConvLayout enum struct ConvDataType { - F32_F32_F32, // 0 - F16_F16_F16, // 1 - BF16_BF16_BF16, // 2 - INT8_INT8_INT8, // 3 - F8_F8_F8, // 4 - BF8_BF8_F8, // 5 - F8_BF8_F8, // 6 - BF8_F8_F8, // 7 + F32_F32_F32, // 0 + F16_F16_F16, // 1 + BF16_BF16_BF16, // 2 + INT8_INT8_INT8, // 3 + F8_F8_F8, // 4 + BF8_BF8_F8, // 5 + F8_BF8_F8, // 6 + BF8_F8_F8, // 7 + F32_F32_F32_TF32, // 8 }; enum struct IndexType @@ -52,6 +53,7 @@ static void print_helper_msg() << " 5: Input bf8, Weight bf8, Output fp8\n" << " 6: Input fp8, Weight bf8, Output fp8\n" << " 7: Input bf8, Weight fp8, Output fp8)\n" + << " 8: Input fp32, Weight fp32, Output fp32, Compute tf32)\n" << "arg3: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]\n" << " 1: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K]\n" << " 2: Input[N, G, C, Hi, Wi], Weight[G, K, Y, X, C], Output[N, " @@ -103,6 +105,9 @@ int grouped_conv_fwd_clamp(int argc, char* argv[]) using F32 = float; using BF16 = ck::bhalf_t; using F16 = ck::half_t; +#if defined(__gfx942__) + using TF32 = ck::tf32_t; +#endif using GKZYXC = ck::tensor_layout::convolution::GKZYXC; using NDHWGC = ck::tensor_layout::convolution::NDHWGC; @@ -168,6 +173,12 @@ int grouped_conv_fwd_clamp(int argc, char* argv[]) { return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{}); } + else if(data_type == ConvDataType::F32_F32_F32_TF32) + { +#if defined(__gfx942__) + return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); +#endif + } } else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) { @@ -184,6 +195,12 @@ int grouped_conv_fwd_clamp(int argc, char* argv[]) return profile( I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{}); } + else if(data_type == ConvDataType::F32_F32_F32_TF32) + { +#if defined(__gfx942__) + return profile(I3, NDHWGC{}, NDHWGC{}, NDHWGC{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); +#endif + } } std::cout << "this data_type & layout is not implemented" << std::endl; From 7a653cd08b93c232988e40f4a08f791edf1807b7 Mon Sep 17 00:00:00 2001 From: yingmaolu Date: Tue, 23 Sep 2025 15:54:25 +0800 Subject: [PATCH 11/19] remove useless instances --- ...wd_xdl_scaleadd_scaleadd_relu_instance.hpp | 23 ------- .../gpu/grouped_convolution_forward.hpp | 14 +---- ...rouped_convolution_forward_scaleadd_ab.hpp | 27 +-------- ...olution_forward_scaleadd_scaleadd_relu.hpp | 28 +-------- .../gpu/grouped_convolution_forward_xdl.inc | 15 ----- .../gpu/grouped_conv1d_fwd/CMakeLists.txt | 1 - ...d_xdl_gnwc_gkxc_gnwk_f32_tf32_instance.cpp | 56 ----------------- .../CMakeLists.txt | 1 - ...ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp | 54 ----------------- .../CMakeLists.txt | 1 - ...ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp | 60 ------------------- 11 files changed, 7 insertions(+), 273 deletions(-) delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/xdl/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_tf32_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_instance.hpp index 91525b92d31..d62bec2b356 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_instance.hpp @@ -16,7 +16,6 @@ namespace instance { using BF16 = ck::bhalf_t; using F16 = ck::half_t; using F32 = float; -using TF32 = ck::tf32_t; template using S = ck::Sequence; @@ -104,28 +103,6 @@ using device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f32_instances = std::tu // clang-format on >; -template -using device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f32_tf32_instances = std::tuple< - // clang-format off - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - // generic instance - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, ScaleAddScaleAddRelu, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1, TF32, TF32>, - // instances for small conv.K and conv.C - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, ScaleAddScaleAddRelu, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 1, TF32, TF32>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, ScaleAddScaleAddRelu, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>, - - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, ScaleAddScaleAddRelu, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32> - // clang-format on - >; - template && is_same_v && - is_same_v) + is_same_v && is_same_v && + is_same_v) { - static_assert(is_same_v, - "Error: AComputeType and BComputeType should be the same"); - if constexpr(is_same_v) - { - add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_tf32_instances(op_ptrs); - } - else - { - add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instances(op_ptrs); - } + add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instances(op_ptrs); } #endif #ifdef CK_ENABLE_FP16 diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_ab.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_ab.hpp index 433660fd811..d7a217f1b84 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_ab.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_ab.hpp @@ -68,21 +68,6 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f32_inst ScaleAdd, ScaleAdd, PassThrough>>>& instances); -void add_device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( - std::vector, - NDHWGK, - ck::Tuple, - ck::Tuple, - ck::Tuple<>, - F32, - ScaleAdd, - ScaleAdd, - PassThrough, - TF32, - TF32>>>& instances); #endif #ifdef CK_ENABLE_INT8 @@ -152,16 +137,8 @@ struct DeviceOperationInstanceFactory> && is_same_v) { - if constexpr(is_same_v) - { - add_device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( - op_ptrs); - } - else - { - add_device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f32_instances( - op_ptrs); - } + add_device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f32_instances( + op_ptrs); } #endif #ifdef CK_ENABLE_FP16 diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_scaleadd_relu.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_scaleadd_relu.hpp index 13894cac919..efb62664266 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_scaleadd_relu.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_scaleadd_relu.hpp @@ -68,22 +68,6 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw PassThrough, PassThrough, ScaleAddScaleAddRelu>>>& instances); - -void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( - std::vector, - NDHWGK, - F32, - F32, - ck::Tuple, - F32, - PassThrough, - PassThrough, - ScaleAddScaleAddRelu, - TF32, - TF32>>>& instances); #endif #ifdef CK_ENABLE_INT8 @@ -154,16 +138,8 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { - if constexpr(is_same_v) - { - add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( - op_ptrs); - } - else - { - add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f32_instances( - op_ptrs); - } + add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f32_instances( + op_ptrs); } #endif #ifdef CK_ENABLE_FP16 diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc index 72f9591915c..b0ba477dba9 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc @@ -55,21 +55,6 @@ void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instances( PassThrough, PassThrough, PassThrough>>>& instances); -void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_tf32_instances( - std::vector>>& instances); #endif #ifdef CK_ENABLE_INT8 diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/CMakeLists.txt index 6bb7e202eb1..ca4ea515bb0 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/CMakeLists.txt @@ -3,6 +3,5 @@ add_instance_library(device_grouped_conv1d_fwd_instance xdl/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instance.cpp xdl/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instance.cpp xdl/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instance.cpp - xdl/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_tf32_instance.cpp xdl/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/xdl/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/xdl/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_tf32_instance.cpp deleted file mode 100644 index 0078d8788c3..00000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/xdl/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_tf32_instance.cpp +++ /dev/null @@ -1,56 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_tf32_instances( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_f32_tf32_instances<1, - GNWC, - GKXC, - Empty_Tuple, - GNWK, - ConvFwdDefault>{}); - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f32_tf32_instances<1, - GNWC, - GKXC, - Empty_Tuple, - GNWK, - ConvFwd1x1P0>{}); - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_f32_tf32_instances<1, - GNWC, - GKXC, - Empty_Tuple, - GNWK, - ConvFwd1x1S1P0>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/CMakeLists.txt index 74d4a3829aa..10762494474 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/CMakeLists.txt @@ -3,7 +3,6 @@ set(GROUPED_CONV3D_FWD_SCALEADD_AB xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp - xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp) add_instance_library(device_grouped_conv3d_fwd_scaleadd_ab_instance ${GROUPED_CONV3D_FWD_SCALEADD_AB}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp deleted file mode 100644 index 315aefb8251..00000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp +++ /dev/null @@ -1,54 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_scaleadd_ab_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -void add_device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( - std::vector, - NDHWGK, - ck::Tuple, - ck::Tuple, - ck::Tuple<>, - F32, - ScaleAdd, - ScaleAdd, - PassThrough, - TF32, - TF32>>>& instances) -{ - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_scaleadd_ab_f32_tf32_instances<3, - NDHWGC, - GKZYXC, - NDHWGK, - ConvFwdDefault>{}); - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_scaleadd_ab_f32_tf32_instances<3, - NDHWGC, - GKZYXC, - NDHWGK, - ConvFwd1x1P0>{}); - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_scaleadd_ab_f32_tf32_instances<3, - NDHWGC, - GKZYXC, - NDHWGK, - ConvFwd1x1S1P0>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/CMakeLists.txt index ea9bbc3a4ab..1be1db7d1d9 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/CMakeLists.txt @@ -3,7 +3,6 @@ set(GROUPED_CONV3D_FWD_scaleadd_scaleadd_RELU xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp - xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp) add_instance_library(device_grouped_conv3d_fwd_scaleadd_scaleadd_relu_instance ${GROUPED_CONV3D_FWD_scaleadd_scaleadd_RELU}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp deleted file mode 100644 index 35d86e0e9dd..00000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp +++ /dev/null @@ -1,60 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( - std::vector, - NDHWGK, - F32, - F32, - ck::Tuple, - F32, - PassThrough, - PassThrough, - ScaleAddScaleAddRelu, - TF32, - TF32>>>& instances) -{ - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f32_tf32_instances< - 3, - NDHWGC, - GKZYXC, - ck::Tuple, - NDHWGK, - ConvFwdDefault>{}); - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f32_tf32_instances< - 3, - NDHWGC, - GKZYXC, - ck::Tuple, - NDHWGK, - ConvFwd1x1P0>{}); - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f32_tf32_instances< - 3, - NDHWGC, - GKZYXC, - ck::Tuple, - NDHWGK, - ConvFwd1x1S1P0>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck From 040aee6b757227ffb139a91bcba7c398d8b4c836 Mon Sep 17 00:00:00 2001 From: yingmaolu Date: Wed, 24 Sep 2025 16:25:53 +0800 Subject: [PATCH 12/19] remove gnhwc/ngchw/ngcdhw instances --- .../gpu/grouped_convolution_forward.hpp | 150 +++++------------- .../grouped_convolution_forward_comp_xdl.inc | 30 ---- ...uped_convolution_forward_mem_inter_xdl.inc | 15 -- ...uped_convolution_forward_mem_intra_xdl.inc | 15 -- .../gpu/grouped_convolution_forward_xdl.inc | 46 ------ ..._convolution_forward_xdl_merged_groups.inc | 31 ---- .../gpu/grouped_conv2d_fwd/CMakeLists.txt | 7 - ...chw_gkcyx_ngkhw_f32_tf32_comp_instance.cpp | 41 ----- ...dl_gnhwc_gkyxc_gnhwk_f32_tf32_instance.cpp | 66 -------- ...dl_ngchw_gkcyx_ngkhw_f32_tf32_instance.cpp | 56 ------- ...dl_ngchw_gkyxc_ngkhw_f32_tf32_instance.cpp | 41 ----- ...kcyx_ngkhw_f32_tf32_mem_inter_instance.cpp | 42 ----- ...kcyx_ngkhw_f32_tf32_mem_intra_instance.cpp | 42 ----- ...ps_ngchw_gkcyx_ngkhw_f32_tf32_instance.cpp | 50 ------ .../gpu/grouped_conv3d_fwd/CMakeLists.txt | 15 -- ...w_gkczyx_ngkdhw_f32_tf32_comp_instance.cpp | 57 ------- ...czyx_ngkdhw_f32_tf32_mem_inter_instance.in | 67 -------- ...czyx_ngkdhw_f32_tf32_mem_intra_instance.in | 67 -------- ...ngcdhw_gkczyx_ngkdhw_f32_tf32_instance.cpp | 49 ------ 19 files changed, 41 insertions(+), 846 deletions(-) delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_tf32_comp_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_tf32_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_tf32_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_tf32_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_tf32_mem_inter_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_tf32_mem_intra_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_f32_tf32_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_comp_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_mem_inter_instance.in delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_mem_intra_instance.in delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_f32_tf32_instance.cpp diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp index 330d83c0793..e73e8aac1e9 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp @@ -169,20 +169,10 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v) + is_same_v && is_same_v && + is_same_v) { - static_assert(is_same_v, - "Error: AComputeType and BComputeType should be the same"); - if constexpr(is_same_v) - { - - add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_tf32_instances(op_ptrs); - } - else - { - - add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances(op_ptrs); - } + add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances(op_ptrs); } #endif #ifdef CK_ENABLE_FP16 @@ -210,39 +200,20 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v) + is_same_v && is_same_v && + is_same_v) { - static_assert(is_same_v, - "Error: AComputeType and BComputeType should be the same"); - if constexpr(is_same_v) - { - add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances(op_ptrs); - add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instances( - op_ptrs); - add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_comp_instances( - op_ptrs); - add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instances( - op_ptrs); - add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_intra_instances( - op_ptrs); - add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instances( - op_ptrs); - } - else - { - add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs); - add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_16x16_instances( - op_ptrs); - add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instances( - op_ptrs); - add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances( - op_ptrs); - add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances(op_ptrs); - add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances( - op_ptrs); - add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances( - op_ptrs); - } + add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs); + add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_16x16_instances(op_ptrs); + add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances(op_ptrs); + add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances( + op_ptrs); } #endif #ifdef CK_ENABLE_FP16 @@ -314,35 +285,18 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v) + is_same_v && is_same_v && + is_same_v) { - static_assert(is_same_v, - "Error: AComputeType and BComputeType should be the same"); - if constexpr(is_same_v) - { - add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_f32_tf32_instances( - op_ptrs); - add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_tf32_instances(op_ptrs); - add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_tf32_comp_instances( - op_ptrs); - add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_tf32_mem_inter_instances( - op_ptrs); - add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_tf32_mem_intra_instances( - op_ptrs); - } - else - { - add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_f32_instances( - op_ptrs); - add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_instances(op_ptrs); - add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_16x16_instances( - op_ptrs); - add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_comp_instances(op_ptrs); - add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_mem_intra_instances( - op_ptrs); - add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_mem_inter_instances( - op_ptrs); - } + add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_f32_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_instances(op_ptrs); + add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_16x16_instances(op_ptrs); + add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_comp_instances(op_ptrs); + add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_mem_intra_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_mem_inter_instances( + op_ptrs); } #endif #ifdef CK_ENABLE_FP16 @@ -393,18 +347,10 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v) + is_same_v && is_same_v && + is_same_v) { - static_assert(is_same_v, - "Error: AComputeType and BComputeType should be the same"); - if constexpr(is_same_v) - { - add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_tf32_instances(op_ptrs); - } - else - { - add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_instances(op_ptrs); - } + add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_instances(op_ptrs); } #endif #ifdef CK_ENABLE_FP16 @@ -516,7 +462,6 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v) + is_same_v && is_same_v && + is_same_v) { - static_assert(is_same_v, - "Error: AComputeType and BComputeType should be the same"); - if constexpr(is_same_v) - { - add_device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_f32_tf32_instances( - op_ptrs); - add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_comp_instances( - op_ptrs); - } - else - { - add_device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_f32_instances( - op_ptrs); - add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_instances(op_ptrs); - add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_16x16_instances( - op_ptrs); - add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_comp_instances( - op_ptrs); - add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instances( - op_ptrs); - add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instances( - op_ptrs); - } + add_device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_f32_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_instances(op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_16x16_instances(op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_comp_instances(op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instances( + op_ptrs); } #endif #ifdef CK_ENABLE_FP16 diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_comp_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_comp_xdl.inc index d0dcdb7b84d..91221c2c0cf 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_comp_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_comp_xdl.inc @@ -247,21 +247,6 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_comp_instances( PassThrough, PassThrough, PassThrough>>>& instances); -void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_tf32_comp_instances( - std::vector>>& instances); #endif #ifdef CK_ENABLE_BF16 @@ -432,21 +417,6 @@ void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_comp_instances( PassThrough, PassThrough, PassThrough>>>& instances); -void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_comp_instances( - std::vector>>& instances); #endif } // namespace instance diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_inter_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_inter_xdl.inc index a4fb152828e..ac7a773aff4 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_inter_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_inter_xdl.inc @@ -135,21 +135,6 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_mem_inter_instances PassThrough, PassThrough, PassThrough>>>& instances); -void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_tf32_mem_inter_instances( - std::vector>>& instances); #endif #ifdef CK_ENABLE_BF16 diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_intra_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_intra_xdl.inc index 0d07d9da43d..68cbc56b41c 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_intra_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_intra_xdl.inc @@ -135,21 +135,6 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_mem_intra_instances PassThrough, PassThrough, PassThrough>>>& instances); -void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_tf32_mem_intra_instances( - std::vector>>& instances); #endif #ifdef CK_ENABLE_BF16 diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc index b0ba477dba9..7f3c9d9d365 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc @@ -120,21 +120,6 @@ void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances( PassThrough, PassThrough, PassThrough>>>& instances); -void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_tf32_instances( - std::vector>>& instances); #endif // grouped conv2d forward, NHWGC/GKYXC/NHWGK @@ -307,21 +292,6 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_instances( PassThrough, PassThrough, PassThrough>>>& instances); -void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_tf32_instances( - std::vector>>& instances); #endif #ifdef CK_ENABLE_INT8 @@ -429,22 +399,6 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_16x16_instances( PassThrough, PassThrough, PassThrough>>>& instances); - -void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_tf32_instances( - std::vector>>& instances); #endif #ifdef CK_ENABLE_BF16 diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl_merged_groups.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl_merged_groups.inc index adf8aa72a65..eedbd1abd08 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl_merged_groups.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl_merged_groups.inc @@ -113,22 +113,6 @@ void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_f32_insta PassThrough, PassThrough, PassThrough>>>& instances); - -void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_f32_tf32_instances( - std::vector>>& instances); #endif #ifdef CK_ENABLE_INT8 @@ -273,21 +257,6 @@ void add_device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_f32_in PassThrough, PassThrough, PassThrough>>>& instances); -void add_device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_f32_tf32_instances( - std::vector>>& instances); #endif } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt index 0e48c974cae..5987b90685a 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt @@ -5,7 +5,6 @@ set(GROUPED_CONV2D_FWD xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp - xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_tf32_instance.cpp # NHWGC, GKYXC, NHWGK xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp @@ -19,11 +18,9 @@ set(GROUPED_CONV2D_FWD xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_bf16_instance.cpp xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_instance.cpp xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_instance.cpp - xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_tf32_instance.cpp xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_instance.cpp # NGCHW, GKCYX, NGKHW xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_instance.cpp - xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_tf32_instance.cpp xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_16x16_instance.cpp xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_16x16_instance.cpp xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_16x16_instance.cpp @@ -45,7 +42,6 @@ set(GROUPED_CONV2D_FWD xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_bf16_instance.cpp xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_f16_instance.cpp xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_f32_instance.cpp - xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_f32_tf32_instance.cpp #mem # NHWGC, GKYXC, NHWGK xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp @@ -61,12 +57,10 @@ set(GROUPED_CONV2D_FWD xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_mem_intra_instance.cpp xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_mem_intra_instance.cpp xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_mem_intra_instance.cpp - xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_tf32_mem_intra_instance.cpp # NGCHW, GKCYX, NGKHW xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_mem_inter_instance.cpp xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_mem_inter_instance.cpp xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_mem_inter_instance.cpp - xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_tf32_mem_inter_instance.cpp #comp # NHWGC, GKYXC, NHWGK xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp @@ -81,7 +75,6 @@ set(GROUPED_CONV2D_FWD # NGCHW, GKCYX, NGKHW xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_comp_instance.cpp xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_comp_instance.cpp - xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_tf32_comp_instance.cpp xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_comp_2x_instance.cpp xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_comp_2x_instance.cpp xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_comp_part2_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_tf32_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_tf32_comp_instance.cpp deleted file mode 100644 index ad9ad654aab..00000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_tf32_comp_instance.cpp +++ /dev/null @@ -1,41 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { -// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_tf32_comp_instances( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_f32_tf32_comp_instances<2, - NGCHW, - GKCYX, - Empty_Tuple, - NGKHW, - ConvFwdDefault>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_tf32_instance.cpp deleted file mode 100644 index 9c8589c7b3f..00000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_tf32_instance.cpp +++ /dev/null @@ -1,66 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { -// Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k] -void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_tf32_instances( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_f32_tf32_instances<2, - GNHWC, - GKYXC, - Empty_Tuple, - GNHWK, - ConvFwdDefault>{}); - - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f32_tf32_instances<2, - GNHWC, - GKYXC, - Empty_Tuple, - GNHWK, - ConvFwd1x1P0>{}); - - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_f32_tf32_instances<2, - GNHWC, - GKYXC, - Empty_Tuple, - GNHWK, - ConvFwd1x1S1P0>{}); - - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f32_tf32_instances<2, - GNHWC, - GKYXC, - Empty_Tuple, - GNHWK, - ConvFwdOddC>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_tf32_instance.cpp deleted file mode 100644 index 6f921c24322..00000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_tf32_instance.cpp +++ /dev/null @@ -1,56 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { -// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_tf32_instances( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_f32_tf32_instances<2, - NGCHW, - GKCYX, - Empty_Tuple, - NGKHW, - ConvFwdDefault>{}); - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f32_tf32_instances<2, - NGCHW, - GKCYX, - Empty_Tuple, - NGKHW, - ConvFwd1x1P0>{}); - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_f32_tf32_instances<2, - NGCHW, - GKCYX, - Empty_Tuple, - NGKHW, - ConvFwd1x1S1P0>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_tf32_instance.cpp deleted file mode 100644 index 451d5823996..00000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_tf32_instance.cpp +++ /dev/null @@ -1,41 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { -// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_tf32_instances( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_f32_tf32_generic_instances<2, - NGCHW, - GKYXC, - Empty_Tuple, - NGKHW, - ConvFwdDefault>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_tf32_mem_inter_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_tf32_mem_inter_instance.cpp deleted file mode 100644 index 98e52ab15da..00000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_tf32_mem_inter_instance.cpp +++ /dev/null @@ -1,42 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { -// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_tf32_mem_inter_instances( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<2, - NGCHW, - GKCYX, - Empty_Tuple, - NGKHW, - ConvFwdDefault, - Interwave>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_tf32_mem_intra_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_tf32_mem_intra_instance.cpp deleted file mode 100644 index 5585de5b4a9..00000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_tf32_mem_intra_instance.cpp +++ /dev/null @@ -1,42 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { -// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_tf32_mem_intra_instances( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<2, - NGCHW, - GKCYX, - Empty_Tuple, - NGKHW, - ConvFwdDefault, - Intrawave>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_f32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_f32_tf32_instance.cpp deleted file mode 100644 index 8af95f920cd..00000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_f32_tf32_instance.cpp +++ /dev/null @@ -1,50 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { -// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_f32_tf32_instances( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_merged_groups_f32_tf32_instances<2, - NGCHW, - GKCYX, - Empty_Tuple, - NGKHW, - ConvFwdDefault>{}); - - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_merged_groups_f32_tf32_instances<2, - NGCHW, - GKCYX, - Empty_Tuple, - NGKHW, - ConvFwd3x3>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt index 5b1f681becd..5774db21c98 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt @@ -29,7 +29,6 @@ set(GROUPED_CONV3D_FWD xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_f32_instance.cpp - xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_f32_tf32_instance.cpp xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instance.cpp xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_inter_instance.cpp @@ -108,13 +107,6 @@ generate_sharded_instantiations( SRC_LIST GROUPED_CONV3D_FWD OUTPUT_DIR ${GENERATED_DIR}/xdl/mem ) -generate_sharded_instantiations( - INSTANCES_NAME device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_mem_inter_instances - TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_mem_inter_instance.in - NUM_SHARDS 10 - SRC_LIST GROUPED_CONV3D_FWD - OUTPUT_DIR ${GENERATED_DIR}/xdl/mem -) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instances @@ -137,13 +129,6 @@ generate_sharded_instantiations( SRC_LIST GROUPED_CONV3D_FWD OUTPUT_DIR ${GENERATED_DIR}/xdl/mem ) -generate_sharded_instantiations( - INSTANCES_NAME device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_mem_intra_instances - TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_mem_intra_instance.in - NUM_SHARDS 10 - SRC_LIST GROUPED_CONV3D_FWD - OUTPUT_DIR ${GENERATED_DIR}/xdl/mem -) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_comp_instance.cpp deleted file mode 100644 index bb62769b3b4..00000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_comp_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_comp_instances( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_f32_tf32_comp_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwdDefault>{}); - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_f32_tf32_comp_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwd1x1P0>{}); - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_f32_tf32_comp_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwd1x1S1P0>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_mem_inter_instance.in b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_mem_inter_instance.in deleted file mode 100644 index 00e39603e71..00000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_mem_inter_instance.in +++ /dev/null @@ -1,67 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/utility/filter_tuple.hpp" - -namespace ck::tensor_operation::device::instance { - -using device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_mem_inter_instances = - std::vector>>; -template -void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_mem_inter_instances_shard( - device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_mem_inter_instances& instances) -{ - add_device_operation_instances( - instances, - ck::util::filter_tuple_by_modulo_t< - device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwdDefault, - Interwave>, - Shards, - ShardIndex>{}); - add_device_operation_instances( - instances, - ck::util::filter_tuple_by_modulo_t< - device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwd1x1P0, - Interwave>, - Shards, - ShardIndex>{}); - add_device_operation_instances( - instances, - ck::util::filter_tuple_by_modulo_t< - device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwd1x1S1P0, - Interwave>, - Shards, - ShardIndex>{}); -} - -} // namespace ck::tensor_operation::device::instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_mem_intra_instance.in b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_mem_intra_instance.in deleted file mode 100644 index 9e13fddd32e..00000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_mem_intra_instance.in +++ /dev/null @@ -1,67 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/utility/filter_tuple.hpp" - -namespace ck::tensor_operation::device::instance { - -using device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_mem_intra_instances = - std::vector>>; -template -void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_mem_intra_instances_shard( - device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_mem_intra_instances& instances) -{ - add_device_operation_instances( - instances, - ck::util::filter_tuple_by_modulo_t< - device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwdDefault, - Intrawave>, - Shards, - ShardIndex>{}); - add_device_operation_instances( - instances, - ck::util::filter_tuple_by_modulo_t< - device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwd1x1P0, - Intrawave>, - Shards, - ShardIndex>{}); - add_device_operation_instances( - instances, - ck::util::filter_tuple_by_modulo_t< - device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwd1x1S1P0, - Intrawave>, - Shards, - ShardIndex>{}); -} - -} // namespace ck::tensor_operation::device::instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_f32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_f32_tf32_instance.cpp deleted file mode 100644 index 753d452990f..00000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_f32_tf32_instance.cpp +++ /dev/null @@ -1,49 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -void add_device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_f32_tf32_instances( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_merged_groups_f32_tf32_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwdDefault>{}); - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_merged_groups_f32_tf32_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwd3x3>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck From a8d9fbe193b1a645248b29cb1204e58eeeb078c4 Mon Sep 17 00:00:00 2001 From: yingmaolu Date: Thu, 25 Sep 2025 11:47:39 +0800 Subject: [PATCH 13/19] remove useless bwd instances --- .../grouped_convolution_backward_weight.hpp | 41 +------ ...rouped_convolution_backward_weight_xdl.inc | 111 ------------------ .../grouped_conv1d_bwd_weight/CMakeLists.txt | 1 - ...t_xdl_gnwc_gkxc_gnwk_f32_tf32_instance.cpp | 47 -------- .../grouped_conv2d_bwd_data/CMakeLists.txt | 5 - ...dl_gnhwc_gkyxc_gnhwk_f32_tf32_instance.cpp | 51 -------- ...hw_gkcyx_ngkhw_f32_tf32_16_16_instance.cpp | 42 ------- ...dl_ngchw_gkcyx_ngkhw_f32_tf32_instance.cpp | 42 ------- ...dl_ngchw_gkyxc_ngkhw_f32_tf32_instance.cpp | 42 ------- .../grouped_conv2d_bwd_weight/CMakeLists.txt | 5 - ...gnhwk_f32_tf32_default_pipev1_instance.cpp | 42 ------- ...dl_gnhwc_gkyxc_gnhwk_f32_tf32_instance.cpp | 50 -------- ...xc_gnhwk_f32_tf32_pad0_pipev1_instance.cpp | 42 ------- ...dl_ngchw_gkcyx_ngkhw_f32_tf32_instance.cpp | 51 -------- ...dl_ngchw_gkyxc_ngkhw_f32_tf32_instance.cpp | 41 ------- .../grouped_conv3d_bwd_weight/CMakeLists.txt | 3 - ...gndhwc_gkzyxc_gndhwk_f32_tf32_instance.cpp | 49 -------- ...ngcdhw_gkczyx_ngkdhw_f32_tf32_instance.cpp | 51 -------- ...ngcdhw_gkzyxc_ngkdhw_f32_tf32_instance.cpp | 41 ------- 19 files changed, 2 insertions(+), 755 deletions(-) delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_tf32_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_tf32_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f32_tf32_16_16_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f32_tf32_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkyxc_ngkhw_f32_tf32_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_tf32_default_pipev1_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_tf32_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_tf32_pad0_pipev1_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_xdl_ngchw_gkcyx_ngkhw_f32_tf32_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/ngchw_gkyxc_ngkhw/device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_f32_tf32_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/gndhwc_gkzyxc_gndhwk/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_tf32_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ngcdhw_gkzyxc_ngkdhw/device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_f32_tf32_instance.cpp diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp index 6d20b39ad27..5e62793dfb5 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp @@ -278,12 +278,7 @@ struct DeviceOperationInstanceFactory, "Error: ComputeTypeA and ComputeTypeB should be the same"); - if constexpr(is_same_v) - { - add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_tf32_instances( - op_ptrs); - } - else + if constexpr(is_same_v) { add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances( op_ptrs); @@ -321,14 +316,7 @@ struct DeviceOperationInstanceFactory, "Error: ComputeTypeA and ComputeTypeB should be the same"); - if constexpr(is_same_v) - { - add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_tf32_default_pipev1_instances( - op_ptrs); - add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_tf32_pad0_pipev1_instances( - op_ptrs); - } - else + if constexpr(is_same_v) { add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances( op_ptrs); @@ -561,11 +549,6 @@ struct DeviceOperationInstanceFactory) - { - add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkcyx_ngkhw_f32_tf32_instances( - op_ptrs); - } } #endif } @@ -603,11 +586,6 @@ struct DeviceOperationInstanceFactory) - { - add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_f32_tf32_instances( - op_ptrs); - } } #endif } @@ -628,11 +606,6 @@ struct DeviceOperationInstanceFactory) - { - add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_tf32_instances( - op_ptrs); - } } #endif #ifdef CK_ENABLE_FP16 @@ -860,11 +833,6 @@ struct DeviceOperationInstanceFactory) - { - add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_instances( - op_ptrs); - } } #endif } @@ -902,11 +870,6 @@ struct DeviceOperationInstanceFactory) - { - add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_f32_tf32_instances( - op_ptrs); - } } #endif } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_xdl.inc index 7086f7034cf..a8dffa4dda4 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_xdl.inc @@ -47,19 +47,6 @@ void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances( PassThrough, PassThrough, PassThrough>>>& instances); -void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_tf32_instances( - std::vector>>& instances); #endif // conv2d backward weight #ifdef CK_ENABLE_BF16 @@ -137,20 +124,6 @@ void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_default_pipe PassThrough, PassThrough>>>& instances); -void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_tf32_default_pipev1_instances( - std::vector>>& instances); - void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_pad0_pipev1_instances( std::vector>>& instances); - -void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_tf32_pad0_pipev1_instances( - std::vector>>& instances); #endif #ifdef CK_ENABLE_BF16 void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instances( @@ -637,20 +596,6 @@ void add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_f32_instances( PassThrough, PassThrough>>>& instances); -void add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_f32_tf32_instances( - std::vector>>& instances); - void add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkcyx_ngkhw_f32_instances( std::vector>>& instances); -void add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkcyx_ngkhw_f32_tf32_instances( - std::vector>>& instances); - void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_default_pipev2_instances( std::vector>>& instances); - -void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_tf32_instances( - std::vector>>& instances); #endif #ifdef CK_ENABLE_BF16 void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances( @@ -1270,20 +1187,6 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_f32_instances PassThrough, PassThrough>>>& instances); -void add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_f32_tf32_instances( - std::vector>>& instances); - void add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkczyx_ngkdhw_f32_instances( std::vector>>& instances); -void add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_instances( - std::vector>>& instances); - void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( std::vector>>& instances) -{ - // 1. Default - add_device_operation_instances( - instances, - device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_tf32_instances<1, - GNWC, - GKXC, - GNWK, - ConvBwdWeightDefault>{}); - // 2. Filter1x1Stride1Pad0 - add_device_operation_instances(instances, - device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_tf32_instances< - 1, - GNWC, - GKXC, - GNWK, - ConvBwdWeightFilter1x1Stride1Pad0>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt index 6b8df51e491..4cc2fdeab7f 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt @@ -4,7 +4,6 @@ add_instance_library( xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp - xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_tf32_instance.cpp xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp @@ -16,19 +15,15 @@ add_instance_library( xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkyxc_ngkhw_f16_instance.cpp xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkyxc_ngkhw_bf16_instance.cpp xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkyxc_ngkhw_f32_instance.cpp - xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkyxc_ngkhw_f32_tf32_instance.cpp xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f16_16_16_instance.cpp xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_bf16_16_16_instance.cpp xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f32_16_16_instance.cpp - xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f32_tf32_16_16_instance.cpp xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f16_instance.cpp xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_bf16_instance.cpp xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f32_instance.cpp - xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f32_tf32_instance.cpp xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f16_vec_transpose_instance.cpp xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_bf16_vec_transpose_instance.cpp xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f32_vec_transpose_instance.cpp - xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f32_tf32_vec_transpose_instance.cpp wmma/device_grouped_conv2d_bwd_data_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instance.cpp wmma/device_grouped_conv2d_bwd_data_wmma_nhwgc_gkyxc_nhwgk_f16_1x1s1p0_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_tf32_instance.cpp deleted file mode 100644 index 4bfd07f60d0..00000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_tf32_instance.cpp +++ /dev/null @@ -1,51 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_tf32_instances( - std::vector>>& instances) -{ - // 1. Default - add_device_operation_instances( - instances, - device_grouped_conv_bwd_data_xdl_f32_tf32_16_16_instances<2, - GNHWK, - GKYXC, - Empty_Tuple, - GNHWC, - ConvBwdDataDefault>{}); - // 2. Filter1x1Stride1Pad0 - add_device_operation_instances(instances, - device_grouped_conv_bwd_data_xdl_f32_tf32_16_16_instances< - 2, - GNHWK, - GKYXC, - Empty_Tuple, - GNHWC, - ConvBwdDataFilter1x1Stride1Pad0>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f32_tf32_16_16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f32_tf32_16_16_instance.cpp deleted file mode 100644 index 5fe6268a091..00000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f32_tf32_16_16_instance.cpp +++ /dev/null @@ -1,42 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_transpose_xdl_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_tf32_16_16_instances( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_grouped_conv_bwd_data_xdl_f32_tf32_16_16_instances<2, - NGKHW, - GKCYX, - Empty_Tuple, - NGCHW, - ConvBwdDataDefault>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f32_tf32_instance.cpp deleted file mode 100644 index cc103cd4f13..00000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f32_tf32_instance.cpp +++ /dev/null @@ -1,42 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_transpose_xdl_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_tf32_instances( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_grouped_conv_bwd_data_xdl_f32_tf32_instances<2, - NGKHW, - GKCYX, - Empty_Tuple, - NGCHW, - ConvBwdDataDefault>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkyxc_ngkhw_f32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkyxc_ngkhw_f32_tf32_instance.cpp deleted file mode 100644 index 6af5fc7fbba..00000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkyxc_ngkhw_f32_tf32_instance.cpp +++ /dev/null @@ -1,42 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_transpose_xdl_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkyxc_ngchw_f32_tf32_instances( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_grouped_conv_bwd_data_xdl_f32_tf32_generic_instances<2, - NGKHW, - GKYXC, - Empty_Tuple, - NGCHW, - ConvBwdDataDefault>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt index 5eb7650746c..2fd7174a7f1 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt @@ -2,14 +2,11 @@ set(GROUPED_CONV2D_BWD_WEIGHT xdl/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp xdl/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp - xdl/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_tf32_instance.cpp xdl/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instance.cpp xdl/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_default_pipev1_instance.cpp xdl/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_pad0_pipev1_instance.cpp xdl/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_default_pipev1_instance.cpp - xdl/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_tf32_default_pipev1_instance.cpp xdl/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_pad0_pipev1_instance.cpp - xdl/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_tf32_pad0_pipev1_instance.cpp xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp @@ -47,7 +44,6 @@ set(GROUPED_CONV2D_BWD_WEIGHT xdl/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_xdl_ngchw_gkcyx_ngkhw_f16_instance.cpp xdl/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_xdl_ngchw_gkcyx_ngkhw_f32_instance.cpp - xdl/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_xdl_ngchw_gkcyx_ngkhw_f32_tf32_instance.cpp xdl/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_xdl_ngchw_gkcyx_ngkhw_bf16_instance.cpp xdl/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_f16_pipev2_instance.cpp xdl/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_f16_pipev5_instance.cpp @@ -59,7 +55,6 @@ set(GROUPED_CONV2D_BWD_WEIGHT xdl/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_bf16_pipev1_part2_instance.cpp xdl/ngchw_gkyxc_ngkhw/device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_f32_instance.cpp - xdl/ngchw_gkyxc_ngkhw/device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_f32_tf32_instance.cpp xdl/ngchw_gkyxc_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev1_instance.cpp xdl/ngchw_gkyxc_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev1_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_tf32_default_pipev1_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_tf32_default_pipev1_instance.cpp deleted file mode 100644 index 404fe4f5225..00000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_tf32_default_pipev1_instance.cpp +++ /dev/null @@ -1,42 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k] -void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_tf32_default_pipev1_instances( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f32_tf32_instances< - 2, - GNHWC, - GKYXC, - GNHWK, - ConvBwdWeightDefault, - BlockGemmPipelineScheduler::Intrawave, - BlockGemmPipelineVersion::v1>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_tf32_instance.cpp deleted file mode 100644 index 569edd62ff6..00000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_tf32_instance.cpp +++ /dev/null @@ -1,50 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k] -void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_tf32_instances( - std::vector>>& instances) -{ - // 1. Default - add_device_operation_instances( - instances, - device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_tf32_generic_instances< - 2, - GNHWC, - GKYXC, - GNHWK, - ConvBwdWeightDefault>{}); - // 2. Filter1x1Stride1Pad0 - add_device_operation_instances( - instances, - device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_tf32_generic_instances< - 2, - GNHWC, - GKYXC, - GNHWK, - ConvBwdWeightFilter1x1Stride1Pad0>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_tf32_pad0_pipev1_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_tf32_pad0_pipev1_instance.cpp deleted file mode 100644 index 91bee86045c..00000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_tf32_pad0_pipev1_instance.cpp +++ /dev/null @@ -1,42 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k] -void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_tf32_pad0_pipev1_instances( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f32_tf32_instances< - 2, - GNHWC, - GKYXC, - GNHWK, - ConvBwdWeightFilter1x1Stride1Pad0, - BlockGemmPipelineScheduler::Intrawave, - BlockGemmPipelineVersion::v1>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_xdl_ngchw_gkcyx_ngkhw_f32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_xdl_ngchw_gkcyx_ngkhw_f32_tf32_instance.cpp deleted file mode 100644 index 9d56ee5c03a..00000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_xdl_ngchw_gkcyx_ngkhw_f32_tf32_instance.cpp +++ /dev/null @@ -1,51 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -void add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkcyx_ngkhw_f32_tf32_instances( - std::vector>>& instances) -{ - // 1. Default - add_device_operation_instances( - instances, - device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_tf32_instances<2, - NGCHW, - GKCYX, - NGKHW, - ConvBwdWeightDefault, - 1, - 1>{}); - add_device_operation_instances( - instances, - device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_tf32_instances<2, - NGCHW, - GKCYX, - NGKHW, - ConvBwdWeightDefault, - 4, - 4>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/ngchw_gkyxc_ngkhw/device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_f32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/ngchw_gkyxc_ngkhw/device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_f32_tf32_instance.cpp deleted file mode 100644 index f39bbc7120d..00000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/ngchw_gkyxc_ngkhw/device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_f32_tf32_instance.cpp +++ /dev/null @@ -1,41 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -void add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_f32_tf32_instances( - std::vector>>& instances) -{ - // 1. Default - add_device_operation_instances( - instances, - device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_tf32_generic_instances< - 2, - NGCHW, - GKYXC, - NGKHW, - ConvBwdWeightDefault>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt index 301641ffde3..f9922b1f375 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt @@ -2,7 +2,6 @@ set(GROUPED_CONV3D_BWD_WEIGHT xdl/gndhwc_gkzyxc_gndhwk/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp xdl/gndhwc_gkzyxc_gndhwk/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp - xdl/gndhwc_gkzyxc_gndhwk/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_tf32_instance.cpp xdl/gndhwc_gkzyxc_gndhwk/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instance.cpp xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp @@ -40,13 +39,11 @@ set(GROUPED_CONV3D_BWD_WEIGHT xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev5_irregular_instance.cpp xdl/ngcdhw_gkzyxc_ngkdhw/device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_f32_instance.cpp - xdl/ngcdhw_gkzyxc_ngkdhw/device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_f32_tf32_instance.cpp xdl/ngcdhw_gkzyxc_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev1_instance.cpp xdl/ngcdhw_gkzyxc_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev1_instance.cpp xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkczyx_ngkdhw_f32_instance.cpp - xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_instance.cpp xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_f16_pipev2_instance.cpp xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_f16_pipev5_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/gndhwc_gkzyxc_gndhwk/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/gndhwc_gkzyxc_gndhwk/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_tf32_instance.cpp deleted file mode 100644 index 88746535329..00000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/gndhwc_gkzyxc_gndhwk/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_tf32_instance.cpp +++ /dev/null @@ -1,49 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_tf32_instances( - std::vector>>& instances) -{ - // 1. Default - add_device_operation_instances( - instances, - device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_tf32_generic_instances< - 3, - GNDHWC, - GKZYXC, - GNDHWK, - ConvBwdWeightDefault>{}); - // 2. Filter1x1Stride1Pad0 - add_device_operation_instances( - instances, - device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_tf32_generic_instances< - 3, - GNDHWC, - GKZYXC, - GNDHWK, - ConvBwdWeightFilter1x1Stride1Pad0>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_instance.cpp deleted file mode 100644 index 43719e9339e..00000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_instance.cpp +++ /dev/null @@ -1,51 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -void add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_instances( - std::vector>>& instances) -{ - // 1. Default - add_device_operation_instances( - instances, - device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_tf32_instances<3, - NGCDHW, - GKCZYX, - NGKDHW, - ConvBwdWeightDefault, - 1, - 1>{}); - add_device_operation_instances( - instances, - device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_tf32_instances<3, - NGCDHW, - GKCZYX, - NGKDHW, - ConvBwdWeightDefault, - 4, - 4>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ngcdhw_gkzyxc_ngkdhw/device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_f32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ngcdhw_gkzyxc_ngkdhw/device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_f32_tf32_instance.cpp deleted file mode 100644 index a819c3fe996..00000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ngcdhw_gkzyxc_ngkdhw/device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_f32_tf32_instance.cpp +++ /dev/null @@ -1,41 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -void add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_f32_tf32_instances( - std::vector>>& instances) -{ - // 1. Default - add_device_operation_instances( - instances, - device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_tf32_generic_instances< - 3, - NGCDHW, - GKZYXC, - NGKDHW, - ConvBwdWeightDefault>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck From 94da54bbd39e18bb12ca558a4477d680852ee14e Mon Sep 17 00:00:00 2001 From: yingmaolu Date: Thu, 25 Sep 2025 12:02:03 +0800 Subject: [PATCH 14/19] change check_err for tf32 --- .../convnd_bwd_data_common.hpp | 5 ++- include/ck/library/utility/check_err.hpp | 37 ++++++++++++++----- 2 files changed, 31 insertions(+), 11 deletions(-) diff --git a/example/17_convnd_bwd_data/convnd_bwd_data_common.hpp b/example/17_convnd_bwd_data/convnd_bwd_data_common.hpp index aead9734901..6f8230dc635 100644 --- a/example/17_convnd_bwd_data/convnd_bwd_data_common.hpp +++ b/example/17_convnd_bwd_data/convnd_bwd_data_common.hpp @@ -174,7 +174,10 @@ int run_conv_bwd_data(bool do_verification, in_device_buf.FromDevice(in_device.mData.data()); - return ck::utils::check_err(in_device, in_host) ? 0 : 1; + return ck::utils::check_err, Tensor, ComputeDataType>( + in_device, in_host) + ? 0 + : 1; } return 0; diff --git a/include/ck/library/utility/check_err.hpp b/include/ck/library/utility/check_err.hpp index 84a166f1c5c..3637053e14b 100644 --- a/include/ck/library/utility/check_err.hpp +++ b/include/ck/library/utility/check_err.hpp @@ -157,10 +157,13 @@ double get_absolute_threshold(const double max_possible_num, const int number_of return std::max(acc_error, midway_error); } -template +template > typename std::enable_if< std::is_same_v, ranges::range_value_t> && - std::is_same_v, float>, + std::is_same_v, float> && + std::is_same_v, bool>::type check_err(const Range& out, const RefRange& ref, @@ -207,12 +210,14 @@ check_err(const Range& out, return res; } -template +template > typename std::enable_if< std::is_same_v, ranges::range_value_t> && std::is_floating_point_v> && !std::is_same_v, half_t> && - !std::is_same_v, float>, + !std::is_same_v, bool>::type check_err(const Range& out, const RefRange& ref, @@ -259,7 +264,9 @@ check_err(const Range& out, return res; } -template +template > typename std::enable_if< std::is_same_v, ranges::range_value_t> && std::is_same_v, bhalf_t>, @@ -310,7 +317,9 @@ check_err(const Range& out, return res; } -template +template > typename std::enable_if< std::is_same_v, ranges::range_value_t> && std::is_same_v, half_t>, @@ -360,7 +369,9 @@ check_err(const Range& out, return res; } -template +template > std::enable_if_t<(std::is_same_v, ranges::range_value_t> && std::is_integral_v> && !std::is_same_v, bhalf_t> && @@ -417,7 +428,9 @@ check_err(const Range& out, return res; } -template +template > std::enable_if_t<(std::is_same_v, ranges::range_value_t> && std::is_same_v, f8_t>), bool> @@ -466,7 +479,9 @@ check_err(const Range& out, return res; } -template +template > std::enable_if_t<(std::is_same_v, ranges::range_value_t> && std::is_same_v, bf8_t>), bool> @@ -511,7 +526,9 @@ check_err(const Range& out, return res; } -template +template > std::enable_if_t<(std::is_same_v, ranges::range_value_t> && std::is_same_v, f4_t>), bool> From f54bab1ae3467c683ce1a890efb397ca08bffb02 Mon Sep 17 00:00:00 2001 From: yingmaolu Date: Thu, 25 Sep 2025 12:27:30 +0800 Subject: [PATCH 15/19] fix clang-format fail --- profiler/src/profile_grouped_conv_bwd_data.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/profiler/src/profile_grouped_conv_bwd_data.cpp b/profiler/src/profile_grouped_conv_bwd_data.cpp index 62482fc35a7..95098e23011 100644 --- a/profiler/src/profile_grouped_conv_bwd_data.cpp +++ b/profiler/src/profile_grouped_conv_bwd_data.cpp @@ -107,9 +107,9 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[]) using WeiLayout = decltype(wei_layout); using InLayout = decltype(in_layout); - using OutDataType = decltype(out_type); - using WeiDataType = decltype(wei_type); - using InDataType = decltype(in_type); + using OutDataType = decltype(out_type); + using WeiDataType = decltype(wei_type); + using InDataType = decltype(in_type); using ComputeDataType = decltype(compute_type); bool pass = ck::profiler::profile_grouped_conv_bwd_data_impl Date: Fri, 26 Sep 2025 11:24:41 +0800 Subject: [PATCH 16/19] remove non-ndhwgc/nhwgc/nhwc instances --- example/17_convnd_bwd_data/CMakeLists.txt | 17 -- .../convnd_bwd_data_common.hpp | 14 +- .../convnd_bwd_data_xdl_fp32.cpp | 207 ----------------- .../convnd_bwd_data_xdl_fp32_tf32.cpp | 212 ------------------ .../gpu/device/device_conv_fwd.hpp | 3 +- .../device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp | 17 +- ...device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp | 48 +--- ...d_conv_bwd_data_transpose_xdl_instance.hpp | 36 --- ...ice_grouped_conv_bwd_data_xdl_instance.hpp | 4 +- ...rouped_conv_bwd_weight_v3_xdl_instance.hpp | 2 +- ..._conv_bwd_weight_xdl_bilinear_instance.hpp | 80 +++---- ...e_grouped_conv_bwd_weight_xdl_instance.hpp | 100 ++++----- ...ped_conv_bwd_weight_xdl_scale_instance.hpp | 78 +++---- .../gpu/grouped_convolution_backward_data.hpp | 60 ++--- .../grouped_convolution_backward_data_xdl.inc | 78 ------- .../grouped_convolution_backward_weight.hpp | 80 ++----- ...rouped_convolution_backward_weight_xdl.inc | 68 ------ .../gpu/conv2d_fwd/CMakeLists.txt | 1 - ...d_xdl_nhwc_kyxc_nhwk_f32_tf32_instance.cpp | 136 ----------- ..._ngkhw_f32_tf32_vec_transpose_instance.cpp | 42 ---- ...gndhwc_gkzyxc_gndhwk_f32_tf32_instance.cpp | 52 ----- ..._gkzyxc_ndhwgk_f32_tf32_16_16_instance.cpp | 51 ----- ..._gkczyx_ngkdhw_f32_tf32_16_16_instance.cpp | 42 ---- ...ngkdhw_f32_tf32_vec_transpose_instance.cpp | 42 ---- ...ngcdhw_gkzyxc_ngkdhw_f32_tf32_instance.cpp | 42 ---- ...ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp | 51 ----- .../grouped_conv3d_bwd_weight/CMakeLists.txt | 5 - ...dhwgk_f32_tf32_default_pipev2_instance.cpp | 42 ---- ...dhwgk_f32_tf32_default_pipev5_instance.cpp | 42 ---- ...ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp | 48 ---- ...c_ndhwgk_f32_tf32_pad0_pipev2_instance.cpp | 42 ---- ...c_ndhwgk_f32_tf32_pad0_pipev5_instance.cpp | 42 ---- .../src/profile_grouped_conv_bwd_weight.cpp | 12 +- .../src/profile_grouped_conv_fwd_clamp.cpp | 2 +- 34 files changed, 188 insertions(+), 1610 deletions(-) delete mode 100644 example/17_convnd_bwd_data/convnd_bwd_data_xdl_fp32.cpp delete mode 100644 example/17_convnd_bwd_data/convnd_bwd_data_xdl_fp32_tf32.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_tf32_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f32_tf32_vec_transpose_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_f32_tf32_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_16_16_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_16_16_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_vec_transpose_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkzyxc_ngkdhw_f32_tf32_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_bilinear/xdl/device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_default_pipev2_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_default_pipev5_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_pad0_pipev2_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_pad0_pipev5_instance.cpp diff --git a/example/17_convnd_bwd_data/CMakeLists.txt b/example/17_convnd_bwd_data/CMakeLists.txt index 70228d08938..39f9fb8ec06 100644 --- a/example/17_convnd_bwd_data/CMakeLists.txt +++ b/example/17_convnd_bwd_data/CMakeLists.txt @@ -3,23 +3,6 @@ if(result EQUAL 0) target_link_libraries(example_convnd_bwd_data_xdl_fp16 PRIVATE utility) endif() -add_example_executable(example_convnd_bwd_data_xdl_fp32 convnd_bwd_data_xdl_fp32.cpp) -if(result EQUAL 0) - target_link_libraries(example_convnd_bwd_data_xdl_fp32 PRIVATE utility) -endif() - -list(APPEND gpu_list gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_example_executable(example_convnd_bwd_data_xdl_fp32_tf32 convnd_bwd_data_xdl_fp32_tf32.cpp) - if(result EQUAL 0) - target_link_libraries(example_convnd_bwd_data_xdl_fp32_tf32 PRIVATE utility) - endif() - set(target 1) - endif() -endforeach() - add_example_executable(example_convnd_bwd_data_dl_fp16 convnd_bwd_data_dl_fp16.cpp) if(result EQUAL 0) target_link_libraries(example_convnd_bwd_data_dl_fp16 PRIVATE utility) diff --git a/example/17_convnd_bwd_data/convnd_bwd_data_common.hpp b/example/17_convnd_bwd_data/convnd_bwd_data_common.hpp index 6f8230dc635..d219df02453 100644 --- a/example/17_convnd_bwd_data/convnd_bwd_data_common.hpp +++ b/example/17_convnd_bwd_data/convnd_bwd_data_common.hpp @@ -33,8 +33,7 @@ template + typename DeviceConvNdBwdDataInstance> int run_conv_bwd_data(bool do_verification, int init_method, bool time_kernel, @@ -151,11 +150,7 @@ int run_conv_bwd_data(bool do_verification, OutDataType, InElementOp, WeiElementOp, - OutElementOp, - 0, - 0, - 0, - ComputeDataType>(); + OutElementOp>(); auto ref_invoker = ref_conv.MakeInvoker(); @@ -174,10 +169,7 @@ int run_conv_bwd_data(bool do_verification, in_device_buf.FromDevice(in_device.mData.data()); - return ck::utils::check_err, Tensor, ComputeDataType>( - in_device, in_host) - ? 0 - : 1; + return ck::utils::check_err(in_device, in_host) ? 0 : 1; } return 0; diff --git a/example/17_convnd_bwd_data/convnd_bwd_data_xdl_fp32.cpp b/example/17_convnd_bwd_data/convnd_bwd_data_xdl_fp32.cpp deleted file mode 100644 index c4037842a3a..00000000000 --- a/example/17_convnd_bwd_data/convnd_bwd_data_xdl_fp32.cpp +++ /dev/null @@ -1,207 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include "convnd_bwd_data_common.hpp" - -#include "ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp" - -using InDataType = float; -using WeiDataType = float; -using OutDataType = float; -using AccDataType = float; - -template -using S = ck::Sequence; - -using InElementOp = ck::tensor_operation::element_wise::PassThrough; -using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; -using OutElementOp = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto ConvBwdDefault = - ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default; - -template -using DeviceConvNdBwdDataInstance = ck::tensor_operation::device::DeviceConvNdBwdDataNwcKxcNwk_Xdl< - NDimSpatial, // NDimSpatial - InDataType, // InDataType - WeiDataType, // WeiDataType - OutDataType, // OutDataType - AccDataType, // AccDataType - InElementOp, // InElementwiseOperation - WeiElementOp, // WeiElementwiseOperation - OutElementOp, // OutElementwiseOperation - ConvBwdDefault, // ConvolutionBackwardDataSpecialization - 256, // BlockSize - 128, // MPerBlock - 128, // NPerBlock - 4, // K0PerBlock - 8, // K1 - 32, // MPerXdl - 32, // NPerXdl - 2, // MXdlPerWave - 2, // NXdlPerWave - S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 - S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder - S<1, 0, 2>, // ABlockTransferSrcAccessOrder - 2, // ABlockTransferSrcVectorDim - 8, // ABlockTransferSrcScalarPerVector - 8, // ABlockTransferDstScalarPerVector_K1 - true, // ABlockLdsAddExtraM - S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 - S<2, 0, 1>, // BBlockTransferThreadClusterArrangeOrder - S<0, 2, 1>, // BBlockTransferSrcAccessOrder - 1, // BBlockTransferSrcVectorDim - 2, // BBlockTransferSrcScalarPerVector - 8, // BBlockTransferDstScalarPerVector_K1 - true, // BBlockLdsAddExtraN - 7, - 1>; // GemmCThreadTransferDstScalarPerVector - -int main(int argc, char* argv[]) -{ - namespace ctc = ck::tensor_layout::convolution; - - print_helper_msg(); - - bool do_verification = true; - int init_method = 1; - bool time_kernel = false; - - ck::utils::conv::ConvParam conv_param{ - 2, 1, 128, 256, 256, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, {1, 1}}; - - if(argc == 1) - { - // use default - } - else if(argc == 4) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - } - else - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - const ck::index_t num_dim_spatial = std::stoi(argv[4]); - - conv_param = ck::utils::conv::parse_conv_param(num_dim_spatial, 5, argv); - } - - const auto in_element_op = InElementOp{}; - const auto wei_element_op = WeiElementOp{}; - const auto out_element_op = OutElementOp{}; - - if(conv_param.num_dim_spatial_ == 1) - { - using InLayout = ctc::GNWC; - using WeiLayout = ctc::GKXC; - using OutLayout = ctc::GNWK; - - const auto in_g_n_c_wis_desc = - ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed( - conv_param); - - const auto wei_g_k_c_xs_desc = - ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed( - conv_param); - - const auto out_g_n_k_wos_desc = - ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed( - conv_param); - - return run_conv_bwd_data<1, - InDataType, - WeiDataType, - OutDataType, - InElementOp, - WeiElementOp, - OutElementOp, - DeviceConvNdBwdDataInstance<1>>(do_verification, - init_method, - time_kernel, - conv_param, - in_g_n_c_wis_desc, - wei_g_k_c_xs_desc, - out_g_n_k_wos_desc, - in_element_op, - wei_element_op, - out_element_op); - } - else if(conv_param.num_dim_spatial_ == 2) - { - using InLayout = ctc::GNHWC; - using WeiLayout = ctc::GKYXC; - using OutLayout = ctc::GNHWK; - - const auto in_g_n_c_wis_desc = - ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed( - conv_param); - - const auto wei_g_k_c_xs_desc = - ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed( - conv_param); - - const auto out_g_n_k_wos_desc = - ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed( - conv_param); - - return run_conv_bwd_data<2, - InDataType, - WeiDataType, - OutDataType, - InElementOp, - WeiElementOp, - OutElementOp, - DeviceConvNdBwdDataInstance<2>>(do_verification, - init_method, - time_kernel, - conv_param, - in_g_n_c_wis_desc, - wei_g_k_c_xs_desc, - out_g_n_k_wos_desc, - in_element_op, - wei_element_op, - out_element_op); - } - else if(conv_param.num_dim_spatial_ == 3) - { - using InLayout = ctc::GNDHWC; - using WeiLayout = ctc::GKZYXC; - using OutLayout = ctc::GNDHWK; - - const auto in_g_n_c_wis_desc = - ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed( - conv_param); - - const auto wei_g_k_c_xs_desc = - ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed( - conv_param); - - const auto out_g_n_k_wos_desc = - ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed( - conv_param); - - return run_conv_bwd_data<3, - InDataType, - WeiDataType, - OutDataType, - InElementOp, - WeiElementOp, - OutElementOp, - DeviceConvNdBwdDataInstance<3>>(do_verification, - init_method, - time_kernel, - conv_param, - in_g_n_c_wis_desc, - wei_g_k_c_xs_desc, - out_g_n_k_wos_desc, - in_element_op, - wei_element_op, - out_element_op); - } - - return 0; -} diff --git a/example/17_convnd_bwd_data/convnd_bwd_data_xdl_fp32_tf32.cpp b/example/17_convnd_bwd_data/convnd_bwd_data_xdl_fp32_tf32.cpp deleted file mode 100644 index b4a0a2273a9..00000000000 --- a/example/17_convnd_bwd_data/convnd_bwd_data_xdl_fp32_tf32.cpp +++ /dev/null @@ -1,212 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include "convnd_bwd_data_common.hpp" - -#include "ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp" - -using InDataType = float; -using WeiDataType = float; -using OutDataType = float; -using AccDataType = float; -using ComputeDataType = ck::tf32_t; - -template -using S = ck::Sequence; - -using InElementOp = ck::tensor_operation::element_wise::PassThrough; -using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; -using OutElementOp = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto ConvBwdDefault = - ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default; - -template -using DeviceConvNdBwdDataInstance = ck::tensor_operation::device::DeviceConvNdBwdDataNwcKxcNwk_Xdl< - NDimSpatial, // NDimSpatial - InDataType, // InDataType - WeiDataType, // WeiDataType - OutDataType, // OutDataType - AccDataType, // AccDataType - InElementOp, // InElementwiseOperation - WeiElementOp, // WeiElementwiseOperation - OutElementOp, // OutElementwiseOperation - ConvBwdDefault, // ConvolutionBackwardDataSpecialization - 256, // BlockSize - 128, // MPerBlock - 128, // NPerBlock - 4, // K0PerBlock - 8, // K1 - 32, // MPerXdl - 32, // NPerXdl - 2, // MXdlPerWave - 2, // NXdlPerWave - S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 - S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder - S<1, 0, 2>, // ABlockTransferSrcAccessOrder - 2, // ABlockTransferSrcVectorDim - 8, // ABlockTransferSrcScalarPerVector - 8, // ABlockTransferDstScalarPerVector_K1 - true, // ABlockLdsAddExtraM - S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 - S<2, 0, 1>, // BBlockTransferThreadClusterArrangeOrder - S<0, 2, 1>, // BBlockTransferSrcAccessOrder - 1, // BBlockTransferSrcVectorDim - 2, // BBlockTransferSrcScalarPerVector - 8, // BBlockTransferDstScalarPerVector_K1 - true, // BBlockLdsAddExtraN - 7, // CThreadTransferSrcDstVectorDim - 1, // GemmCThreadTransferDstScalarPerVector - ComputeDataType>; - -int main(int argc, char* argv[]) -{ - namespace ctc = ck::tensor_layout::convolution; - - print_helper_msg(); - - bool do_verification = true; - int init_method = 1; - bool time_kernel = false; - - ck::utils::conv::ConvParam conv_param{ - 2, 1, 128, 256, 256, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, {1, 1}}; - - if(argc == 1) - { - // use default - } - else if(argc == 4) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - } - else - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - const ck::index_t num_dim_spatial = std::stoi(argv[4]); - - conv_param = ck::utils::conv::parse_conv_param(num_dim_spatial, 5, argv); - } - - const auto in_element_op = InElementOp{}; - const auto wei_element_op = WeiElementOp{}; - const auto out_element_op = OutElementOp{}; - - if(conv_param.num_dim_spatial_ == 1) - { - using InLayout = ctc::GNWC; - using WeiLayout = ctc::GKXC; - using OutLayout = ctc::GNWK; - - const auto in_g_n_c_wis_desc = - ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed( - conv_param); - - const auto wei_g_k_c_xs_desc = - ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed( - conv_param); - - const auto out_g_n_k_wos_desc = - ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed( - conv_param); - - return run_conv_bwd_data<1, - InDataType, - WeiDataType, - OutDataType, - InElementOp, - WeiElementOp, - OutElementOp, - DeviceConvNdBwdDataInstance<1>, - ComputeDataType>(do_verification, - init_method, - time_kernel, - conv_param, - in_g_n_c_wis_desc, - wei_g_k_c_xs_desc, - out_g_n_k_wos_desc, - in_element_op, - wei_element_op, - out_element_op); - } - else if(conv_param.num_dim_spatial_ == 2) - { - using InLayout = ctc::GNHWC; - using WeiLayout = ctc::GKYXC; - using OutLayout = ctc::GNHWK; - - const auto in_g_n_c_wis_desc = - ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed( - conv_param); - - const auto wei_g_k_c_xs_desc = - ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed( - conv_param); - - const auto out_g_n_k_wos_desc = - ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed( - conv_param); - - return run_conv_bwd_data<2, - InDataType, - WeiDataType, - OutDataType, - InElementOp, - WeiElementOp, - OutElementOp, - DeviceConvNdBwdDataInstance<2>, - ComputeDataType>(do_verification, - init_method, - time_kernel, - conv_param, - in_g_n_c_wis_desc, - wei_g_k_c_xs_desc, - out_g_n_k_wos_desc, - in_element_op, - wei_element_op, - out_element_op); - } - else if(conv_param.num_dim_spatial_ == 3) - { - using InLayout = ctc::GNDHWC; - using WeiLayout = ctc::GKZYXC; - using OutLayout = ctc::GNDHWK; - - const auto in_g_n_c_wis_desc = - ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed( - conv_param); - - const auto wei_g_k_c_xs_desc = - ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed( - conv_param); - - const auto out_g_n_k_wos_desc = - ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed( - conv_param); - - return run_conv_bwd_data<3, - InDataType, - WeiDataType, - OutDataType, - InElementOp, - WeiElementOp, - OutElementOp, - DeviceConvNdBwdDataInstance<3>, - ComputeDataType>(do_verification, - init_method, - time_kernel, - conv_param, - in_g_n_c_wis_desc, - wei_g_k_c_xs_desc, - out_g_n_k_wos_desc, - in_element_op, - wei_element_op, - out_element_op); - } - - return 0; -} diff --git a/include/ck/tensor_operation/gpu/device/device_conv_fwd.hpp b/include/ck/tensor_operation/gpu/device/device_conv_fwd.hpp index 9859b6d5854..4dc11dbefd7 100644 --- a/include/ck/tensor_operation/gpu/device/device_conv_fwd.hpp +++ b/include/ck/tensor_operation/gpu/device/device_conv_fwd.hpp @@ -20,8 +20,7 @@ template + typename OutElementwiseOperation> struct DeviceConvFwd : public BaseOperator { virtual std::unique_ptr diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp index 16a53c24b2b..cecfa48408f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp @@ -54,8 +54,7 @@ template + ck::index_t CThreadTransferDstScalarPerVector> struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K : public DeviceConvFwd<2, ck::tensor_layout::convolution::NHWC, @@ -66,8 +65,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K OutDataType, InElementwiseOperation, WeiElementwiseOperation, - OutElementwiseOperation, - ComputeDataType> + OutElementwiseOperation> { using DeviceOp = DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K; @@ -80,8 +78,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K using CDataType = OutDataType; // TODO make A/B datatype different - using ABDataTypeElementwise = ADataType; // for load/store and elementwise operation - using ABDataTypeGemm = ComputeDataType; // only for gemm computation + using ABDataType = InDataType; static constexpr index_t NDimSpatial = 2; @@ -334,7 +331,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K template using GridwiseGemmBase = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< BlockSize, - ABDataTypeGemm, // TODO: distinguish A/B datatype + ABDataType, // TODO: distinguish A/B datatype AccDataType, CDataType, InMemoryDataOperationEnum::Set, @@ -475,7 +472,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K { const auto kernel = kernel_gemm_xdlops_v2r3()) + if(!ck::is_xdl_wmma_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp index 403a7be9689..d0743421272 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp @@ -55,8 +55,7 @@ template + ck::index_t CThreadTransferDstScalarPerVector> struct DeviceConvNdBwdDataNwcKxcNwk_Xdl : public DeviceConvBwdData< NDimSpatial, @@ -79,14 +78,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl WeiElementwiseOperation, OutElementwiseOperation> { - - DeviceConvNdBwdDataNwcKxcNwk_Xdl() - { - static_assert(is_same_v || - (is_same_v && is_same_v), - "InDataType and ComputeDataType need to be the same or (InDataType=float and " - "ComputeDataType=tf32_t)"); - } using DeviceOp = DeviceConvNdBwdDataNwcKxcNwk_Xdl; GET_NXDL_PER_WAVE_IMPL @@ -98,7 +89,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl using CDataType = InDataType; // TODO make A/B datatype different - using ABDataType = ComputeDataType; + using ABDataType = InDataType; static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -1204,36 +1195,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl } } - void Print() const - { - std::cout << "InDataType: " << get_type_name() - << "; WeiDataType: " << get_type_name() - << "; OutDataType: " << get_type_name() - << "; AccDataType: " << get_type_name() << std::endl; - auto print_v = [](std::ostream& os, - const std::vector& v, - const std::string& name) -> std::ostream& { - os << name << ": ["; - for(size_t i = 0; i < v.size(); ++i) - { - os << v[i]; - if(i + 1 < v.size()) - os << ", "; - } - os << "]"; - return os; - }; - std::cout << "Conv params: Ndims: " << NDimSpatial << ", N: " << Conv_N_ - << ", K: " << Conv_K_ << ", C: " << Conv_C_ << "\n\t"; - print_v(std::cout, input_spatial_lengths_, "input_spatial_lengths") << "\n\t"; - print_v(std::cout, filter_spatial_lengths_, "filter_spatial_lengths") << "\n\t"; - print_v(std::cout, output_spatial_lengths_, "output_spatial_lengths") << "\n\t"; - print_v(std::cout, conv_filter_strides_, "conv_filter_strides") << "\n\t"; - print_v(std::cout, conv_filter_dilations_, "conv_filter_dilations") << "\n\t"; - print_v(std::cout, input_left_pads_, "input_left_pads") << "\n\t"; - print_v(std::cout, input_right_pads_, "input_right_pads") << std::endl; - } - const ADataType* p_a_grid_; const BDataType* p_b_grid_; CDataType* p_c_grid_; @@ -1265,11 +1226,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl template float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { - - if(stream_config.log_level_ > 0) - { - arg.Print(); - } float ave_time = 0; for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++) { diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_transpose_xdl_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_transpose_xdl_instance.hpp index 48b76ed8d83..04165382f4a 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_transpose_xdl_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_transpose_xdl_instance.hpp @@ -139,42 +139,6 @@ using device_grouped_conv_bwd_data_transpose_xdl_f32_instances = // clang-format on >; -// f32_f32_f32_f32 tf32 -template -using device_grouped_conv_bwd_data_transpose_xdl_f32_tf32_instances = - std::tuple< - // clang-format off - // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| LoopSched| AComputeType| BComputeType| MaxTranspose| MaxTranspose| - // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| | | | TransferIn| TransferOut| - // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| | | | ScalarPer| ScalarPer| - // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Vector| Vector| - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4, make_default_loop_scheduler(), TF32, TF32, 2, 2>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 4>, 4, make_default_loop_scheduler(), TF32, TF32, 2, 2>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 4>, 4, make_default_loop_scheduler(), TF32, TF32, 2, 2>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 4>, 4, make_default_loop_scheduler(), TF32, TF32, 2, 2>, - - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4, make_default_loop_scheduler(), TF32, TF32, 4, 4>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 4>, 4, make_default_loop_scheduler(), TF32, TF32, 4, 4>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 4>, 4, make_default_loop_scheduler(), TF32, TF32, 4, 4>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 4>, 4, make_default_loop_scheduler(), TF32, TF32, 4, 4>, - - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4, make_default_loop_scheduler(), TF32, TF32, 1, 2>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 4>, 4, make_default_loop_scheduler(), TF32, TF32, 1, 2>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 4>, 4, make_default_loop_scheduler(), TF32, TF32, 1, 2>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 4>, 4, make_default_loop_scheduler(), TF32, TF32, 1, 2>, - - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4, make_default_loop_scheduler(), TF32, TF32, 2, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 4>, 4, make_default_loop_scheduler(), TF32, TF32, 2, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 4>, 4, make_default_loop_scheduler(), TF32, TF32, 2, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 4>, 4, make_default_loop_scheduler(), TF32, TF32, 2, 1> - // clang-format on - >; - } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp index 064666c8641..cd84b0e8313 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp @@ -128,13 +128,13 @@ using device_grouped_conv_bwd_data_xdl_f16_nchw_instances = // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // generic instance DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 8, 1, 8>, 1>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 8, 1, 8>, 1>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 16, 1, 8>, 4>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 4>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 8>, 4>, diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp index 114a6cab35d..4e096e5b449 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp @@ -95,7 +95,7 @@ using device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f16_instances = std::tuple DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 32, 8, 32, 32, 1, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>, DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 128, 32, 8, 32, 32, 1, 4, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>, DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 32, 8, 32, 32, 2, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>, - DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 32, 8, 32, 32, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>, + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 32, 8, 32, 32, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>, DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 80, 32, 8, 16, 16, 4, 5, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 5, 4, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>, DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 112, 32, 8, 16, 16, 4, 7, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 7, 4, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion> // clang-format on diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_bilinear_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_bilinear_instance.hpp index a81ec510819..362195a819a 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_bilinear_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_bilinear_instance.hpp @@ -120,26 +120,26 @@ using device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_bilinear_instances = std: //#########################################| Spatial| | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | // generic instance - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, 1, 1, S<1, 16, 1, 4>, 1>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, 1, 1, S<1, 16, 1, 4>, 2>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, 1, 1, S<1, 16, 1, 4>, 2>, // instance for small conv.K // for fp16 conv.K and conv.C must be divisible by 2 // since half_t atomic_add require scalar_per_x_vector % 2 == 0 - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 1, true, 1, 1, S<1, 32, 1, 4>, 2>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 1, true, 1, 1, S<1, 32, 1, 4>, 2>, DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8>, DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8>, DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8> // clang-format on >; @@ -156,23 +156,23 @@ using device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_bilinear_instances = std //#########################################| Spatial| | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | // generic instance - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, 1, 1, S<1, 16, 1, 4>, 1>, // instance for small conv.K - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 1, true, 1, 1, S<1, 32, 1, 4>, 1>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 1, true, 1, 1, S<1, 32, 1, 4>, 1>, DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 4>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 4>, DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4> // clang-format on >; @@ -190,25 +190,25 @@ using device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_comp_bf8_f8_bilinear_inst //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | #if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 // generic instance - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, 1, 1, S<1, 16, 1, 4>, 2, BF8, F8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, 1, 1, S<1, 16, 1, 4>, 2, BF8, F8>, // instance for small conv.K // for fp16 conv.K and conv.C must be divisible by 2 // since half_t atomic_add require scalar_per_x_vector % 2 == 0 - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 1, true, 1, 1, S<1, 32, 1, 4>, 2, BF8, F8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 1, true, 1, 1, S<1, 32, 1, 4>, 2, BF8, F8>, DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8>, DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8>, DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8> #endif // clang-format on diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp index 10f4f3b69a0..095b847a0c6 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp @@ -183,25 +183,25 @@ using device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_instances = std::tuple< //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| | | ScalarPerVector| ScalarPerVector| //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | | // generic instance - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, 1, 1, S<1, 16, 1, 4>, 2, F16, F16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, 1, 1, S<1, 16, 1, 4>, 2, F16, F16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, // instance for small conv.K // for fp16 conv.K and conv.C must be divisible by 2 // since half_t atomic_add require scalar_per_x_vector % 2 == 0 - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 1, true, 1, 1, S<1, 32, 1, 4>, 2, F16, F16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 1, true, 1, 1, S<1, 32, 1, 4>, 2, F16, F16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, F16, F16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 8, F16, F16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8, F16, F16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, F16, F16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, F16, F16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, F16, F16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8, F16, F16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, F16, F16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, F16, F16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, F16, F16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, F16, F16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, F16, F16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8, F16, F16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, F16, F16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8, F16, F16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, F16, F16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8, F16, F16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, F16, F16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8, F16, F16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8, F16, F16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, F16, F16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8, F16, F16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, F16, F16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector> // clang-format on >; @@ -236,23 +236,23 @@ using device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_f32_bf16_instances = std //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| | | ScalarPerVector| ScalarPerVector| //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | | // generic instance - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, // instance for small conv.K - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 1, true, 1, 1, S<1, 32, 1, 4>, 1, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 1, true, 1, 1, S<1, 32, 1, 4>, 1, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 4, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 4, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 4, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 4, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 4, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector> // clang-format on >; @@ -271,25 +271,25 @@ using device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_instances = std::tuple< //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| | | ScalarPerVector| ScalarPerVector| //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | | // generic instance - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, 1, 1, S<1, 16, 1, 4>, 2, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, 1, 1, S<1, 16, 1, 4>, 2, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, // instance for small conv.K // for bf16 conv.K and conv.C must be divisible by 2 // since half_t atomic_add require scalar_per_x_vector % 2 == 0 - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 1, true, 1, 1, S<1, 32, 1, 4>, 2, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 1, true, 1, 1, S<1, 32, 1, 4>, 2, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 8, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, BF16, BF16, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector> // clang-format on >; @@ -309,25 +309,25 @@ using device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_comp_bf8_f8_instances = s //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | | #if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 // generic instance - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, 1, 1, S<1, 16, 1, 4>, 2, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, 1, 1, S<1, 16, 1, 4>, 2, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, // instance for small conv.K // for fp16 conv.K and conv.C must be divisible by 2 // since half_t atomic_add require scalar_per_x_vector % 2 == 0 - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 1, true, 1, 1, S<1, 32, 1, 4>, 2, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 1, true, 1, 1, S<1, 32, 1, 4>, 2, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, - DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector>, DeviceGroupedConvBwdWeight_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8, MaxTransposeTransferSrcScalarPerVector, MaxTransposeTransferDstScalarPerVector> #endif // clang-format on diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_scale_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_scale_instance.hpp index dc365c4fdcf..9305076e97b 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_scale_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_scale_instance.hpp @@ -120,25 +120,25 @@ using device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_scale_instances = std::tu //#########################################| Spatial| | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | // generic instance - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, 1, 1, S<1, 16, 1, 4>, 2>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, 1, 1, S<1, 16, 1, 4>, 2>, // instance for small conv.K // for fp16 conv.K and conv.C must be divisible by 2 // since half_t atomic_add require scalar_per_x_vector % 2 == 0 - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 1, true, 1, 1, S<1, 32, 1, 4>, 2>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 1, true, 1, 1, S<1, 32, 1, 4>, 2>, DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8>, DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8>, DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8> // clang-format on >; @@ -155,23 +155,23 @@ using device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_scale_instances = std::t //#########################################| Spatial| | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | // generic instance - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, 1, 1, S<1, 16, 1, 4>, 1>, // instance for small conv.K - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 1, true, 1, 1, S<1, 32, 1, 4>, 1>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 1, true, 1, 1, S<1, 32, 1, 4>, 1>, DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 4>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 4>, DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4> // clang-format on >; @@ -189,25 +189,25 @@ using device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_comp_bf8_f8_scale_instanc //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | #if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 // generic instance - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, 1, 1, S<1, 16, 1, 4>, 2, BF8, F8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, 1, 1, S<1, 16, 1, 4>, 2, BF8, F8>, // instance for small conv.K // for fp16 conv.K and conv.C must be divisible by 2 // since half_t atomic_add require scalar_per_x_vector % 2 == 0 - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 1, true, 1, 1, S<1, 32, 1, 4>, 2, BF8, F8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 1, true, 1, 1, S<1, 32, 1, 4>, 2, BF8, F8>, DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8>, DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, - DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 32, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 8, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8>, DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 4, 4>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 2, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8> #endif // clang-format on diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp index 04608e1996e..617ddad2807 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp @@ -84,20 +84,10 @@ struct DeviceOperationInstanceFactory< #endif #ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && - is_same_v) + is_same_v && is_same_v && + is_same_v) { - static_assert(is_same_v, - "Error: this operator requires the same compute type"); - if constexpr(is_same_v) - { - add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_tf32_instances( - op_ptrs); - } - else - { - add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances( - op_ptrs); - } + add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances(op_ptrs); } #endif #ifdef CK_ENABLE_BF16 @@ -170,20 +160,10 @@ struct DeviceOperationInstanceFactory< #endif #ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && - is_same_v) + is_same_v && is_same_v && + is_same_v) { - static_assert(is_same_v, - "Error: this operator requires the same compute type"); - if constexpr(is_same_v) - { - add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkyxc_ngchw_f32_tf32_instances( - op_ptrs); - } - else - { - add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkyxc_ngchw_f32_instances( - op_ptrs); - } + add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkyxc_ngchw_f32_instances(op_ptrs); } #endif #ifdef CK_ENABLE_BF16 @@ -213,28 +193,14 @@ struct DeviceOperationInstanceFactory< #endif #ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && - is_same_v) + is_same_v && is_same_v && + is_same_v) { - static_assert(is_same_v, - "Error: this operator requires the same compute type"); - if constexpr(is_same_v) - { - add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_tf32_instances( - op_ptrs); - add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_tf32_16_16_instances( - op_ptrs); - add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_tf32_vec_transpose_instances( - op_ptrs); - } - else - { - add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_instances( - op_ptrs); - add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_16_16_instances( - op_ptrs); - add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_vec_transpose_instances( - op_ptrs); - } + add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_instances(op_ptrs); + add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_16_16_instances( + op_ptrs); + add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_vec_transpose_instances( + op_ptrs); } #endif #ifdef CK_ENABLE_BF16 diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_xdl.inc index be2cbff9e4c..d73bca42116 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_xdl.inc @@ -38,21 +38,6 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances( PassThrough, PassThrough, PassThrough>>>& instances); -void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_tf32_instances( - std::vector>>& instances); #endif #ifdef CK_ENABLE_BF16 void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances( @@ -219,21 +204,6 @@ void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkyxc_ngchw_f32_instances( PassThrough, PassThrough, PassThrough>>>& instances); -void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkyxc_ngchw_f32_tf32_instances( - std::vector>>& instances); #endif #ifdef CK_ENABLE_BF16 void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkyxc_ngchw_bf16_instances( @@ -336,54 +306,6 @@ void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_vec_transpose_ PassThrough, PassThrough, PassThrough>>>& instances); - -void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_tf32_instances( - std::vector>>& instances); - -void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_tf32_16_16_instances( - std::vector>>& instances); - -void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_tf32_vec_transpose_instances( - std::vector>>& instances); #endif #ifdef CK_ENABLE_BF16 void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_bf16_instances( diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp index 5e62793dfb5..bca19b05a40 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp @@ -274,15 +274,10 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v) + is_same_v && is_same_v && + is_same_v) { - static_assert(is_same_v, - "Error: ComputeTypeA and ComputeTypeB should be the same"); - if constexpr(is_same_v) - { - add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances( - op_ptrs); - } + add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances(op_ptrs); } #endif #ifdef CK_ENABLE_FP16 @@ -312,20 +307,16 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v) + is_same_v && is_same_v && + is_same_v) { - static_assert(is_same_v, - "Error: ComputeTypeA and ComputeTypeB should be the same"); - if constexpr(is_same_v) - { - add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances( - op_ptrs); + add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances( + op_ptrs); - add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_default_pipev1_instances( - op_ptrs); - add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_pad0_pipev1_instances( - op_ptrs); - } + add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_default_pipev1_instances( + op_ptrs); + add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_pad0_pipev1_instances( + op_ptrs); } #endif #ifdef CK_ENABLE_FP16 @@ -540,15 +531,11 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v) + is_same_v && is_same_v && + is_same_v) { - static_assert(is_same_v, - "Error: ComputeTypeA and ComputeTypeB should be the same"); - if constexpr(is_same_v) - { - add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkcyx_ngkhw_f32_instances( - op_ptrs); - } + add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkcyx_ngkhw_f32_instances( + op_ptrs); } #endif } @@ -577,15 +564,11 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v) + is_same_v && is_same_v && + is_same_v) { - static_assert(is_same_v, - "Error: ComputeTypeA and ComputeTypeB should be the same"); - if constexpr(is_same_v) - { - add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_f32_instances( - op_ptrs); - } + add_device_grouped_conv2d_bwd_weight_xdl_ngchw_gkyxc_ngkhw_f32_instances( + op_ptrs); } #endif } @@ -597,15 +580,11 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v) + is_same_v && is_same_v && + is_same_v) { - static_assert(is_same_v, - "Error: ComputeTypeA and ComputeTypeB should be the same"); - if constexpr(is_same_v) - { - add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances( - op_ptrs); - } + add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances( + op_ptrs); } #endif #ifdef CK_ENABLE_FP16 @@ -651,19 +630,6 @@ struct DeviceOperationInstanceFactory) - { - add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( - op_ptrs); - add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_default_pipev2_instances( - op_ptrs); - add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_default_pipev5_instances( - op_ptrs); - add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_pad0_pipev2_instances( - op_ptrs); - add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_pad0_pipev5_instances( - op_ptrs); - } } #endif #ifdef CK_ENABLE_FP16 diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_xdl.inc index a8dffa4dda4..61744531f57 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_xdl.inc @@ -1211,20 +1211,6 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances PassThrough, PassThrough>>>& instances); -void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( - std::vector>>& instances); - void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_default_pipev2_instances( std::vector>>& instances); -void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_default_pipev2_instances( - std::vector>>& instances); - void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_default_pipev5_instances( std::vector>>& instances); -void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_default_pipev5_instances( - std::vector>>& instances); - void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_pad0_pipev2_instances( std::vector>>& instances); -void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_pad0_pipev2_instances( - std::vector>>& instances); void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_pad0_pipev5_instances( std::vector>>& instances); -void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_pad0_pipev5_instances( - std::vector>>& instances); #endif #if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_f8_instances( diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/conv2d_fwd/CMakeLists.txt index 028d5b518f1..04b313d075b 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd/CMakeLists.txt @@ -1,7 +1,6 @@ # ONLY XDL_KERNELS set(DEVICE_CONV2D_FWD_INSTANCES) list(APPEND DEVICE_CONV2D_FWD_INSTANCES device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp - device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_tf32_instance.cpp device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_tf32_instance.cpp deleted file mode 100644 index ffcdf57d220..00000000000 --- a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_tf32_instance.cpp +++ /dev/null @@ -1,136 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#ifdef CK_ENABLE_FP32 -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -using F32 = float; -using TF32 = ck::tf32_t; - -template -using S = ck::Sequence; - -using NHWC = ck::tensor_layout::convolution::NHWC; -using KYXC = ck::tensor_layout::convolution::KYXC; -using NHWK = ck::tensor_layout::convolution::NHWK; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto ConvFwdDefault = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; - -static constexpr auto ConvFwd1x1P0 = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; - -static constexpr auto ConvFwd1x1S1P0 = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; - -// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] -using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_tf32_instances = std::tuple< - // clang-format off - //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| - //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| - //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| - //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32> - // clang-format on - >; - -using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_f32_tf32_instances = std::tuple< - // clang-format off - //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| - //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| - //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| - //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32> - // clang-format on - >; - -using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f32_tf32_instances = std::tuple< - // clang-format off - //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| - //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| - //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| - //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1, TF32> - // clang-format on - >; - -void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_tf32_instances( - std::vector>>& instances) -{ -#if CK_BUILD_DEPRECATED -#pragma message "These instances are getting deprecated" - add_device_operation_instances(instances, - device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_tf32_instances{}); - add_device_operation_instances( - instances, device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_f32_tf32_instances{}); - add_device_operation_instances( - instances, device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f32_tf32_instances{}); -#else -#pragma message "These instances were deprecated" - std::ignore = instances; -#endif -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck -#endif diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f32_tf32_vec_transpose_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f32_tf32_vec_transpose_instance.cpp deleted file mode 100644 index 93f217ca054..00000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f32_tf32_vec_transpose_instance.cpp +++ /dev/null @@ -1,42 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_transpose_xdl_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_tf32_vec_transpose_instances( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_grouped_conv_bwd_data_transpose_xdl_f32_tf32_instances<2, - NGKHW, - GKCYX, - Empty_Tuple, - NGCHW, - ConvBwdDataDefault>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_f32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_f32_tf32_instance.cpp deleted file mode 100644 index c9223e42ea4..00000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_f32_tf32_instance.cpp +++ /dev/null @@ -1,52 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// wo, k] -void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_tf32_instances( - std::vector>>& instances) -{ - // 1. Default - add_device_operation_instances( - instances, - device_grouped_conv_bwd_data_xdl_f32_tf32_16_16_instances<3, - GNDHWK, - GKZYXC, - Empty_Tuple, - GNDHWC, - ConvBwdDataDefault>{}); - // 2. Filter1x1Stride1Pad0 - add_device_operation_instances(instances, - device_grouped_conv_bwd_data_xdl_f32_tf32_16_16_instances< - 3, - GNDHWK, - GKZYXC, - Empty_Tuple, - GNDHWC, - ConvBwdDataFilter1x1Stride1Pad0>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_16_16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_16_16_instance.cpp deleted file mode 100644 index 63e90333a96..00000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_16_16_instance.cpp +++ /dev/null @@ -1,51 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_tf32_16_16_instances( - std::vector>>& instances) -{ - // 1. Default - add_device_operation_instances( - instances, - device_grouped_conv_bwd_data_xdl_f32_tf32_16_16_instances<3, - NDHWGK, - GKZYXC, - Empty_Tuple, - NDHWGC, - ConvBwdDataDefault>{}); - // 2. Filter1x1Stride1Pad0 - add_device_operation_instances(instances, - device_grouped_conv_bwd_data_xdl_f32_tf32_16_16_instances< - 3, - NDHWGK, - GKZYXC, - Empty_Tuple, - NDHWGC, - ConvBwdDataFilter1x1Stride1Pad0>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_16_16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_16_16_instance.cpp deleted file mode 100644 index cea4aac2ff2..00000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_16_16_instance.cpp +++ /dev/null @@ -1,42 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_transpose_xdl_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f32_tf32_16_16_instances( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_grouped_conv_bwd_data_xdl_f32_tf32_16_16_instances<3, - NGKDHW, - GKCZYX, - Empty_Tuple, - NGCDHW, - ConvBwdDataDefault>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_vec_transpose_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_vec_transpose_instance.cpp deleted file mode 100644 index 4b12fece2d1..00000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f32_tf32_vec_transpose_instance.cpp +++ /dev/null @@ -1,42 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_transpose_xdl_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f32_tf32_vec_transpose_instances( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_grouped_conv_bwd_data_transpose_xdl_f32_tf32_instances<3, - NGKDHW, - GKCZYX, - Empty_Tuple, - NGCDHW, - ConvBwdDataDefault>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkzyxc_ngkdhw_f32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkzyxc_ngkdhw_f32_tf32_instance.cpp deleted file mode 100644 index 39bcb567be0..00000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkzyxc_ngkdhw_f32_tf32_instance.cpp +++ /dev/null @@ -1,42 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_transpose_xdl_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkzyxc_ngcdhw_f32_tf32_instances( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_grouped_conv_bwd_data_xdl_f32_tf32_generic_instances<3, - NGKDHW, - GKZYXC, - Empty_Tuple, - NGCDHW, - ConvBwdDataDefault>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_bilinear/xdl/device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_bilinear/xdl/device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp deleted file mode 100644 index 12b36b77ca0..00000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_bilinear/xdl/device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp +++ /dev/null @@ -1,51 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_bilinear_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -void add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_f32_tf32_instances( - std::vector, - NDHWGC, - F32, - F32, - Tuple, - F32, - PassThrough, - PassThrough, - Bilinear, - TF32, - TF32>>>& instances) -{ - // 1. Default - add_device_operation_instances( - instances, - device_grouped_conv_bwd_data_xdl_bilinear_f32_tf32_instances<3, - NDHWGK, - GKZYXC, - Tuple, - NDHWGC, - ConvBwdDataDefault>{}); - // 2. Filter1x1Stride1Pad0 - add_device_operation_instances(instances, - device_grouped_conv_bwd_data_xdl_bilinear_f32_tf32_instances< - 3, - NDHWGK, - GKZYXC, - Tuple, - NDHWGC, - ConvBwdDataFilter1x1Stride1Pad0>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt index f9922b1f375..5574cf82f9f 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt @@ -6,7 +6,6 @@ set(GROUPED_CONV3D_BWD_WEIGHT xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp - xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instance.cpp xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_default_pipev2_instance.cpp @@ -18,13 +17,9 @@ set(GROUPED_CONV3D_BWD_WEIGHT xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pad0_pipev2_instance.cpp xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pad0_pipev5_instance.cpp xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_default_pipev2_instance.cpp - xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_default_pipev2_instance.cpp xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_default_pipev5_instance.cpp - xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_default_pipev5_instance.cpp xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_pad0_pipev2_instance.cpp - xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_pad0_pipev2_instance.cpp xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_pad0_pipev5_instance.cpp - xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_pad0_pipev5_instance.cpp xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_instance.cpp xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instance.cpp xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev2_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_default_pipev2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_default_pipev2_instance.cpp deleted file mode 100644 index dab91ec7475..00000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_default_pipev2_instance.cpp +++ /dev/null @@ -1,42 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_default_pipev2_instances( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f32_tf32_instances< - 3, - NDHWGC, - GKZYXC, - NDHWGK, - ConvBwdWeightDefault, - BlockGemmPipelineScheduler::Intrawave, - BlockGemmPipelineVersion::v2>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_default_pipev5_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_default_pipev5_instance.cpp deleted file mode 100644 index 01229234ff2..00000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_default_pipev5_instance.cpp +++ /dev/null @@ -1,42 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_default_pipev5_instances( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f32_tf32_instances< - 3, - NDHWGC, - GKZYXC, - NDHWGK, - ConvBwdWeightDefault, - BlockGemmPipelineScheduler::Intrawave, - BlockGemmPipelineVersion::v5>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp deleted file mode 100644 index ac6c3b60e40..00000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp +++ /dev/null @@ -1,48 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( - std::vector>>& instances) -{ - // 1. Default - add_device_operation_instances( - instances, - device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_tf32_instances<3, - NDHWGC, - GKZYXC, - NDHWGK, - ConvBwdWeightDefault>{}); - // 2. Filter1x1Stride1Pad0 - add_device_operation_instances(instances, - device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_tf32_instances< - 3, - NDHWGC, - GKZYXC, - NDHWGK, - ConvBwdWeightFilter1x1Stride1Pad0>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_pad0_pipev2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_pad0_pipev2_instance.cpp deleted file mode 100644 index c479cc20481..00000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_pad0_pipev2_instance.cpp +++ /dev/null @@ -1,42 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_pad0_pipev2_instances( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f32_tf32_instances< - 3, - NDHWGC, - GKZYXC, - NDHWGK, - ConvBwdWeightFilter1x1Stride1Pad0, - BlockGemmPipelineScheduler::Intrawave, - BlockGemmPipelineVersion::v2>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_pad0_pipev5_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_pad0_pipev5_instance.cpp deleted file mode 100644 index cfb0e8a65e5..00000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_pad0_pipev5_instance.cpp +++ /dev/null @@ -1,42 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_pad0_pipev5_instances( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f32_tf32_instances< - 3, - NDHWGC, - GKZYXC, - NDHWGK, - ConvBwdWeightFilter1x1Stride1Pad0, - BlockGemmPipelineScheduler::Intrawave, - BlockGemmPipelineVersion::v5>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/profiler/src/profile_grouped_conv_bwd_weight.cpp b/profiler/src/profile_grouped_conv_bwd_weight.cpp index 24c848e08b7..1dd40fdbfe1 100644 --- a/profiler/src/profile_grouped_conv_bwd_weight.cpp +++ b/profiler/src/profile_grouped_conv_bwd_weight.cpp @@ -160,7 +160,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) // fp32 atomic add is used for weight tensor in bf16 kernel return profile(I1, GNWC{}, GKXC{}, GNWK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{}); } - if(data_type == ConvDataType::F32_F32_F32_TF32) + else if(data_type == ConvDataType::F32_F32_F32_TF32) { #if defined(__gfx942__) return profile(I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); @@ -182,7 +182,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) // fp32 atomic add is used for weight tensor in bf16 kernel return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{}); } - if(data_type == ConvDataType::F32_F32_F32_TF32) + else if(data_type == ConvDataType::F32_F32_F32_TF32) { #if defined(__gfx942__) return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); @@ -208,7 +208,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) { return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{}); } - if(data_type == ConvDataType::F32_F32_F32_TF32) + else if(data_type == ConvDataType::F32_F32_F32_TF32) { #if defined(__gfx942__) return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); @@ -241,7 +241,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) { return profile(I2, NGCHW{}, GKCYX{}, NGKHW{}, F32{}, F32{}, F32{}, F32{}, F32{}); } - if(data_type == ConvDataType::F32_F32_F32_TF32) + else if(data_type == ConvDataType::F32_F32_F32_TF32) { #if defined(__gfx942__) return profile(I2, NGCHW{}, GKCYX{}, NGKHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); @@ -304,7 +304,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) return profile( I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, int8_t{}, int8_t{}, int8_t{}, int8_t{}, int8_t{}); } - if(data_type == ConvDataType::F32_F32_F32_TF32) + else if(data_type == ConvDataType::F32_F32_F32_TF32) { #if defined(__gfx942__) return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); @@ -338,7 +338,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) { return profile(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F32{}, F32{}, F32{}, F32{}, F32{}); } - if(data_type == ConvDataType::F32_F32_F32_TF32) + else if(data_type == ConvDataType::F32_F32_F32_TF32) { #if defined(__gfx942__) return profile(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); diff --git a/profiler/src/profile_grouped_conv_fwd_clamp.cpp b/profiler/src/profile_grouped_conv_fwd_clamp.cpp index f23e2ddc110..b3552a1a3b1 100644 --- a/profiler/src/profile_grouped_conv_fwd_clamp.cpp +++ b/profiler/src/profile_grouped_conv_fwd_clamp.cpp @@ -198,7 +198,7 @@ int grouped_conv_fwd_clamp(int argc, char* argv[]) else if(data_type == ConvDataType::F32_F32_F32_TF32) { #if defined(__gfx942__) - return profile(I3, NDHWGC{}, NDHWGC{}, NDHWGC{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); + return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); #endif } } From 6f6657172e822222a2ec3136e95fa07bbafd599d Mon Sep 17 00:00:00 2001 From: yingmaolu Date: Fri, 26 Sep 2025 12:06:26 +0800 Subject: [PATCH 17/19] complement ndhwgc instances --- .../gpu/grouped_convolution_backward_data.hpp | 20 +++- .../grouped_convolution_backward_data_xdl.inc | 31 ++++++ .../grouped_convolution_backward_weight.hpp | 38 ++++---- ...uped_convolution_backward_weight_scale.hpp | 5 +- ...rouped_convolution_backward_weight_xdl.inc | 95 ++++++++++++++++--- ...rouped_convolution_forward_scaleadd_ab.hpp | 2 +- .../gpu/grouped_convolution_forward_xdl.inc | 4 +- .../grouped_conv2d_bwd_weight/CMakeLists.txt | 2 +- .../grouped_conv3d_bwd_data/CMakeLists.txt | 2 + ..._gkzyxc_ndhwgk_f32_tf32_16_16_instance.cpp | 51 ++++++++++ ...ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp | 51 ++++++++++ .../grouped_conv3d_bwd_weight/CMakeLists.txt | 5 + ...dhwgk_f32_tf32_default_pipev2_instance.cpp | 42 ++++++++ ...dhwgk_f32_tf32_default_pipev5_instance.cpp | 42 ++++++++ ...ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp | 48 ++++++++++ ...c_ndhwgk_f32_tf32_pad0_pipev2_instance.cpp | 42 ++++++++ ...c_ndhwgk_f32_tf32_pad0_pipev5_instance.cpp | 42 ++++++++ .../src/profile_grouped_conv_bwd_weight.cpp | 2 +- .../src/profile_grouped_conv_fwd_clamp.cpp | 2 +- 19 files changed, 484 insertions(+), 42 deletions(-) create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_16_16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_default_pipev2_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_default_pipev5_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_pad0_pipev2_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_pad0_pipev5_instance.cpp diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp index 617ddad2807..351746488d9 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp @@ -279,10 +279,22 @@ struct DeviceOperationInstanceFactory< is_same_v && is_same_v && is_same_v) { - add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances( - op_ptrs); - add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_16_16_instances( - op_ptrs); + static_assert(is_same_v, + "Error: this operator requires the same compute type"); + if constexpr(is_same_v) + { + add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances( + op_ptrs); + add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_16_16_instances( + op_ptrs); + } + else if constexpr(is_same_v) + { + add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_tf32_instances( + op_ptrs); + add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_tf32_16_16_instances( + op_ptrs); + } } #endif #ifdef CK_ENABLE_BF16 diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_xdl.inc index d73bca42116..fdf4bc18236 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_xdl.inc @@ -454,6 +454,37 @@ void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_16_16_insta PassThrough, PassThrough, PassThrough>>>& instances); + +void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_tf32_instances( + std::vector>>& instances); +void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_tf32_16_16_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_BF16 void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances( diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp index bca19b05a40..7f91d0bee93 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp @@ -616,7 +616,6 @@ struct DeviceOperationInstanceFactory, "Error: ComputeTypeA and ComputeTypeB should be the same"); - if constexpr(is_same_v) { add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( @@ -630,6 +629,19 @@ struct DeviceOperationInstanceFactory) + { + add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + op_ptrs); + add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_default_pipev2_instances( + op_ptrs); + add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_default_pipev5_instances( + op_ptrs); + add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_pad0_pipev2_instances( + op_ptrs); + add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_pad0_pipev5_instances( + op_ptrs); + } } #endif #ifdef CK_ENABLE_FP16 @@ -790,15 +802,11 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v) + is_same_v && is_same_v && + is_same_v) { - static_assert(is_same_v, - "Error: ComputeTypeA and ComputeTypeB should be the same"); - if constexpr(is_same_v) - { - add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkczyx_ngkdhw_f32_instances( - op_ptrs); - } + add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkczyx_ngkdhw_f32_instances( + op_ptrs); } #endif } @@ -827,15 +835,11 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v) + is_same_v && is_same_v && + is_same_v) { - static_assert(is_same_v, - "Error: ComputeTypeA and ComputeTypeB should be the same"); - if constexpr(is_same_v) - { - add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_f32_instances( - op_ptrs); - } + add_device_grouped_conv3d_bwd_weight_xdl_ngcdhw_gkzyxc_ngkdhw_f32_instances( + op_ptrs); } #endif } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_scale.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_scale.hpp index 4bb44b62e4b..46ddba312af 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_scale.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_scale.hpp @@ -156,8 +156,9 @@ struct DeviceOperationInstanceFactory< if constexpr(is_same_v && is_same_v && is_same_v) { - if constexpr(is_same_v && - is_same_v) + static_assert(is_same_v, + "Error: this operator requires the same compute type"); + if constexpr(is_same_v) { add_device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( op_ptrs); diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_xdl.inc index 61744531f57..0d3159210df 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_xdl.inc @@ -1259,21 +1259,90 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_pad0_pipe PassThrough, PassThrough>>>& instances); +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_default_pipev2_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_default_pipev5_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_pad0_pipev2_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_pad0_pipev5_instances( + std::vector>>& instances); #endif #if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 - void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_f8_instances( - std::vector>>& instances); +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_f8_instances( + std::vector>>& instances); #endif } // namespace instance diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_ab.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_ab.hpp index d7a217f1b84..1bea403afa2 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_ab.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_ab.hpp @@ -135,7 +135,7 @@ struct DeviceOperationInstanceFactory> && is_same_v> && - is_same_v) + is_same_v && is_same_v) { add_device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f32_instances( op_ptrs); diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc index 7f3c9d9d365..a59fcd9d6ea 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc @@ -639,7 +639,7 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf8_instances( BF8>>>& instances); #endif -#if (defined(CK_ENABLE_FP8) && defined(CK_ENABLE_BF8)) +#if(defined(CK_ENABLE_FP8) && defined(CK_ENABLE_BF8)) void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f8_bf8_instances( std::vector>>& instances); #endif -#if (defined(CK_ENABLE_FP8) && defined(CK_ENABLE_BF8)) +#if(defined(CK_ENABLE_FP8) && defined(CK_ENABLE_BF8)) void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf8_f8_instances( std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_f32_tf32_16_16_instances<3, + NDHWGK, + GKZYXC, + Empty_Tuple, + NDHWGC, + ConvBwdDataDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances(instances, + device_grouped_conv_bwd_data_xdl_f32_tf32_16_16_instances< + 3, + NDHWGK, + GKZYXC, + Empty_Tuple, + NDHWGC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp new file mode 100644 index 00000000000..1db6494479e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_tf32_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_f32_tf32_instances<3, + NDHWGK, + GKZYXC, + Empty_Tuple, + NDHWGC, + ConvBwdDataDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_f32_tf32_instances<3, + NDHWGK, + GKZYXC, + Empty_Tuple, + NDHWGC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt index 5574cf82f9f..f9922b1f375 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt @@ -6,6 +6,7 @@ set(GROUPED_CONV3D_BWD_WEIGHT xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp + xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instance.cpp xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_default_pipev2_instance.cpp @@ -17,9 +18,13 @@ set(GROUPED_CONV3D_BWD_WEIGHT xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pad0_pipev2_instance.cpp xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pad0_pipev5_instance.cpp xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_default_pipev2_instance.cpp + xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_default_pipev2_instance.cpp xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_default_pipev5_instance.cpp + xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_default_pipev5_instance.cpp xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_pad0_pipev2_instance.cpp + xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_pad0_pipev2_instance.cpp xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_pad0_pipev5_instance.cpp + xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_pad0_pipev5_instance.cpp xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_instance.cpp xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instance.cpp xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev2_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_default_pipev2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_default_pipev2_instance.cpp new file mode 100644 index 00000000000..dab91ec7475 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_default_pipev2_instance.cpp @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_default_pipev2_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f32_tf32_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v2>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_default_pipev5_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_default_pipev5_instance.cpp new file mode 100644 index 00000000000..01229234ff2 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_default_pipev5_instance.cpp @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_default_pipev5_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f32_tf32_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v5>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp new file mode 100644 index 00000000000..ac6c3b60e40 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp @@ -0,0 +1,48 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_tf32_instances<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances(instances, + device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_tf32_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_pad0_pipev2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_pad0_pipev2_instance.cpp new file mode 100644 index 00000000000..c479cc20481 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_pad0_pipev2_instance.cpp @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_pad0_pipev2_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f32_tf32_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightFilter1x1Stride1Pad0, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v2>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_pad0_pipev5_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_pad0_pipev5_instance.cpp new file mode 100644 index 00000000000..cfb0e8a65e5 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_pad0_pipev5_instance.cpp @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_pad0_pipev5_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f32_tf32_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightFilter1x1Stride1Pad0, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v5>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/src/profile_grouped_conv_bwd_weight.cpp b/profiler/src/profile_grouped_conv_bwd_weight.cpp index 1dd40fdbfe1..7d3f1ad6c09 100644 --- a/profiler/src/profile_grouped_conv_bwd_weight.cpp +++ b/profiler/src/profile_grouped_conv_bwd_weight.cpp @@ -42,7 +42,7 @@ static void print_helper_msg() << " 2: Input bf16, Weight fp32, Output bf16\n" << " 3: Input fp16, Weight fp16, Output fp16, Gemm bf8@fp8\n" << " 4: Input int8, Weight int8, Output int8\n" - << " 5: Input bf16, Weight bf16, Output bf16)\n" + << " 5: Input bf16, Weight bf16, Output bf16\n" << " 6: Input fp32, Weight fp32, Output fp32, Compute tf32)\n" << "arg3: tensor layout (0: Input[G, N, C, Hi, Wi], Weight[G, K, C, Y, X], Output[G, " "N, K, Ho, Wo]\n" diff --git a/profiler/src/profile_grouped_conv_fwd_clamp.cpp b/profiler/src/profile_grouped_conv_fwd_clamp.cpp index b3552a1a3b1..1b100ff8671 100644 --- a/profiler/src/profile_grouped_conv_fwd_clamp.cpp +++ b/profiler/src/profile_grouped_conv_fwd_clamp.cpp @@ -52,7 +52,7 @@ static void print_helper_msg() << " 4: Input fp8, Weight fp8, Output fp8\n" << " 5: Input bf8, Weight bf8, Output fp8\n" << " 6: Input fp8, Weight bf8, Output fp8\n" - << " 7: Input bf8, Weight fp8, Output fp8)\n" + << " 7: Input bf8, Weight fp8, Output fp8\n" << " 8: Input fp32, Weight fp32, Output fp32, Compute tf32)\n" << "arg3: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]\n" << " 1: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K]\n" From 1be39fc22617e9a0bb6b97b3b86c7747218c78b3 Mon Sep 17 00:00:00 2001 From: yingmaolu Date: Fri, 26 Sep 2025 14:01:56 +0800 Subject: [PATCH 18/19] update copyright datetime --- ...ight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp | 2 +- ..._weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_bilinear/xdl/device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_bilinear/xdl/device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp index 522598ba87d..a71c02aec15 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_bilinear/xdl/device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_bilinear/xdl/device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_bilinear_instance.hpp" diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_scale/xdl/device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_scale/xdl/device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp index 37692f3478a..65f79141b37 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_scale/xdl/device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_scale/xdl/device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_scale_instance.hpp" From a1b65ecbc1f5bfeed9de6e95cd705d0e7740660e Mon Sep 17 00:00:00 2001 From: yingmaolu Date: Mon, 29 Sep 2025 17:21:15 +0800 Subject: [PATCH 19/19] add check in IsSupportedArgument() --- ...nv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp | 16 ++++++++++++++++ ...onv_bwd_weight_multiple_d_xdl_cshuffle.hpp | 16 ++++++++++++++++ ...conv_bwd_weight_two_stage_xdl_cshuffle.hpp | 17 +++++++++++++++++ ...e_grouped_conv_bwd_weight_xdl_cshuffle.hpp | 16 ++++++++++++++++ ...rouped_conv_bwd_weight_xdl_cshuffle_v3.hpp | 17 +++++++++++++++++ ..._conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp | 19 +++++++++++++++++++ ...d_multiple_d_xdl_large_tensor_cshuffle.hpp | 17 +++++++++++++++++ 7 files changed, 118 insertions(+) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp index 57ea476ced8..ab6679236d2 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp @@ -1498,6 +1498,22 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 { return false; } + if constexpr(is_same_v || is_same_v) + { + if(!is_tf32_supported()) + { + return false; + } + if constexpr(!is_same_v) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "ComputeDataType for A and B should be same while using TF32" + << std::endl; + } + return false; + } + } if constexpr(!IsSplitKSupported) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp index 987a1e273ae..ab185700b6d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp @@ -951,6 +951,22 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle { return false; } + if constexpr(is_same_v || is_same_v) + { + if(!is_tf32_supported()) + { + return false; + } + if constexpr(!is_same_v) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "ComputeDataType for A and B should be same while using TF32" + << std::endl; + } + return false; + } + } if constexpr(NDimSpatial == 1) { if constexpr(!is_GNWC_GKXC_GNWK()) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp index e38768b2fa1..50796f78b40 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp @@ -1687,6 +1687,23 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle const index_t GemmK = arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); + if constexpr(is_same_v || is_same_v) + { + if(!is_tf32_supported()) + { + return false; + } + if constexpr(!is_same_v) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "ComputeDataType for A and B should be same while using TF32" + << std::endl; + } + return false; + } + } + if(get_warp_size() == 64) { if constexpr(NXdlPerWave64 > 0) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp index 22fc13bae40..c7ee3e9ecfb 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp @@ -950,6 +950,22 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle { return false; } + if constexpr(is_same_v || is_same_v) + { + if(!is_tf32_supported()) + { + return false; + } + if constexpr(!is_same_v) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "ComputeDataType for A and B should be same while using TF32" + << std::endl; + } + return false; + } + } if constexpr(NDimSpatial == 1) { if constexpr(!is_GNWC_GKXC_GNWK()) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp index 735eebbdf66..07722155fda 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp @@ -1289,6 +1289,23 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2); + if constexpr(is_same_v || is_same_v) + { + if(!is_tf32_supported()) + { + return false; + } + if constexpr(!is_same_v) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "ComputeDataType for A and B should be same while using TF32" + << std::endl; + } + return false; + } + } + if(get_warp_size() == 64) { if constexpr(NXdlPerWave64 > 0) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp index dd2e429a01d..dbc60e3fdc1 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp @@ -1399,6 +1399,25 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 } return false; } + + if constexpr(is_same_v || + is_same_v) + { + if(!is_tf32_supported()) + { + return false; + } + if constexpr(!is_same_v) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "ComputeDataType for A and B should be same while using TF32" + << std::endl; + } + return false; + } + } + // check ConvolutionForwardSpecialization if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp index 25afe466907..020b3dc5a63 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp @@ -820,6 +820,23 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor { return false; } + if constexpr(is_same_v || + is_same_v) + { + if(!is_tf32_supported()) + { + return false; + } + if constexpr(!is_same_v) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "ComputeDataType for A and B should be same while using TF32" + << std::endl; + } + return false; + } + } // check ConvolutionForwardSpecialization if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)