Skip to content

Conversation

msaffari-amd
Copy link
Contributor

Proposed changes

Contraction + Multi D Kernel in ck_tile using universal gemm. kernel supports multi G, M, N, K dimensions to do contraction on inputs and apply element wise on multi number of D tensors. an example to has been implemented too.

Checklist

Please put an x into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.

  • I have added tests relevant to the introduced functionality, and the unit tests are passing locally
  • I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, IF the test takes more than 30 seconds to run.
  • I have added inline documentation which enables the maintainers with understanding the motivation
  • I have removed the stale documentation which is no longer relevant after this pull request
  • (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request
  • I have run clang-format on all changed files
  • Any dependent changes have been merged

Discussion

If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered

@msaffari-amd msaffari-amd force-pushed the LWPCK-3688-cktile-contraction-multi-D branch from 44f78c9 to 5d83464 Compare September 23, 2025 08:25
@bartekxk bartekxk changed the title Lwpck 3688 cktile contraction multi d [CK Tile] contraction multi d Sep 23, 2025
@bartekxk bartekxk requested a review from Copilot September 23, 2025 22:00
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR introduces a batched contraction implementation for CK Tile that supports multi-dimensional tensor contractions using the universal GEMM kernel. The implementation supports tensors with multiple G, M, N, K dimensions and applies element-wise operations on multiple D tensors.

  • Multi-dimensional batched contraction kernel with configurable tensor dimensions
  • Universal GEMM-based implementation with support for split-K batching
  • Complete example implementation with CPU reference validation

Reviewed Changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
include/ck_tile/ops/batched_contraction/pipeline/batched_contraction_problem.hpp Defines problem template for batched contraction operations
include/ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp Main kernel implementation with host/device argument structures
include/ck_tile/ops/batched_contraction/kernel/batched_conratction_utils.hpp Utility header (currently empty)
include/ck_tile/ops/batched_contraction.hpp Main include header aggregating all contraction components
example/ck_tile/CMakeLists.txt Adds batched contraction example to build
example/ck_tile/40_batched_contraction/* Complete example implementation with utilities and test cases

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

bartekxk
bartekxk previously approved these changes Sep 25, 2025
ck_tile::index_t NumDimN,
ck_tile::index_t NumDimK,
ck_tile::index_t NumDTensor = 0>
struct BatchedContractionKernelArgs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I look at those kernel args, they seem to me a bit unclear. Comparing it to old CK definition of this op: https://github.com/ROCm/composable_kernel/blob/a0e48cb317ad8c3dfb9a188c44ba4ef8f1364cb3/include/ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d.hpp CK Tile's one is unclear to me. What for are ..total values?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this implementation leverages the universal_gemm kernel.
The _total are used as a bridge between tensor contraction and GEMM.
These _total fields represent the products of dimensions within each group:

  • M_total = M0 × M1 × ... × M_{NumDimM-1}
  • N_total = N0 × N1 × ... × N_{NumDimN-1}
  • K_total = K0 × K1 × ... × K_{NumDimK-1}
  • G_total = G0 × G1 × ... × G_{NumDimG-1} (batch size)

typename TilePartitioner_,
typename GemmPipeline_,
typename EpiloguePipeline_>
struct BatchedContractionKernel
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be good to have some at least short documentation of this operation. Like in old CK:

// Tensor Contraction:
// input : A
// input : B
// input : D0, D1, ...
// output : E
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// A[G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...]
// B[G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...]
// D[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...]
// E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I add them

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants