-
Notifications
You must be signed in to change notification settings - Fork 239
[CK Tile] contraction multi d #2901
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Conversation
…el. it is a temporary commit
…fferent reference calculation algorithms
…hes and some code cleaning
44f78c9
to
5d83464
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR introduces a batched contraction implementation for CK Tile that supports multi-dimensional tensor contractions using the universal GEMM kernel. The implementation supports tensors with multiple G, M, N, K dimensions and applies element-wise operations on multiple D tensors.
- Multi-dimensional batched contraction kernel with configurable tensor dimensions
- Universal GEMM-based implementation with support for split-K batching
- Complete example implementation with CPU reference validation
Reviewed Changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated 4 comments.
Show a summary per file
File | Description |
---|---|
include/ck_tile/ops/batched_contraction/pipeline/batched_contraction_problem.hpp | Defines problem template for batched contraction operations |
include/ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp | Main kernel implementation with host/device argument structures |
include/ck_tile/ops/batched_contraction/kernel/batched_conratction_utils.hpp | Utility header (currently empty) |
include/ck_tile/ops/batched_contraction.hpp | Main include header aggregating all contraction components |
example/ck_tile/CMakeLists.txt | Adds batched contraction example to build |
example/ck_tile/40_batched_contraction/* | Complete example implementation with utilities and test cases |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
include/ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp
Outdated
Show resolved
Hide resolved
include/ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp
Outdated
Show resolved
Hide resolved
include/ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp
Outdated
Show resolved
Hide resolved
example/ck_tile/40_batched_contraction/run_batched_contraction_example.inc
Outdated
Show resolved
Hide resolved
example/ck_tile/40_batched_contraction/run_batched_contraction_example.inc
Outdated
Show resolved
Hide resolved
example/ck_tile/40_batched_contraction/run_batched_contraction_example.inc
Outdated
Show resolved
Hide resolved
include/ck_tile/ops/batched_contraction/pipeline/batched_contraction_problem.hpp
Show resolved
Hide resolved
…ased on review feedback
example/ck_tile/40_batched_contraction/run_batched_contraction_example.inc
Outdated
Show resolved
Hide resolved
ck_tile::index_t NumDimN, | ||
ck_tile::index_t NumDimK, | ||
ck_tile::index_t NumDTensor = 0> | ||
struct BatchedContractionKernelArgs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When I look at those kernel args, they seem to me a bit unclear. Comparing it to old CK definition of this op: https://github.com/ROCm/composable_kernel/blob/a0e48cb317ad8c3dfb9a188c44ba4ef8f1364cb3/include/ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d.hpp CK Tile's one is unclear to me. What for are ..total
values?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this implementation leverages the universal_gemm kernel.
The _total
are used as a bridge between tensor contraction and GEMM.
These _total fields represent the products of dimensions within each group:
M_total
= M0 × M1 × ... × M_{NumDimM-1}N_total
= N0 × N1 × ... × N_{NumDimN-1}K_total
= K0 × K1 × ... × K_{NumDimK-1}G_total
= G0 × G1 × ... × G_{NumDimG-1} (batch size)
include/ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp
Outdated
Show resolved
Hide resolved
typename TilePartitioner_, | ||
typename GemmPipeline_, | ||
typename EpiloguePipeline_> | ||
struct BatchedContractionKernel |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be good to have some at least short documentation of this operation. Like in old CK:
composable_kernel/include/ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d.hpp
Lines 15 to 26 in a0e48cb
// Tensor Contraction: | |
// input : A | |
// input : B | |
// input : D0, D1, ... | |
// output : E | |
// C = a_op(A) * b_op(B) | |
// E = cde_op(C, D0, D1, ...) | |
// Assume: | |
// A[G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...] | |
// B[G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...] | |
// D[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...] | |
// E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, I add them
Proposed changes
Contraction + Multi D Kernel in ck_tile using universal gemm. kernel supports multi G, M, N, K dimensions to do contraction on inputs and apply element wise on multi number of D tensors. an example to has been implemented too.
Checklist
Please put an
x
into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.clang-format
on all changed filesDiscussion
If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered