Skip to content

Commit a0e48cb

Browse files
authored
Merge branch 'develop' into LWPCK-3688-cktile-contraction-multi-D
2 parents 4f89943 + 32773fe commit a0e48cb

File tree

333 files changed

+8295
-2411
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

333 files changed

+8295
-2411
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
66

77
### Added
88
* Added support for B Tensor type pk_int4_t in the CK TILE weight preshuffle GEMM.
9+
* Added the new api to load different memory sizes to SGPR.
910
* Added support for B Tensor Preshuffle in CK TILE Grouped GEMM.
1011
* Added a basic copy kernel example and supporting documentation for new CK Tile developers.
1112
* Added support for bf16, f32, and f16 for 2D and 3D NGCHW grouped convolution backward data

Dockerfile

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,23 @@
1+
12
FROM ubuntu:24.04
23
ARG DEBIAN_FRONTEND=noninteractive
3-
ARG ROCMVERSION=6.4.1
4+
ARG ROCMVERSION=7.0.1
45
ARG compiler_version=""
56
ARG compiler_commit=""
67
ARG CK_SCCACHE=""
78
ARG DEB_ROCM_REPO=http://repo.radeon.com/rocm/apt/.apt_$ROCMVERSION/
89
ENV APT_KEY_DONT_WARN_ON_DANGEROUS_USAGE=DontWarn
10+
ENV DEBIAN_FRONTEND=noninteractive
911

1012
# Add rocm repository
1113
RUN set -xe && \
12-
apt-get update && apt-get install -y --allow-unauthenticated apt-utils wget gnupg2 curl && \
13-
curl -fsSL https://repo.radeon.com/rocm/rocm.gpg.key | gpg --dearmor -o /etc/apt/trusted.gpg.d/rocm-keyring.gpg
14-
15-
RUN if [ "$ROCMVERSION" != "6.5" ]; then \
16-
sh -c "wget https://repo.radeon.com/amdgpu-install/$ROCMVERSION/ubuntu/jammy/amdgpu-install_6.4.60401-1_all.deb --no-check-certificate" && \
17-
apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated ./amdgpu-install_6.4.60401-1_all.deb && \
18-
wget -qO - http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - && \
19-
sh -c "echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] $DEB_ROCM_REPO jammy main > /etc/apt/sources.list.d/rocm.list" && \
20-
sh -c 'echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] https://repo.radeon.com/amdgpu/$ROCMVERSION/ubuntu jammy main > /etc/apt/sources.list.d/amdgpu.list'; \
21-
fi
14+
apt-get update && apt-get install -y --allow-unauthenticated apt-utils wget gnupg2 curl
2215

23-
RUN sh -c "echo deb http://mirrors.kernel.org/ubuntu jammy main universe | tee -a /etc/apt/sources.list" && \
24-
amdgpu-install -y --usecase=rocm --no-dkms
16+
RUN wget https://repo.radeon.com/amdgpu-install/7.0.1/ubuntu/noble/amdgpu-install_7.0.1.70001-1_all.deb && \
17+
apt install ./amdgpu-install_7.0.1.70001-1_all.deb -y && \
18+
apt update && \
19+
apt install python3-setuptools python3-wheel -y && \
20+
apt install rocm-dev -y
2521

