Skip to content
Open
Show file tree
Hide file tree
Changes from 169 commits
Commits
Show all changes
175 commits
Select commit Hold shift + click to select a range
4a0fc29
Initial reference implementation of hstu attention
qianfengz Mar 28, 2025
83f2924
fix the jagged mode tensor access in reference_hstu_attention
qianfengz Mar 29, 2025
121a950
Add hstu attention kernel implementation, instances and interfaces (b…
qianfengz Apr 3, 2025
7337345
Fix and change in example
qianfengz Apr 3, 2025
10e72d3
Change in HstBlockMasking and kernel/reference codes for using masking
qianfengz Apr 3, 2025
dbcf38a
Fixes and updates
qianfengz Apr 7, 2025
dc2f72a
Fix in hstu-attention pipeline (which makes some testing cases passed)
qianfengz Apr 8, 2025
561d490
Fix in kernel and forward dispatch for jagged mode
qianfengz Apr 8, 2025
9cb2dca
Add several verification test cases
qianfengz Apr 8, 2025
86c0e45
Add benchmark_hstu_attention.sh
qianfengz Apr 9, 2025
dd2cd2c
Tune the input initialization to avoid over-flow in silu
qianfengz Apr 9, 2025
1766e6d
Update to the scripts and error thresholds
qianfengz Apr 9, 2025
71697d9
Add output of estimated TFLOPS
qianfengz Apr 9, 2025
53e5679
Fix in calculation of total_flops and update benchmark scripts
qianfengz Apr 13, 2025
238e78d
Update the in pipeline codes
qianfengz Apr 13, 2025
c2e6ab8
Add IsFirstVLdsBufferOverlapLastKLdsBuffer() check to reduce call of …
qianfengz Apr 13, 2025
fff13b6
Update to partially reduce the register spilling
qianfengz Apr 15, 2025
cad1356
Use packed cast_tile for fp16
qianfengz Apr 15, 2025
3cd1b13
Split HstuBlockMasking into HstuBlockMaskWithLocal and HstuBlockMaskN…
qianfengz Apr 15, 2025
d1749b3
Use kM0=128 kN0=64 to completely remove the vgprs spilling
qianfengz Apr 15, 2025
226a254
Remove the comparing of row/col to max_uih_len in masking
qianfengz Apr 16, 2025
1351d9c
Use exp2() to calculate exp() for better performance
qianfengz Apr 16, 2025
6086ead
Add scripts for comparing with triton
qianfengz Apr 17, 2025
b0ae270
Fix the integer overflow in total_flops calculation
qianfengz Apr 17, 2025
ca1ae84
Remove one line of __builtin_amdgcn_sched_barrier(0)
qianfengz Apr 17, 2025
f12a472
Tiny codes simplification in pipeline
qianfengz Apr 18, 2025
88e54a8
Use shared ring Lds buffers for K/V to avoid over-lapping between fir…
qianfengz Apr 18, 2025
efc786f
Remove un-needed __builtin_amdgcn_sched_barrier(0)
qianfengz Apr 18, 2025
ee259a8
Fix the GetTileRangeAlongX() to align with the hstu masking definitio…
qianfengz Apr 18, 2025
2546e90
Change gemm0 to iterate along kN0 so that BlockGemm can overlap with …
qianfengz Apr 19, 2025
677fd60
Add script compare_with_triton_2.sh for measuring the jagged cases of…
qianfengz Apr 22, 2025
58ab553
Fix in GetTileRangeAlongX
qianfengz Apr 22, 2025
65ddb1a
Fix the script name
qianfengz Apr 22, 2025
26db7e0
Use kN0=64 to save vgprs
qianfengz Apr 22, 2025
7316a44
Update exp() in ck_tile/core/numeric/math.hpp to use __expf
qianfengz Apr 21, 2025
022ed3f
Back to use exp() instead of exp2() since exp() in ck_tile using fast…
qianfengz Apr 21, 2025
8dcde8d
Fix in generate_instances.py and re-generated the instances
qianfengz Apr 23, 2025
2d2e194
Update in using masking for the case where kMasking is false and kPad…
qianfengz Apr 23, 2025
ce46652
Move silu calculation to gemm1 iteration and try to interleave gemm_1…
qianfengz Apr 23, 2025
aec1917
Combine minus with scale_s
qianfengz Apr 24, 2025
7848d15
Using __builtin_amdgcn_rcpf in siLU function
qianfengz Apr 24, 2025
cea919a
Use 16x16x16 WarpGemm
qianfengz Apr 24, 2025
a41371f
Update in K-Lds laying-out to consider for both WarpGemm-32x32x16 and…
qianfengz Apr 24, 2025
05910eb
Add support for WarpGem-16x16x32 in QK-BlockGemm (which enables using…
qianfengz Apr 25, 2025
7818cce
Rename the performance measurement scripts
qianfengz Apr 25, 2025
4a49119
Update the seqlen_k_curr inside the first gemm loop
qianfengz Apr 25, 2025
80677eb
Code re-arrangement in pipeline
qianfengz Apr 25, 2025
4ae9acd
Revert "Update exp() in ck_tile/core/numeric/math.hpp to use __expf"
qianfengz Apr 25, 2025
27f7ab4
Use compiler builtin directly in f_silu for float type
qianfengz Apr 25, 2025
9996270
Tiny update in IsTokenPairInsideMask()
qianfengz Apr 25, 2025
1b463e9
Add scripts for measuring jagged with/no causal cases
qianfengz Apr 25, 2025
95c93ba
Update the GridSize() and GetTileIndex() in hstu kernel
qianfengz Apr 26, 2025
054c397
Replace set_tile_if() by sweep_tile_span() to reduce branching
qianfengz Apr 27, 2025
1af2702
Add IsFullTileInsideMask() to avoid pixel-by-pixel checking when kUse…
qianfengz Apr 27, 2025
f53be61
Put two gemms call inside one n0loop unroll
qianfengz Apr 28, 2025
f1f4e24
Adjust the v_tile and k_tile loading location
qianfengz Apr 28, 2025
d63dab9
Hack block_gemm_areg_bsmem_creg_v2 to let s_acc for gemm_0 not need b…
qianfengz Apr 28, 2025
2972de4
Temporarily close the instance for hdim64 and hdim256 to save compili…
qianfengz Apr 30, 2025
da89540
Use kN0=32
qianfengz Apr 30, 2025
611f2ce
Override and fix GetAlignmentK()
qianfengz May 3, 2025
374e062
Remove using cast_tile_pk_fp16_fp32 for better accuracy for fp16 hstu…
qianfengz May 6, 2025
72d55d1
Add max_seqlen as divider in siLu
qianfengz May 6, 2025
079f7e3
Use type_convert rather than static_cast in f_silu
qianfengz May 7, 2025
632fd06
Use kK1=16
qianfengz May 7, 2025
d32851e
Simplification in the static iterations of block_gemm_areg_bsmem_creg…
qianfengz May 7, 2025
1d1dd8f
Revert "Temporarily close the instance for hdim64 and hdim256 to save…
qianfengz May 7, 2025
79cd1f0
Fix sequence dim length for o_dram descriptor in the kernel
qianfengz May 10, 2025
3a320bc
Add test cases for better functional verification
qianfengz May 10, 2025
c3761c3
Update the rules of hstu masking
qianfengz May 13, 2025
5b0a261
Add -save_mask option to the example to output int8 mask tensor
qianfengz May 14, 2025
b0d3704
Add scripts (test_ck_hstu_mask.sh and test_pytorch_hstu_mask.py) for …
qianfengz May 14, 2025
5869687
Set example option -save_mask default to 0
qianfengz May 14, 2025
0771390
Move the dividing by max_seqlen out of f_silu to be handle outside th…
qianfengz May 18, 2025
58e45ec
Move the lambda for dividing by max_seqlen from kernel to pipeline
qianfengz May 18, 2025
473fbc3
Rename the hacked block_gemm_areg_bsmem_creg_v2
qianfengz May 15, 2025
7c0ac51
Hack block_gemm_areg_bsmem_creg_v2 for gemm_1
qianfengz May 15, 2025
afd7793
Prefetch K for next iteration from LDS in block_gemm_areg_bsmem_creg …
qianfengz May 18, 2025
694295a
Move b_warp_windows construction into k-iteration in block_gemm_areg_…
qianfengz May 18, 2025
ff3415d
Prefetch b_warp_tensor for next nIter and move b_warp_windows constru…
qianfengz May 18, 2025
e4e70f8
Set the block_per_cu to 3 for hdim-128
qianfengz May 18, 2025
4e65469
Add _builtin_amdgcn_sched_barrier(0) for instructing the compiler for…
qianfengz May 18, 2025
f582c21
Replace s_acc and pcomp tile array by single tile object for simplifi…
qianfengz May 19, 2025
902b1c6
Move k_tile loading in the loop earlier
qianfengz May 19, 2025
f411d67
Move k_tile loading and v_tile loading earlier in the loop
qianfengz May 19, 2025
14ab6f1
Adjust the codes before the main-loop
qianfengz May 19, 2025
fac03ab
Change do-while main-loop to while-do and remove early exiting check
qianfengz May 19, 2025
29cf161
Enable RTN fp32 to bf16 conversion by adding compiler option in CMake…
qianfengz May 20, 2025
0a8ea6b
Adjust the threshold values for fp16/bf16 in the example
qianfengz May 20, 2025
a1346aa
Update the reference hstu to not do fp32 to fp16/bf16 conversion befo…
qianfengz May 20, 2025
81f7b13
Use LDS to in-directly load Q-tile to enable dwordx4 loading and avoi…
qianfengz May 21, 2025
dc0977f
Use NRepetitions2DEpilogue for outputing o_acc tile
qianfengz May 26, 2025
c9e1935
Update to the method for calculating max_seqlen in the example
qianfengz May 27, 2025
10c3512
Add example parameter max_seqlen and max_target
qianfengz May 27, 2025
68a5ab8
Add init_qkv and dump_output example parameters for easier debugging
qianfengz May 28, 2025
36a0f20
not-critical updates in example and block_masking codes
qianfengz May 29, 2025
781cba3
Convert P to fp16/bf16 before doing second gemm in reference hstu imp…
qianfengz May 29, 2025
832747c
Add example parameter alpha to ease the testing
qianfengz May 30, 2025
bec35ab
Tune the settings for hdim-256
qianfengz May 30, 2025
9582ae2
Move dividing by max_seqlen to end of Gemm1 loop in the reference hst…
qianfengz May 30, 2025
2bb59df
Add two scripts
qianfengz Jun 6, 2025
9e6a240
Move all test and bench scripts to folder scripts
qianfengz Jun 6, 2025
d7930cd
Update IsFulleTileInsideMask() for kUseLocal is true situtation
qianfengz Jun 6, 2025
b2db644
Add assert(contextual_seqlen >= 0) in example
qianfengz Jun 6, 2025
84eb9ad
Move GetKPackV() and GetAlignmentV() out of ck_tile fmha to hstu pipe…
qianfengz Jun 7, 2025
4632d30
Improve the VDramTileDistribution and VLds layout for better device l…
qianfengz Jun 8, 2025
08886e9
Enable BATCH_AS_FIRST_GRID_DIM grid-scheduling and use ASSUME_LEAST_V…
qianfengz Jun 10, 2025
9e62359
Tiny fix in hstu attention IsFullTileInsideMask()
Jun 18, 2025
09ac146
Align the -seqlens=xxx in the mattn0_full0 and mattn256_full256 scrip…
qianfengz Jun 18, 2025
f9caae2
Use batch dim as first grid dim by default and replace env ASSUME_LEA…
qianfengz Jun 18, 2025
a5f24d7
Change while() do to do while() for the main loop to let the compiler…
qianfengz Jun 21, 2025
4fa6474
Fix in using KV LdsBuffers to avoid un-expected over-writting that ca…
qianfengz Jun 21, 2025
463a198
Completely remove the dependency to include/ck_tile/ops/fmha/ops headers
qianfengz Jun 22, 2025
c87a217
Update to test_ck_hstu_mask.sh and test_pytorch_hstu_mask.py to align…
qianfengz Jun 22, 2025
63a47d7
Fix masking for min_full_attn_seqlen > 0 situation
qianfengz Jun 22, 2025
dc7e62a
Simplify the codes in all host/device IsTokenPairInsideMask() trying …
qianfengz Jun 23, 2025
60d8ffb
Use two work-groups per compute-unit for scheduling the kernel
qianfengz Jun 26, 2025
3c300d3
Tiny movement in the code lines of the pipeline
qianfengz Jun 26, 2025
5451912
Let causal == 0 cases to do IsFullTileInsideMask() checking before ca…
qianfengz Jun 26, 2025
8d30e46
Remove using i_loop and num_loops since seqlen_k_curr and seqlen_k_en…
qianfengz Jul 6, 2025
6825618
Moving code-lines in hstu pipeline
qianfengz Jul 7, 2025
9171b35
Add including of block_dropout.hpp in the hstu kernel to avoid potent…
qianfengz Jul 11, 2025
0206b34
[Performance] use iglp compiler instruction to tune the codes around …
qianfengz Jul 14, 2025
c016955
Fix the calculation of number of instructions used by sched_group_bar…
qianfengz Jul 15, 2025
fdd9c11
Remove num_target from HstuBlockMask class member since it overlaps t…
qianfengz Jul 15, 2025
0306a1e
Re-org the kernel parameters in HstuAttentionFwdBatchModeBaseKargs an…
qianfengz Jul 17, 2025
f0c8dca
Fix bug in generate_instances.py and re-generate the instances
qianfengz Jul 17, 2025
ed062f9
Disable support of hdim64 amnd hdim256 for quick compiling and testing
qianfengz Jul 17, 2025
fed1474
Revert "Disable support of hdim64 amnd hdim256 for quick compiling an…
qianfengz Jul 17, 2025
acb6cd8
Move store_tile() caled before the current iteration
qianfengz Jul 21, 2025
906ab84
Fix in using sched_group_barrier()
qianfengz Jul 21, 2025
fcd41a6
Re-arrange the codes section for using sched_group_barrier
qianfengz Jul 21, 2025
ecf6a86
Correct some comments
qianfengz Jul 21, 2025
203e22b
Change the seqlen_q dim padding setting for o_dram and bias_dram
qianfengz Jul 22, 2025
ce6a044
Fix comments in test_pytorch_hstu_mask.py scripts
qianfengz Jul 22, 2025
f49fe28
[Performance] Use separate workgroups to handle seqlen scope [max_uih…
qianfengz Jul 23, 2025
cf012c2
Adjust the codes related to calculate i_m0 in the kernel
qianfengz Jul 23, 2025
01c123d
Fix in GetTileRangeAlongX() and IsFullTileInsideMask() of HstuBlockMa…
qianfengz Jul 25, 2025
43a9768
Add three scripts for verification of jagged causal cases
qianfengz Jul 25, 2025
29d3dc9
Update in GetTileRangeAlongX to consider for non-causal+local_size>0 …
qianfengz Jul 25, 2025
3483af0
Fix added case in test_hstu_attention.sh
qianfengz Jul 25, 2025
de71d33
Use __builtin_amdgcn_sched_barrier(0x1) to prevent the compiler from …
qianfengz Aug 1, 2025
7c9032d
Replace the integer max_seqlen by float scale_p as kernel/pipeline pa…
qianfengz Aug 1, 2025
f27d8ce
Add attn_scale MakeKargs() parameter support and update in example, r…
qianfengz Aug 3, 2025
ae05715
[ck_tile] Remove useless code lines in make_wave_buffer_resource
qianfengz Aug 4, 2025
4026122
[ck_tile] Add get_partition_index_v2 which uses warp_id in vgpr and t…
qianfengz Aug 6, 2025
fd25f5d
[ck_tile] Merge get_partition_index() and get_partition_index_v2() to…
qianfengz Aug 8, 2025
971d0d9
Update to support min_full_attn_seqlen be bigger than max_uih_len
qianfengz Aug 8, 2025
1404336
Update HstuBlockMaskWithLocal::GetTileRangeAlongX, add comments and t…
qianfengz Aug 10, 2025
832ef5d
Tiny fix and comments in HstuBlockMaskWithLocal::IsFullTimeInsideMask()
qianfengz Aug 10, 2025
d5d4a0d
Add simple handling for max_atten_seqlen bigger than max_uih_len situ…
qianfengz Aug 10, 2025
30dd274
Adjust the atol and rtol and fix the check_err() using in example_hst…
qianfengz Aug 12, 2025
d2b0f75
Tiny fix in HstuBlockMaskWithLocal::GetTileRangeAlongX()
qianfengz Aug 12, 2025
fb09061
Add norm_dist parameter for hstu example to select either normal or u…
qianfengz Aug 12, 2025
89cd5ff
Merge branch 'develop' into hstu_attention_n0loop_fused_unroll
qianfengz Aug 18, 2025
7b68b6e
Tiny change in pipeline BlockGemm definition to adapt to the latest m…
qianfengz Aug 18, 2025
21ca848
Remove selectable VLayout for simplifying the codes since hdim is alw…
qianfengz Aug 20, 2025
328fc71
Use xor transform to implement Q/K Lds descriptor for kKpack == 8 cases
qianfengz Aug 21, 2025
4281f54
Merge branch 'develop' into hstu_attention_n0loop_fused_unroll and ti…
qianfengz Sep 1, 2025
2357499
[ck_tile] Fix in set_slice_tile()
qianfengz Sep 8, 2025
1aa6cf8
Use set_slice_tilie() to replace direct thread_buffer assignment
qianfengz Sep 9, 2025
4828e4b
Clarify the using of kSubQKHeaddim and kQKHeaddim so that less regula…
qianfengz Sep 9, 2025
9ae76b2
Remove using MakeKargsImpl() to simplify the hstu kernel
qianfengz Sep 10, 2025
c5994e1
Unify the license statements on all the source files
qianfengz Sep 11, 2025
f8398f5
Detach HstuBlockMask from pipeline definition and construct the HstuB…
qianfengz Sep 12, 2025
75a7332
Smalle update in reference hstu attention
qianfengz Sep 13, 2025
8a01016
Add HSTU_CHECK() and use it in example codes
qianfengz Sep 13, 2025
072459c
Remove un-necessary HSTU_CHECK() callings
qianfengz Sep 13, 2025
fec6e8d
Remove useless constant statement in the kernel
qianfengz Sep 19, 2025
090459d
Merge branch 'develop' into hstu_attention_n0loop_fused_unroll_pr
qianfengz Sep 22, 2025
a57413e
Merge branch 'develop' into hstu_attention_n0loop_fused_unroll_pr
asleepzzz Sep 23, 2025
f158b32
Move hstu from fhold 18_hstu_attention to 23_hstu_attention
qianfengz Sep 23, 2025
fb22e5e
Update to hstu READM.md
qianfengz Sep 23, 2025
6ba01e7
Simplify the warp_gemm definitions in GetQKBlockGemm and GetKVBlockGemm
qianfengz Sep 25, 2025
7a67ddc
Fix in GetQKBlockGemm()
qianfengz Sep 27, 2025
cc16906
re-format using clang-format-18
qianfengz Sep 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions example/ck_tile/18_hstu_attention/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
set(EXAMPLE_HSTU_ATTENTION "tile_example_hstu_attention")
# not using add_example_executable() to add this target, since we don't want this to have
# to be included in "make all/install/check"
message("adding example ${EXAMPLE_HSTU_ATTENTION}")
file(GLOB INSTANCE_SRCS instances/*.cpp)
set(INTERFACES_SRCS hstu_attention_jagged_forward_bf16.cpp hstu_attention_jagged_forward_fp16.cpp hstu_attention_batched_forward_bf16.cpp hstu_attention_batched_forward_fp16.cpp)
add_executable(${EXAMPLE_HSTU_ATTENTION} EXCLUDE_FROM_ALL example_hstu_attention.cpp)
target_include_directories(${EXAMPLE_HSTU_ATTENTION} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_sources(${EXAMPLE_HSTU_ATTENTION} PRIVATE ${INTERFACES_SRCS} ${INSTANCE_SRCS})

set(EXAMPLE_HSTU_ATTENTION_COMPILE_OPTIONS)

list(APPEND EXAMPLE_HSTU_ATTENTION_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=3)

if (DEFINED ENV{ASSUME_HIGHLY_VARIED_SEQLEN})
list(APPEND EXAMPLE_HSTU_ATTENTION_COMPILE_OPTIONS -DHSTU_SCHED_BATCH_AS_FIRST_GRID_DIM=0)
endif()

target_compile_options(${EXAMPLE_HSTU_ATTENTION} PRIVATE ${EXAMPLE_HSTU_ATTENTION_COMPILE_OPTIONS})

# TODO: we have to turn off this global prop, otherwise the progress bar generated
# by cmake will print too many files, execvp: /bin/sh: Argument list too long
# however, this property may affect global
# TODO: consider codegen a makefile by us
set_property(GLOBAL PROPERTY RULE_MESSAGES OFF)

64 changes: 64 additions & 0 deletions example/ck_tile/18_hstu_attention/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# HSTU attention operator

HSTU-attention operator is an operator which takes tensor `q: [batches, seqlen, nhead, hdim_qk]`, `k: [batches, seqlen, nhead, hdim_qk`,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
HSTU-attention operator is an operator which takes tensor `q: [batches, seqlen, nhead, hdim_qk]`, `k: [batches, seqlen, nhead, hdim_qk`,
The HSTU-attention operator is an operator which takes as input three tensor `q: [batches, seqlen, nhead, hdim_qk]`, `k: [batches, seqlen, nhead, hdim_qk`,

`v: [batches, seqlen, nhead, hdim_v]` and some parameters for defining the functional masking as inputs, and do the following:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are the parameters for definiing functional masking, or is the operator for functional masking? It's not clear from the sentence.

But if that's what it is, then this can be changed to:

Suggested change
`v: [batches, seqlen, nhead, hdim_v]` and some parameters for defining the functional masking as inputs, and do the following:
`v: [batches, seqlen, nhead, hdim_v]`, as well as parameters that define functional masking to do the following:
``


* Multiply `q: [batches, seqlen, nhead, hdim_qk]` with `k: [batches, seqlen, nhead, hdim_k]` to get temporary tensor `s: [batches, nhead, seqlen, seqlen]`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
* Multiply `q: [batches, seqlen, nhead, hdim_qk]` with `k: [batches, seqlen, nhead, hdim_k]` to get temporary tensor `s: [batches, nhead, seqlen, seqlen]`
* Multiply `q: [batches, seqlen, nhead, hdim_qk]` with `k: [batches, seqlen, nhead, hdim_k]` to get the intermediate tensor `s: [batches, nhead, seqlen, seqlen]`

* Update `s` by filtering its values according to a special functional mask, which includes the logics of lower-triangular and diagonal window causal mask
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
* Update `s` by filtering its values according to a special functional mask, which includes the logics of lower-triangular and diagonal window causal mask
* Update `s` by filtering it with a functional mask that includes a lower-triangular mask, a diagonal window causal mask, and

as well assequence mask
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
as well assequence mask
a sequence mask.

* Do element-wise SiLu on the `lower seqlen` dimension of `s` to get temporary tensor `p: [batches, nhead, seqlen, seqlen]`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
* Do element-wise SiLu on the `lower seqlen` dimension of `s` to get temporary tensor `p: [batches, nhead, seqlen, seqlen]`
* Do element-wise SiLu on the `lower seqlen` dimension of `s` to get the intermediate tensor `p: [batches, nhead, seqlen, seqlen]`

* Multiply `p : [batches, nhead, seqlen, seqlen]` with `v: [batches, seqlen, nhead, hdim_v]` to get final output `o: [batches, seqlen_q, nhead, headsz_v]`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
* Multiply `p : [batches, nhead, seqlen, seqlen]` with `v: [batches, seqlen, nhead, hdim_v]` to get final output `o: [batches, seqlen_q, nhead, headsz_v]`
* Multiply `p : [batches, nhead, seqlen, seqlen]` with `v: [batches, seqlen, nhead, hdim_v]` to get the final tensor `o: [batches, seqlen_q, nhead, headsz_v]`

* Jagged inputs are also supported, where each batch has separate seqlen defined by the `sequence_offsets[]`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
* Jagged inputs are also supported, where each batch has separate seqlen defined by the `sequence_offsets[]`
Jagged inputs are also supported, where each batch has separate seqlen defined by the `sequence_offsets[]`

This isn't a thing that the operator does, so it shouldn't be in the same bullet list



## implementation

The operator is implemented using a fused kernel in the example:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
The operator is implemented using a fused kernel in the example:
The operator is implemented using a fused kernel:


* Tensor S and Tensor P only exist in VGPRs as per-workgroup tiles, no global memory access is needed
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't need to be a bullet point. You can combine it with the sentence above.


## build

``` bash
#> mkdir build
#> cd build
#> ../script/cmake-ck-dev.sh .. gfx942 ; use #> rocminfo |grep "gfx" to check your gpu arch
#> make -j tile_example_hstu_attention
```

## test/verify

``` bash
#> build/bin/tile_example_hstu_attention -v=1 -prec=bf16 -b=10 -jagged=1 -nhead=4 -hdim_qk=128 -hdim_v=128 -seqlen=750,730,733,860,870,788,760,821,833,779 -targets=5,5,6,6,5,6,5,6,4,6
-causal=1 -local_len=5 -context_len=6 -minfull_len=6
#> . example/ck_tile/07_hstu_attention/test_hstu_attention.sh
```

Check the example file `example_hstu_attention.cpp` for an understanding of the command-line arguments. Which is like the following:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Check the example file `example_hstu_attention.cpp` for an understanding of the command-line arguments. Which is like the following:
Check the example file `example_hstu_attention.cpp` for more information about the command-line arguments.

To be honest, I'd rather see the explanations here, or at least have the code snippet commented. It doesn't need to be everything but some of it.


``` C++
arg_parser.insert("v", "1", "weather do CPU validation or not")
.insert("prec", "fp16", "data type. fp16/bf16")
.insert("jagged", "0", "q/k/v batched sequence is jagged or not")
.insert("b", "12", "batch size")
.insert("nhead", "4", "number of heads")
.insert("hdim_qk", "64", "headdim size of Q/K")
.insert("hdim_v", "64", "headdim size of V/O")
.insert("seqlens", "400", "seqlen of single or all batches for query and key/value tensor, actually allocated seqlen will include the target of each batch and context_len")
.insert("max_seqlen", "0", "max uih_seqlen, can be ignored, or else must be equal or bigger than the maximum of all uih seqlens")
.insert("targets", "16", "sequence length at the end of query/key token sequence that should be excluded from attention")
.insert("max_target", "0", "max target, can be ignored, or else must be equal of bigger than the maximum of all targets")
.insert("causal", "1", "enable causal mask or not")
.insert("local_len", "5", "length of the diagonal window for enabling masking, value 0 to disable")
.insert("context_len", "6", "sequence length at the begin of the query sequence the should be included for attention")
.insert("minfull_len", "6", "sequence length at the end of the query sequence that should be included for attention")
.insert("init_qkv", "0", "initialize q, k, v tensor from local files q.dat, k.dat and v.data")
.insert("seed", "13579", "seed by the uniform or normal distribution generator")
.insert("norm_dist", "0", "if true, initialize the data in normal distribution, or else in uniform distribution")
.insert("alpha", "0", "scale factor of S=Q@K. 0 means equal to 1/sqrt(hdim)")
.insert("attn_scale", "0", "scale factor of SiLu(Q@K), 0 means using 1/max_seqlen for scaling")
.insert("save_mask", "1", "save the mask tensor to disk by the CPU validation codes")
.insert("perf", "0", "weather measure execution time or not");
.insert("dump_output", "0", "dump both device and reference hstu attention outputs to files, only used when validation is true");
```

Original file line number Diff line number Diff line change
@@ -0,0 +1,281 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp"

namespace ck_tile {

// A is block distributed tensor
// B is block window on shared memory
// C is block distributed tensor
template <typename Problem_, typename Policy_ = BlockGemmARegBSmemCRegV2DefaultPolicy>
struct BlockGemmARegBSmemCRegV2Hack_0
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;

static constexpr index_t kBlockSize = Problem::kBlockSize;

// C += A * B
template <typename CBlockTensor, typename ABlockTensorTmp, typename BBlockWindowTmp>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
const ABlockTensorTmp& a_block_tensor_tmp,
const BBlockWindowTmp& b_block_window_tmp) const
{
static_assert(
std::is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
std::is_same_v<BDataType, remove_cv_t<typename BBlockWindowTmp::DataType>> &&
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
"wrong!");

constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];

static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
KPerBlock == BlockGemmShape::kK,
"wrong!");

constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();

using WG = remove_cvref_t<decltype(config.template at<0>())>;

constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();

constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;

constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;

const index_t iNWarp = get_warp_id() % NWarp;

constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};

constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});

// constrcut from A-block-tensor from A-Block-tensor-tmp
// FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent
// distribution
auto a_block_tensor = make_static_distributed_tensor<typename ABlockTensorTmp::DataType>(
MakeABlockTileDistribution());

a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer();

// construct B-warp-window
auto b_warp_window_tmp = make_tile_window(
b_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<WG::kN>{}, number<WG::kK>{}),
b_block_window_tmp.get_window_origin() + multi_index<2>{iNWarp * WG::kN, 0},
make_static_tile_distribution(typename WG::BWarpDstrEncoding{}));

#if 0 // FIXME: using array will cause register spill
array<array<decltype(b_warp_window_tmp), KIterPerWarp>, NIterPerWarp> b_warp_windows{
{b_warp_window_tmp}};

for(index_t nIter = 0; nIter < NIterPerWarp; nIter++)
{
for(index_t kIter = 0; kIter < KIterPerWarp; kIter++)
{
move_tile_window(b_warp_windows(nIter)(kIter),
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
}
}
#else
statically_indexed_array<
statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
NIterPerWarp>
b_warp_windows;
#endif

// check C-block-distribution
static_assert(
std::is_same_v<remove_cvref_t<decltype(c_block_dstr_encode)>,
remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
.get_static_tile_distribution_encoding())>>,
"wrong!");

using AWarpDstr = typename WG::AWarpDstr;
using CWarpDstr = typename WG::CWarpDstr;

using AWarpTensor = typename WG::AWarpTensor;
using CWarpTensor = typename WG::CWarpTensor;

constexpr auto a_warp_y_lengths =
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());

constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};

constexpr auto I0 = number<0>{};
constexpr auto I1 = number<1>{};

// hot loop:
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
using b_warp_tensor_type = decltype(load_tile(b_warp_windows(I0)(I0)));

statically_indexed_array<b_warp_tensor_type, KIterPerWarp> b_warp_tensors;

b_warp_windows(nIter)(I0) = b_warp_window_tmp;
move_tile_window(b_warp_windows(nIter)(I0),
{nIter * NPerBlockPerIter, 0 * KPerBlockPerIter});
b_warp_tensors[I0] = load_tile(b_warp_windows(nIter)(I0));

__builtin_amdgcn_sched_barrier(0);

b_warp_windows(nIter)(I1) = b_warp_window_tmp;
move_tile_window(b_warp_windows(nIter)(I1),
{nIter * NPerBlockPerIter, 1 * KPerBlockPerIter});
b_warp_tensors[I1] = load_tile(b_warp_windows(nIter)(I1));

__builtin_amdgcn_sched_barrier(0);

static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block tensor
AWarpTensor a_warp_tensor;

a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, 0>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));

// warp GEMM
auto c_warp_tensor = WG{}(a_warp_tensor, b_warp_tensors[I0]);
// WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_array[nIter]);

// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});

static_for<1, KIterPerWarp, 1>{}([&](auto kIter) {
// read B warp tensor from B Block window
if constexpr(kIter < KIterPerWarp - 1)
{
b_warp_windows(nIter)(number<kIter + 1>{}) = b_warp_window_tmp;
move_tile_window(b_warp_windows(nIter)(number<kIter + 1>{}),
{nIter * NPerBlockPerIter, (kIter + 1) * KPerBlockPerIter});
b_warp_tensors[number<kIter + 1>{}] =
load_tile(b_warp_windows(nIter)(number<kIter + 1>{}));
};

__builtin_amdgcn_sched_barrier(0);

static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block tensor
AWarpTensor a_warp_tensor;

a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));

// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;

c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));

// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensors[kIter]);
// WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_array[nIter]);

// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
});
}

template <index_t MPerBlock = BlockGemmShape::kM, index_t KPerBlock = BlockGemmShape::kK>
CK_TILE_DEVICE static constexpr auto MakeABlockTileDistribution()
{
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();

using WG = remove_cvref_t<decltype(config.template at<0>())>;

constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();

constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;

constexpr auto a_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};

constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});

return make_static_tile_distribution(a_block_dstr_encode);
}

CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
{
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t NPerBlock = BlockGemmShape::kN;

constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();

using WG = remove_cvref_t<decltype(config.template at<0>())>;

constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();

constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
// constexpr index_t KIterPerWarp = KPerBlock / WG::kK;

constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};

constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
return c_block_tensor;
}

// C = A * B
template <typename ABlockTensorTmp, typename BBlockWindowTmp>
CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp,
const BBlockWindowTmp& b_block_window_tmp) const
{
auto c_block_tensor = MakeCBlockTile();
operator()(c_block_tensor, a_block_tensor_tmp, b_block_window_tmp);
return c_block_tensor;
}
};

} // namespace ck_tile
Loading
Loading