-
Notifications
You must be signed in to change notification settings - Fork 240
Hstu attention n0loop fused unroll pr #2896
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from 169 commits
4a0fc29
83f2924
121a950
7337345
10e72d3
dbcf38a
dc2f72a
561d490
9cb2dca
86c0e45
dd2cd2c
1766e6d
71697d9
53e5679
238e78d
c2e6ab8
fff13b6
cad1356
3cd1b13
d1749b3
226a254
1351d9c
6086ead
b0ae270
ca1ae84
f12a472
88e54a8
efc786f
ee259a8
2546e90
677fd60
58ab553
65ddb1a
26db7e0
7316a44
022ed3f
8dcde8d
2d2e194
ce46652
aec1917
7848d15
cea919a
a41371f
05910eb
7818cce
4a49119
80677eb
4ae9acd
27f7ab4
9996270
1b463e9
95c93ba
054c397
1af2702
f53be61
f1f4e24
d63dab9
2972de4
da89540
611f2ce
374e062
72d55d1
079f7e3
632fd06
d32851e
1d1dd8f
79cd1f0
3a320bc
c3761c3
5b0a261
b0d3704
5869687
0771390
58e45ec
473fbc3
7c0ac51
afd7793
694295a
ff3415d
e4e70f8
4e65469
f582c21
902b1c6
f411d67
14ab6f1
fac03ab
29cf161
0a8ea6b
a1346aa
81f7b13
dc0977f
c9e1935
10c3512
68a5ab8
36a0f20
781cba3
832747c
bec35ab
9582ae2
2bb59df
9e6a240
d7930cd
b2db644
84eb9ad
4632d30
08886e9
9e62359
09ac146
f9caae2
a5f24d7
4fa6474
463a198
c87a217
63a47d7
dc7e62a
60d8ffb
3c300d3
5451912
8d30e46
6825618
9171b35
0206b34
c016955
fdd9c11
0306a1e
f0c8dca
ed062f9
fed1474
acb6cd8
906ab84
fcd41a6
ecf6a86
203e22b
ce6a044
f49fe28
cf012c2
01c123d
43a9768
29d3dc9
3483af0
de71d33
7c9032d
f27d8ce
ae05715
4026122
fd25f5d
971d0d9
1404336
832ef5d
d5d4a0d
30dd274
d2b0f75
fb09061
89cd5ff
7b68b6e
21ca848
328fc71
4281f54
2357499
1aa6cf8
4828e4b
9ae76b2
c5994e1
f8398f5
75a7332
8a01016
072459c
fec6e8d
090459d
a57413e
f158b32
fb22e5e
6ba01e7
7a67ddc
cc16906
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
set(EXAMPLE_HSTU_ATTENTION "tile_example_hstu_attention") | ||
# not using add_example_executable() to add this target, since we don't want this to have | ||
# to be included in "make all/install/check" | ||
message("adding example ${EXAMPLE_HSTU_ATTENTION}") | ||
file(GLOB INSTANCE_SRCS instances/*.cpp) | ||
set(INTERFACES_SRCS hstu_attention_jagged_forward_bf16.cpp hstu_attention_jagged_forward_fp16.cpp hstu_attention_batched_forward_bf16.cpp hstu_attention_batched_forward_fp16.cpp) | ||
add_executable(${EXAMPLE_HSTU_ATTENTION} EXCLUDE_FROM_ALL example_hstu_attention.cpp) | ||
target_include_directories(${EXAMPLE_HSTU_ATTENTION} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) | ||
target_sources(${EXAMPLE_HSTU_ATTENTION} PRIVATE ${INTERFACES_SRCS} ${INSTANCE_SRCS}) | ||
|
||
set(EXAMPLE_HSTU_ATTENTION_COMPILE_OPTIONS) | ||
|
||
list(APPEND EXAMPLE_HSTU_ATTENTION_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=3) | ||
|
||
if (DEFINED ENV{ASSUME_HIGHLY_VARIED_SEQLEN}) | ||
list(APPEND EXAMPLE_HSTU_ATTENTION_COMPILE_OPTIONS -DHSTU_SCHED_BATCH_AS_FIRST_GRID_DIM=0) | ||
endif() | ||
|
||
target_compile_options(${EXAMPLE_HSTU_ATTENTION} PRIVATE ${EXAMPLE_HSTU_ATTENTION_COMPILE_OPTIONS}) | ||
|
||
# TODO: we have to turn off this global prop, otherwise the progress bar generated | ||
# by cmake will print too many files, execvp: /bin/sh: Argument list too long | ||
# however, this property may affect global | ||
# TODO: consider codegen a makefile by us | ||
set_property(GLOBAL PROPERTY RULE_MESSAGES OFF) | ||
|
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,64 @@ | ||||||||
# HSTU attention operator | ||||||||
|
||||||||
HSTU-attention operator is an operator which takes tensor `q: [batches, seqlen, nhead, hdim_qk]`, `k: [batches, seqlen, nhead, hdim_qk`, | ||||||||
`v: [batches, seqlen, nhead, hdim_v]` and some parameters for defining the functional masking as inputs, and do the following: | ||||||||
|
`v: [batches, seqlen, nhead, hdim_v]` and some parameters for defining the functional masking as inputs, and do the following: | |
`v: [batches, seqlen, nhead, hdim_v]`, as well as parameters that define functional masking to do the following: | |
`` |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
* Multiply `q: [batches, seqlen, nhead, hdim_qk]` with `k: [batches, seqlen, nhead, hdim_k]` to get temporary tensor `s: [batches, nhead, seqlen, seqlen]` | |
* Multiply `q: [batches, seqlen, nhead, hdim_qk]` with `k: [batches, seqlen, nhead, hdim_k]` to get the intermediate tensor `s: [batches, nhead, seqlen, seqlen]` |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
* Update `s` by filtering its values according to a special functional mask, which includes the logics of lower-triangular and diagonal window causal mask | |
* Update `s` by filtering it with a functional mask that includes a lower-triangular mask, a diagonal window causal mask, and |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as well assequence mask | |
a sequence mask. |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
* Do element-wise SiLu on the `lower seqlen` dimension of `s` to get temporary tensor `p: [batches, nhead, seqlen, seqlen]` | |
* Do element-wise SiLu on the `lower seqlen` dimension of `s` to get the intermediate tensor `p: [batches, nhead, seqlen, seqlen]` |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
* Multiply `p : [batches, nhead, seqlen, seqlen]` with `v: [batches, seqlen, nhead, hdim_v]` to get final output `o: [batches, seqlen_q, nhead, headsz_v]` | |
* Multiply `p : [batches, nhead, seqlen, seqlen]` with `v: [batches, seqlen, nhead, hdim_v]` to get the final tensor `o: [batches, seqlen_q, nhead, headsz_v]` |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
* Jagged inputs are also supported, where each batch has separate seqlen defined by the `sequence_offsets[]` | |
Jagged inputs are also supported, where each batch has separate seqlen defined by the `sequence_offsets[]` |
This isn't a thing that the operator does, so it shouldn't be in the same bullet list
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The operator is implemented using a fused kernel in the example: | |
The operator is implemented using a fused kernel: |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't need to be a bullet point. You can combine it with the sentence above.
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Check the example file `example_hstu_attention.cpp` for an understanding of the command-line arguments. Which is like the following: | |
Check the example file `example_hstu_attention.cpp` for more information about the command-line arguments. |
To be honest, I'd rather see the explanations here, or at least have the code snippet commented. It doesn't need to be everything but some of it.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,281 @@ | ||
// SPDX-License-Identifier: MIT | ||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. | ||
|
||
#pragma once | ||
|
||
#include "ck_tile/core.hpp" | ||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp" | ||
|
||
namespace ck_tile { | ||
|
||
// A is block distributed tensor | ||
// B is block window on shared memory | ||
// C is block distributed tensor | ||
template <typename Problem_, typename Policy_ = BlockGemmARegBSmemCRegV2DefaultPolicy> | ||
struct BlockGemmARegBSmemCRegV2Hack_0 | ||
{ | ||
using Problem = remove_cvref_t<Problem_>; | ||
using Policy = remove_cvref_t<Policy_>; | ||
using ADataType = remove_cvref_t<typename Problem::ADataType>; | ||
using BDataType = remove_cvref_t<typename Problem::BDataType>; | ||
using CDataType = remove_cvref_t<typename Problem::CDataType>; | ||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>; | ||
|
||
static constexpr index_t kBlockSize = Problem::kBlockSize; | ||
|
||
// C += A * B | ||
template <typename CBlockTensor, typename ABlockTensorTmp, typename BBlockWindowTmp> | ||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, | ||
const ABlockTensorTmp& a_block_tensor_tmp, | ||
const BBlockWindowTmp& b_block_window_tmp) const | ||
{ | ||
static_assert( | ||
std::is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> && | ||
std::is_same_v<BDataType, remove_cv_t<typename BBlockWindowTmp::DataType>> && | ||
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>, | ||
"wrong!"); | ||
|
||
constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}]; | ||
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; | ||
constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}]; | ||
|
||
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && | ||
KPerBlock == BlockGemmShape::kK, | ||
"wrong!"); | ||
|
||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>(); | ||
|
||
using WG = remove_cvref_t<decltype(config.template at<0>())>; | ||
|
||
constexpr index_t MWarp = config.template at<1>(); | ||
constexpr index_t NWarp = config.template at<2>(); | ||
|
||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); | ||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); | ||
constexpr index_t KIterPerWarp = KPerBlock / WG::kK; | ||
|
||
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp; | ||
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; | ||
|
||
const index_t iNWarp = get_warp_id() % NWarp; | ||
|
||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< | ||
sequence<>, | ||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>, | ||
tuple<sequence<1, 2>>, | ||
tuple<sequence<1, 1>>, | ||
sequence<1, 2>, | ||
sequence<0, 0>>{}; | ||
|
||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( | ||
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); | ||
|
||
// constrcut from A-block-tensor from A-Block-tensor-tmp | ||
// FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent | ||
// distribution | ||
auto a_block_tensor = make_static_distributed_tensor<typename ABlockTensorTmp::DataType>( | ||
MakeABlockTileDistribution()); | ||
|
||
a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer(); | ||
|
||
// construct B-warp-window | ||
auto b_warp_window_tmp = make_tile_window( | ||
b_block_window_tmp.get_bottom_tensor_view(), | ||
make_tuple(number<WG::kN>{}, number<WG::kK>{}), | ||
b_block_window_tmp.get_window_origin() + multi_index<2>{iNWarp * WG::kN, 0}, | ||
make_static_tile_distribution(typename WG::BWarpDstrEncoding{})); | ||
|
||
#if 0 // FIXME: using array will cause register spill | ||
array<array<decltype(b_warp_window_tmp), KIterPerWarp>, NIterPerWarp> b_warp_windows{ | ||
{b_warp_window_tmp}}; | ||
|
||
for(index_t nIter = 0; nIter < NIterPerWarp; nIter++) | ||
{ | ||
for(index_t kIter = 0; kIter < KIterPerWarp; kIter++) | ||
{ | ||
move_tile_window(b_warp_windows(nIter)(kIter), | ||
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); | ||
} | ||
} | ||
#else | ||
statically_indexed_array< | ||
statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>, | ||
NIterPerWarp> | ||
b_warp_windows; | ||
#endif | ||
|
||
// check C-block-distribution | ||
static_assert( | ||
std::is_same_v<remove_cvref_t<decltype(c_block_dstr_encode)>, | ||
remove_cvref_t<decltype(CBlockTensor::get_tile_distribution() | ||
.get_static_tile_distribution_encoding())>>, | ||
"wrong!"); | ||
|
||
using AWarpDstr = typename WG::AWarpDstr; | ||
using CWarpDstr = typename WG::CWarpDstr; | ||
|
||
using AWarpTensor = typename WG::AWarpTensor; | ||
using CWarpTensor = typename WG::CWarpTensor; | ||
|
||
constexpr auto a_warp_y_lengths = | ||
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); | ||
constexpr auto c_warp_y_lengths = | ||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); | ||
|
||
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{}; | ||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{}; | ||
|
||
constexpr auto I0 = number<0>{}; | ||
constexpr auto I1 = number<1>{}; | ||
|
||
// hot loop: | ||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { | ||
using b_warp_tensor_type = decltype(load_tile(b_warp_windows(I0)(I0))); | ||
|
||
statically_indexed_array<b_warp_tensor_type, KIterPerWarp> b_warp_tensors; | ||
|
||
b_warp_windows(nIter)(I0) = b_warp_window_tmp; | ||
move_tile_window(b_warp_windows(nIter)(I0), | ||
{nIter * NPerBlockPerIter, 0 * KPerBlockPerIter}); | ||
b_warp_tensors[I0] = load_tile(b_warp_windows(nIter)(I0)); | ||
|
||
__builtin_amdgcn_sched_barrier(0); | ||
|
||
b_warp_windows(nIter)(I1) = b_warp_window_tmp; | ||
move_tile_window(b_warp_windows(nIter)(I1), | ||
{nIter * NPerBlockPerIter, 1 * KPerBlockPerIter}); | ||
b_warp_tensors[I1] = load_tile(b_warp_windows(nIter)(I1)); | ||
|
||
__builtin_amdgcn_sched_barrier(0); | ||
|
||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { | ||
// read A warp tensor from A block tensor | ||
AWarpTensor a_warp_tensor; | ||
|
||
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( | ||
merge_sequences(sequence<mIter, 0>{}, a_warp_y_index_zeros), | ||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); | ||
|
||
// warp GEMM | ||
auto c_warp_tensor = WG{}(a_warp_tensor, b_warp_tensors[I0]); | ||
// WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_array[nIter]); | ||
|
||
// write C warp tensor into C block tensor | ||
c_block_tensor.set_y_sliced_thread_data( | ||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros), | ||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), | ||
c_warp_tensor.get_thread_buffer()); | ||
}); | ||
|
||
static_for<1, KIterPerWarp, 1>{}([&](auto kIter) { | ||
// read B warp tensor from B Block window | ||
if constexpr(kIter < KIterPerWarp - 1) | ||
{ | ||
b_warp_windows(nIter)(number<kIter + 1>{}) = b_warp_window_tmp; | ||
move_tile_window(b_warp_windows(nIter)(number<kIter + 1>{}), | ||
{nIter * NPerBlockPerIter, (kIter + 1) * KPerBlockPerIter}); | ||
b_warp_tensors[number<kIter + 1>{}] = | ||
load_tile(b_warp_windows(nIter)(number<kIter + 1>{})); | ||
}; | ||
|
||
__builtin_amdgcn_sched_barrier(0); | ||
|
||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { | ||
// read A warp tensor from A block tensor | ||
AWarpTensor a_warp_tensor; | ||
|
||
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( | ||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros), | ||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); | ||
|
||
// read C warp tensor from C block tensor | ||
CWarpTensor c_warp_tensor; | ||
|
||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( | ||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros), | ||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); | ||
|
||
// warp GEMM | ||
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensors[kIter]); | ||
// WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_array[nIter]); | ||
|
||
// write C warp tensor into C block tensor | ||
c_block_tensor.set_y_sliced_thread_data( | ||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros), | ||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), | ||
c_warp_tensor.get_thread_buffer()); | ||
}); | ||
}); | ||
}); | ||
} | ||
|
||
template <index_t MPerBlock = BlockGemmShape::kM, index_t KPerBlock = BlockGemmShape::kK> | ||
CK_TILE_DEVICE static constexpr auto MakeABlockTileDistribution() | ||
{ | ||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>(); | ||
|
||
using WG = remove_cvref_t<decltype(config.template at<0>())>; | ||
|
||
constexpr index_t MWarp = config.template at<1>(); | ||
constexpr index_t NWarp = config.template at<2>(); | ||
|
||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); | ||
constexpr index_t KIterPerWarp = KPerBlock / WG::kK; | ||
|
||
constexpr auto a_block_outer_dstr_encoding = | ||
tile_distribution_encoding<sequence<NWarp>, | ||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>, | ||
tuple<sequence<1, 0>>, | ||
tuple<sequence<1, 0>>, | ||
sequence<1, 2>, | ||
sequence<0, 0>>{}; | ||
|
||
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( | ||
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); | ||
|
||
return make_static_tile_distribution(a_block_dstr_encode); | ||
} | ||
|
||
CK_TILE_DEVICE static constexpr auto MakeCBlockTile() | ||
{ | ||
constexpr index_t MPerBlock = BlockGemmShape::kM; | ||
constexpr index_t NPerBlock = BlockGemmShape::kN; | ||
|
||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>(); | ||
|
||
using WG = remove_cvref_t<decltype(config.template at<0>())>; | ||
|
||
constexpr index_t MWarp = config.template at<1>(); | ||
constexpr index_t NWarp = config.template at<2>(); | ||
|
||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); | ||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); | ||
// constexpr index_t KIterPerWarp = KPerBlock / WG::kK; | ||
|
||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< | ||
sequence<>, | ||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>, | ||
tuple<sequence<1, 2>>, | ||
tuple<sequence<1, 1>>, | ||
sequence<1, 2>, | ||
sequence<0, 0>>{}; | ||
|
||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( | ||
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); | ||
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); | ||
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr); | ||
return c_block_tensor; | ||
} | ||
|
||
// C = A * B | ||
template <typename ABlockTensorTmp, typename BBlockWindowTmp> | ||
CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp, | ||
const BBlockWindowTmp& b_block_window_tmp) const | ||
{ | ||
auto c_block_tensor = MakeCBlockTile(); | ||
operator()(c_block_tensor, a_block_tensor_tmp, b_block_window_tmp); | ||
return c_block_tensor; | ||
} | ||
}; | ||
|
||
} // namespace ck_tile |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.