2622
## Sccache binary built from source for ROCm, only install if CK_SCCACHE is defined
2723
ARG SCCACHE_REPO_URL=http://compute-artifactory.amd.com/artifactory/rocm-generic-experimental/rocm-sccache
@@ -45,7 +41,6 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
4541
libelf-dev \
4642
libnuma-dev \
4743
libpthread-stubs0-dev \
48-
llvm-amdgpu \
4944
mpich \
5045
net-tools \
5146
pkg-config \
@@ -61,17 +56,13 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
6156
zip \
6257
libzstd-dev \
6358
openssh-server \
64-
clang-format-12 \
6559
clang-format-18 \
6660
kmod && \
6761
apt-get clean && \
6862
rm -rf /var/lib/apt/lists/* && \
6963
rm -rf amdgpu-install* && \
70-
# Remove unnecessary rocm components that take a lot of space
71-
apt-get remove -y rocblas rocfft rocsparse composablekernel-dev hipblaslt
72-
7364
#Install latest ccache
74-
RUN git clone https://github.com/ccache/ccache.git && \
65+
git clone https://github.com/ccache/ccache.git && \
7566
cd ccache && mkdir build && cd build && cmake .. && make install && \
7667
#Install ninja build tracing tools
7768
cd / && \

Dockerfile.compiler

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
ARG BASE_DOCKER="rocm/composable_kernel:ck_ub24.04_rocm6.4.1"
1+
ARG BASE_DOCKER="rocm/composable_kernel:ck_ub24.04_rocm7.0.1"
22
FROM $BASE_DOCKER
33
ARG compiler_version=""
44
ARG compiler_commit=""

Jenkinsfile

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def getBaseDockerImageName(){
5353
}
5454
else{
5555
def ROCM_numeric = parseVersion("${params.ROCMVERSION}")
56-
if ( ROCM_numeric.major <= 6 && ROCM_numeric.minor < 5 ){
56+
if ( ROCM_numeric.major <= 7 && ROCM_numeric.minor < 1 ){
5757
img = "${env.CK_DOCKERHUB}:ck_ub24.04_rocm${params.ROCMVERSION}"
5858
}
5959
else{
@@ -476,7 +476,7 @@ def buildHipClangJob(Map conf=[:]){
476476
def retimage
477477
(retimage, image) = getDockerImage(conf)
478478

479-
gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') {
479+
gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "${variant}", account: 'ROCm', repo: 'composable_kernel') {
480480
withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') {
481481
timeout(time: 20, unit: 'HOURS')
482482
{
@@ -538,7 +538,7 @@ def Build_CK(Map conf=[:]){
538538
def image
539539
def retimage
540540

541-
gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') {
541+
gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "${variant}", account: 'ROCm', repo: 'composable_kernel') {
542542
try {
543543
(retimage, image) = getDockerImage(conf)
544544
withDockerContainer(image: image, args: dockerOpts) {
@@ -728,7 +728,7 @@ def process_results(Map conf=[:]){
728728
def variant = env.STAGE_NAME
729729
def retimage
730730

731-
gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') {
731+
gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "${variant}", account: 'ROCm', repo: 'composable_kernel') {
732732
try
733733
{
734734
echo "Pulling image: ${image}"
@@ -836,7 +836,7 @@ def run_aiter_tests(Map conf=[:]){
836836
dockerOpts = dockerOpts + " --group-add=${video_id} --group-add=${render_id} "
837837
echo "Docker flags: ${dockerOpts}"
838838

839-
gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') {
839+
gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "${variant}", account: 'ROCm', repo: 'composable_kernel') {
840840
try
841841
{
842842
echo "Pulling image: ${image}"
@@ -859,6 +859,7 @@ def run_aiter_tests(Map conf=[:]){
859859
sh "python3 /home/jenkins/workspace/aiter/op_tests/test_gemm_a8w8.py"
860860
sh "python3 /home/jenkins/workspace/aiter/op_tests/test_gemm_a8w8_blockscale.py"
861861
sh "python3 /home/jenkins/workspace/aiter/op_tests/test_mha.py"
862+
sh "python3 /home/jenkins/workspace/aiter/op_tests/test_mha_varlen.py"
862863
sh "python3 /home/jenkins/workspace/aiter/op_tests/test_moe.py"
863864
sh "python3 /home/jenkins/workspace/aiter/op_tests/test_moe_2stage.py"
864865
sh "python3 /home/jenkins/workspace/aiter/op_tests/test_moe_blockscale.py"
@@ -894,7 +895,7 @@ def run_pytorch_tests(Map conf=[:]){
894895
dockerOpts = dockerOpts + " --group-add=${video_id} --group-add=${render_id} "
895896
echo "Docker flags: ${dockerOpts}"
896897

897-
gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') {
898+
gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "${variant}", account: 'ROCm', repo: 'composable_kernel') {
898899
try
899900
{
900901
echo "Pulling image: ${image}"
@@ -930,7 +931,8 @@ def run_pytorch_tests(Map conf=[:]){
930931
}
931932

932933
//launch develop branch daily jobs
933-
CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;RUN_CK_TILE_FMHA_TESTS=true;RUN_TILE_ENGINE_GEMM_TESTS=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true
934+
CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;RUN_CK_TILE_FMHA_TESTS=true;RUN_PERFORMANCE_TESTS=true
935+
0 22 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;RUN_TILE_ENGINE_GEMM_TESTS=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true
934936
0 21 * * * % RUN_GROUPED_CONV_LARGE_CASES_TESTS=true;hipTensor_test=true;BUILD_GFX908=true;BUILD_GFX942=true;BUILD_GFX950=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true
935937
0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true
936938
0 17 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-mainline;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true
@@ -957,8 +959,8 @@ pipeline {
957959
description: 'If you want to use a custom docker image, please specify it here (default: leave blank).')
958960
string(
959961
name: 'ROCMVERSION',
960-
defaultValue: '6.4.1',
961-
description: 'Specify which ROCM version to use: 6.4.1 (default).')
962+
defaultValue: '7.0.1',
963+
description: 'Specify which ROCM version to use: 7.0.1 (default).')
962964
string(
963965
name: 'COMPILER_VERSION',
964966
defaultValue: '',
@@ -1037,8 +1039,8 @@ pipeline {
10371039
description: "Build CK and run tests on gfx942 (default: ON)")
10381040
booleanParam(
10391041
name: "BUILD_GFX950",
1040-
defaultValue: false,
1041-
description: "Build CK and run tests on gfx950 (default: OFF)")
1042+
defaultValue: true,
1043+
description: "Build CK and run tests on gfx950 (default: ON)")
10421044
booleanParam(
10431045
name: "BUILD_GFX10",
10441046
defaultValue: true,
@@ -1290,7 +1292,7 @@ pipeline {
12901292
agent{ label rocmnode("gfx90a")}
12911293
environment{
12921294
setup_args = "NO_CK_BUILD"
1293-
execute_args = """ CXX=/opt/rocm/llvm/bin/clang++ cmake ../codegen && \
1295+
execute_args = """ CXX=/opt/rocm/llvm/bin/clang++ cmake -DCMAKE_PREFIX_PATH=/opt/rocm ../codegen && \
12941296
make -j64 check"""
12951297
}
12961298
steps{
@@ -1350,15 +1352,14 @@ pipeline {
13501352
}
13511353
agent{ label rocmnode("gfx950") }
13521354
environment{
1353-
def docker_name = "${env.CK_DOCKERHUB_PRIVATE}:ck_ub24.04_rocm7.0"
13541355
setup_args = "NO_CK_BUILD"
13551356
execute_args = """ ../script/cmake-ck-dev.sh ../ gfx950 && \
13561357
make -j128 tile_example_fmha_fwd tile_example_fmha_bwd && \
13571358
cd ../ &&
13581359
example/ck_tile/01_fmha/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx950 """
13591360
}
13601361
steps{
1361-
buildHipClangJobAndReboot(setup_args:setup_args, docker_name: docker_name, no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
1362+
buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
13621363
cleanWs()
13631364
}
13641365
}
@@ -1566,7 +1567,7 @@ pipeline {
15661567
-DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """
15671568
}
15681569
steps{
1569-
Build_CK_and_Reboot(setup_args: setup_args, docker_name: "${env.CK_DOCKERHUB_PRIVATE}:ck_ub24.04_rocm7.0", config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local')
1570+
Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local')
15701571
cleanWs()
15711572
}
15721573
}
@@ -1631,7 +1632,7 @@ pipeline {
16311632
-D CMAKE_BUILD_TYPE=Release \
16321633
-D CMAKE_CXX_FLAGS=" -O3 " .. && ninja -j64 """
16331634

1634-
buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, docker_name: "${env.CK_DOCKERHUB_PRIVATE}:ck_ub24.04_rocm7.0")
1635+
buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, docker_name: "${env.CK_DOCKERHUB}:ck_ub24.04_rocm7.0.1")
16351636
}
16361637
cleanWs()
16371638
}
@@ -1657,13 +1658,13 @@ pipeline {
16571658
cleanWs()
16581659
}
16591660
}
1660-
stage("Build CK and run Tests on gfx1101")
1661+
stage("Build CK and run Tests on gfx11")
16611662
{
16621663
when {
16631664
beforeAgent true
16641665
expression { params.BUILD_GFX11.toBoolean() && !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() }
16651666
}
1666-
agent{ label rocmnode("gfx1101") }
1667+
agent{ label 'miopen && (gfx1101 || gfx1100)' }
16671668
environment{
16681669
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx11-generic" -DUSE_OPT_GFX11=ON -DCMAKE_CXX_FLAGS=" -O3 " """
16691670
execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \

codegen/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ configure_file(${CK_ROOT}/include/ck/config.h.in ${CK_ROOT}/include/ck/config.h)
1212
find_package(ROCM)
1313
include(ROCMInstallTargets)
1414
include(ROCMTest)
15+
find_package(hiprtc REQUIRED)
1516

1617
rocm_setup_version(VERSION 1.0)
1718

@@ -27,7 +28,7 @@ add_compile_options(-std=c++20)
2728
file(GLOB SOURCES CONFIGURE_DEPENDS src/*.cpp)
2829
# TODO: Use object library
2930
add_library(ck_host STATIC ${SOURCES})
30-
target_link_libraries(ck_host PRIVATE ck_headers)
31+
target_link_libraries(ck_host PRIVATE ck_headers hiprtc::hiprtc)
3132

3233
set_target_properties(ck_host PROPERTIES
3334
LINKER_LANGUAGE CXX

example/01_gemm/gemm_xdl_skip_b_lds_fp16.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ using BDataType = ck::half_t;
3636
using CDataType = ck::half_t;
3737
using AccDataType = float;
3838
#else
39-
< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 4, 4, 7, 1>;
39+
< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 16, 128, 4, 4, 16, 16, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 4, 4, 7, 1>;
4040
using ADataType = float;
4141
using BDataType = float;
4242
using CDataType = float;
@@ -185,7 +185,6 @@ int main(int argc, char* argv[])
185185
auto a_element_op = AElementOp{};
186186
auto b_element_op = BElementOp{};
187187
auto c_element_op = CElementOp{};
188-
189188
// do GEMM
190189
auto gemm = DeviceGemmInstance{};
191190
auto invoker = gemm.MakeInvoker();
@@ -209,8 +208,7 @@ int main(int argc, char* argv[])
209208
return 0;
210209
}
211210

212-
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
213-
211+
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
214212
std::size_t flop = std::size_t(2) * M * N * K;
215213
std::size_t num_btype =
216214
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;

example/01_gemm/run_gemm_example.inc

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
33

44
#pragma once
5-
#include "ck/library/utility/validation_common.hpp"
65

76
// use macro to minimize code change
87
#ifndef EXAMPLE_WITH_COMPUTE_DATATYPE
@@ -29,11 +28,11 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
2928
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
3029
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
3130
{
32-
return HostTensorDescriptor({row, col}, {stride, 1_uz});
31+
return HostTensorDescriptor({row, col}, {stride, 1_uz}, layout);
3332
}
3433
else
3534
{
36-
return HostTensorDescriptor({row, col}, {1_uz, stride});
35+
return HostTensorDescriptor({row, col}, {1_uz, stride}, layout);
3736
}
3837
};
3938

@@ -59,17 +58,6 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
5958
StrideB = f_get_default_stride(K, N, StrideB, BLayout{});
6059
StrideC = f_get_default_stride(M, N, StrideC, CLayout{});
6160

62-
try
63-
{
64-
ck::utils::validate_gemm_strides_abc<ALayout, BLayout, CLayout>(
65-
M, N, K, StrideA, StrideB, StrideC);
66-
}
67-
catch(const std::runtime_error& e)
68-
{
69-
std::cerr << "Error: " << e.what() << std::endl;
70-
return false;
71-
}
72-
7361
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
7462
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
7563

example/03_gemm_bias_relu/gemm_bias_relu_xdl_fp16.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,9 @@ int main(int argc, char* argv[])
174174
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
175175
Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
176176

177+
const auto StrideD = std::is_same<decltype(ELayout{}), ck::tensor_layout::gemm::RowMajor>::value
178+
? d_m_n.mDesc.GetStrides()[0]
179+
: d_m_n.mDesc.GetStrides()[1];
177180
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
178181
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
179182
std::cout << "d_m_n: " << d_m_n.mDesc << std::endl;
@@ -221,7 +224,7 @@ int main(int argc, char* argv[])
221224
K,
222225
StrideA,
223226
StrideB,
224-
std::array<ck::index_t, 1>{0},
227+
std::array<ck::index_t, 1>{static_cast<int>(StrideD)},
225228
StrideE,
226229
a_element_op,
227230
b_element_op,

example/04_gemm_add_add_fastgelu/run_gemm_add_add_fastgelu_example.inc

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@ bool run_gemm_add_add_fastgelu(const ProblemSize& problem_size, const ExecutionC
77
#endif
88
using namespace ck::literals;
99

10-
auto& [M, N, K, StrideA, StrideB, StrideD0, StrideD1, StrideE] = problem_size;
10+
ProblemSize ps =
11+
problem_size; // make mutable copy because default stride values of 0 need to be updated
12+
auto& [M, N, K, StrideA, StrideB, StrideD0, StrideD1, StrideE] = ps;
1113

1214
auto f_host_tensor_descriptor =
1315
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
@@ -41,6 +43,30 @@ bool run_gemm_add_add_fastgelu(const ProblemSize& problem_size, const ExecutionC
4143
std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl;
4244
std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl;
4345

46+
// If any user-provided leading stride <= 0, replace it with the one determined by the
47+
// created tensor descriptor. For RowMajor the leading stride is index 0, for ColMajor index 1.
48+
auto fetch_leading_stride = [](const auto& tensor, auto layout_tag) -> int {
49+
if constexpr(std::is_same_v<decltype(layout_tag), ck::tensor_layout::gemm::RowMajor>)
50+
{
51+
return static_cast<int>(tensor.GetStrides()[0]);
52+
}
53+
else
54+
{
55+
return static_cast<int>(tensor.GetStrides()[1]);
56+
}
57+
};
58+
59+
if(StrideA <= 0)
60+
StrideA = fetch_leading_stride(a_m_k, ALayout{});
61+
if(StrideB <= 0)
62+
StrideB = fetch_leading_stride(b_k_n, BLayout{});
63+
if(StrideD0 <= 0)
64+
StrideD0 = fetch_leading_stride(d0_m_n, D0Layout{});
65+
if(StrideD1 <= 0)
66+
StrideD1 = fetch_leading_stride(d1_m_n, D1Layout{});
67+
if(StrideE <= 0)
68+
StrideE = fetch_leading_stride(e_m_n_host_result, ELayout{});
69+
4470
switch(config.init_method)
4571
{
4672
case 0: break;

0 commit comments

Comments
 (0)