-
Notifications
You must be signed in to change notification settings - Fork 239
Hstu attention n0loop fused unroll pr #2896
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Conversation
…uilding succeeded)
…oLocal to save vgprs for non-local situations
…st-K/last-V or last-K/first-V
…n when both causal=true and local=true
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Readme needs work for clarity.
@@ -0,0 +1,64 @@ | |||
# HSTU attention operator | |||
|
|||
HSTU-attention operator is an operator which takes tensor `q: [batches, seqlen, nhead, hdim_qk]`, `k: [batches, seqlen, nhead, hdim_qk`, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
HSTU-attention operator is an operator which takes tensor `q: [batches, seqlen, nhead, hdim_qk]`, `k: [batches, seqlen, nhead, hdim_qk`, | |
The HSTU-attention operator is an operator which takes as input three tensor `q: [batches, seqlen, nhead, hdim_qk]`, `k: [batches, seqlen, nhead, hdim_qk`, |
# HSTU attention operator | ||
|
||
HSTU-attention operator is an operator which takes tensor `q: [batches, seqlen, nhead, hdim_qk]`, `k: [batches, seqlen, nhead, hdim_qk`, | ||
`v: [batches, seqlen, nhead, hdim_v]` and some parameters for defining the functional masking as inputs, and do the following: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are the parameters for definiing functional masking, or is the operator for functional masking? It's not clear from the sentence.
But if that's what it is, then this can be changed to:
`v: [batches, seqlen, nhead, hdim_v]` and some parameters for defining the functional masking as inputs, and do the following: | |
`v: [batches, seqlen, nhead, hdim_v]`, as well as parameters that define functional masking to do the following: | |
`` |
HSTU-attention operator is an operator which takes tensor `q: [batches, seqlen, nhead, hdim_qk]`, `k: [batches, seqlen, nhead, hdim_qk`, | ||
`v: [batches, seqlen, nhead, hdim_v]` and some parameters for defining the functional masking as inputs, and do the following: | ||
|
||
* Multiply `q: [batches, seqlen, nhead, hdim_qk]` with `k: [batches, seqlen, nhead, hdim_k]` to get temporary tensor `s: [batches, nhead, seqlen, seqlen]` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
* Multiply `q: [batches, seqlen, nhead, hdim_qk]` with `k: [batches, seqlen, nhead, hdim_k]` to get temporary tensor `s: [batches, nhead, seqlen, seqlen]` | |
* Multiply `q: [batches, seqlen, nhead, hdim_qk]` with `k: [batches, seqlen, nhead, hdim_k]` to get the intermediate tensor `s: [batches, nhead, seqlen, seqlen]` |
`v: [batches, seqlen, nhead, hdim_v]` and some parameters for defining the functional masking as inputs, and do the following: | ||
|
||
* Multiply `q: [batches, seqlen, nhead, hdim_qk]` with `k: [batches, seqlen, nhead, hdim_k]` to get temporary tensor `s: [batches, nhead, seqlen, seqlen]` | ||
* Update `s` by filtering its values according to a special functional mask, which includes the logics of lower-triangular and diagonal window causal mask |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
* Update `s` by filtering its values according to a special functional mask, which includes the logics of lower-triangular and diagonal window causal mask | |
* Update `s` by filtering it with a functional mask that includes a lower-triangular mask, a diagonal window causal mask, and |
|
||
* Multiply `q: [batches, seqlen, nhead, hdim_qk]` with `k: [batches, seqlen, nhead, hdim_k]` to get temporary tensor `s: [batches, nhead, seqlen, seqlen]` | ||
* Update `s` by filtering its values according to a special functional mask, which includes the logics of lower-triangular and diagonal window causal mask | ||
as well assequence mask |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as well assequence mask | |
a sequence mask. |
* Update `s` by filtering its values according to a special functional mask, which includes the logics of lower-triangular and diagonal window causal mask | ||
as well assequence mask | ||
* Do element-wise SiLu on the `lower seqlen` dimension of `s` to get temporary tensor `p: [batches, nhead, seqlen, seqlen]` | ||
* Multiply `p : [batches, nhead, seqlen, seqlen]` with `v: [batches, seqlen, nhead, hdim_v]` to get final output `o: [batches, seqlen_q, nhead, headsz_v]` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
* Multiply `p : [batches, nhead, seqlen, seqlen]` with `v: [batches, seqlen, nhead, hdim_v]` to get final output `o: [batches, seqlen_q, nhead, headsz_v]` | |
* Multiply `p : [batches, nhead, seqlen, seqlen]` with `v: [batches, seqlen, nhead, hdim_v]` to get the final tensor `o: [batches, seqlen_q, nhead, headsz_v]` |
as well assequence mask | ||
* Do element-wise SiLu on the `lower seqlen` dimension of `s` to get temporary tensor `p: [batches, nhead, seqlen, seqlen]` | ||
* Multiply `p : [batches, nhead, seqlen, seqlen]` with `v: [batches, seqlen, nhead, hdim_v]` to get final output `o: [batches, seqlen_q, nhead, headsz_v]` | ||
* Jagged inputs are also supported, where each batch has separate seqlen defined by the `sequence_offsets[]` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
* Jagged inputs are also supported, where each batch has separate seqlen defined by the `sequence_offsets[]` | |
Jagged inputs are also supported, where each batch has separate seqlen defined by the `sequence_offsets[]` |
This isn't a thing that the operator does, so it shouldn't be in the same bullet list
|
||
## implementation | ||
|
||
The operator is implemented using a fused kernel in the example: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The operator is implemented using a fused kernel in the example: | |
The operator is implemented using a fused kernel: |
|
||
The operator is implemented using a fused kernel in the example: | ||
|
||
* Tensor S and Tensor P only exist in VGPRs as per-workgroup tiles, no global memory access is needed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't need to be a bullet point. You can combine it with the sentence above.
#> . example/ck_tile/07_hstu_attention/test_hstu_attention.sh | ||
``` | ||
|
||
Check the example file `example_hstu_attention.cpp` for an understanding of the command-line arguments. Which is like the following: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Check the example file `example_hstu_attention.cpp` for an understanding of the command-line arguments. Which is like the following: | |
Check the example file `example_hstu_attention.cpp` for more information about the command-line arguments. |
To be honest, I'd rather see the explanations here, or at least have the code snippet commented. It doesn't need to be everything but some of it.
This PR brings an implementation of HSTU attention on ck_tile. HSTU attention is very different from the
fmha
implemented in ck_tile, for details, please refer to the hstu paperThe implementation is well verified on MI300 for both functionalities and targeted performance, but it does not make any optimization for MI350.
To build
#> cd build; ../scripts/cmake-ck-dev.sh .. gfx942; make -j 128 tile_example_hstu_attention
To verify
#> . examples/ck_tile/23_hstu_attention/scripts/test_hstu_attention.sh
The codes of HSTU are all located under the folder examples/ck_tile/23_hstu_attention, but this PR also made some tiny change to the core ck_tile codes under
include/ck_tile/core/tensor