Skip to content

Conversation

qianfengz
Copy link
Contributor

@qianfengz qianfengz commented Sep 22, 2025

This PR brings an implementation of HSTU attention on ck_tile. HSTU attention is very different from the fmha implemented in ck_tile, for details, please refer to the hstu paper

The implementation is well verified on MI300 for both functionalities and targeted performance, but it does not make any optimization for MI350.

To build
#> cd build; ../scripts/cmake-ck-dev.sh .. gfx942; make -j 128 tile_example_hstu_attention

To verify
#> . examples/ck_tile/23_hstu_attention/scripts/test_hstu_attention.sh

The codes of HSTU are all located under the folder examples/ck_tile/23_hstu_attention, but this PR also made some tiny change to the core ck_tile codes under include/ck_tile/core/tensor

…oLocal to save vgprs for non-local situations
Copy link
Contributor

@spolifroni-amd spolifroni-amd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Readme needs work for clarity.

@@ -0,0 +1,64 @@
# HSTU attention operator

HSTU-attention operator is an operator which takes tensor `q: [batches, seqlen, nhead, hdim_qk]`, `k: [batches, seqlen, nhead, hdim_qk`,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
HSTU-attention operator is an operator which takes tensor `q: [batches, seqlen, nhead, hdim_qk]`, `k: [batches, seqlen, nhead, hdim_qk`,
The HSTU-attention operator is an operator which takes as input three tensor `q: [batches, seqlen, nhead, hdim_qk]`, `k: [batches, seqlen, nhead, hdim_qk`,

# HSTU attention operator

HSTU-attention operator is an operator which takes tensor `q: [batches, seqlen, nhead, hdim_qk]`, `k: [batches, seqlen, nhead, hdim_qk`,
`v: [batches, seqlen, nhead, hdim_v]` and some parameters for defining the functional masking as inputs, and do the following:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are the parameters for definiing functional masking, or is the operator for functional masking? It's not clear from the sentence.

But if that's what it is, then this can be changed to:

Suggested change
`v: [batches, seqlen, nhead, hdim_v]` and some parameters for defining the functional masking as inputs, and do the following:
`v: [batches, seqlen, nhead, hdim_v]`, as well as parameters that define functional masking to do the following:
``

HSTU-attention operator is an operator which takes tensor `q: [batches, seqlen, nhead, hdim_qk]`, `k: [batches, seqlen, nhead, hdim_qk`,
`v: [batches, seqlen, nhead, hdim_v]` and some parameters for defining the functional masking as inputs, and do the following:

* Multiply `q: [batches, seqlen, nhead, hdim_qk]` with `k: [batches, seqlen, nhead, hdim_k]` to get temporary tensor `s: [batches, nhead, seqlen, seqlen]`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
* Multiply `q: [batches, seqlen, nhead, hdim_qk]` with `k: [batches, seqlen, nhead, hdim_k]` to get temporary tensor `s: [batches, nhead, seqlen, seqlen]`
* Multiply `q: [batches, seqlen, nhead, hdim_qk]` with `k: [batches, seqlen, nhead, hdim_k]` to get the intermediate tensor `s: [batches, nhead, seqlen, seqlen]`

`v: [batches, seqlen, nhead, hdim_v]` and some parameters for defining the functional masking as inputs, and do the following:

* Multiply `q: [batches, seqlen, nhead, hdim_qk]` with `k: [batches, seqlen, nhead, hdim_k]` to get temporary tensor `s: [batches, nhead, seqlen, seqlen]`
* Update `s` by filtering its values according to a special functional mask, which includes the logics of lower-triangular and diagonal window causal mask
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
* Update `s` by filtering its values according to a special functional mask, which includes the logics of lower-triangular and diagonal window causal mask
* Update `s` by filtering it with a functional mask that includes a lower-triangular mask, a diagonal window causal mask, and


* Multiply `q: [batches, seqlen, nhead, hdim_qk]` with `k: [batches, seqlen, nhead, hdim_k]` to get temporary tensor `s: [batches, nhead, seqlen, seqlen]`
* Update `s` by filtering its values according to a special functional mask, which includes the logics of lower-triangular and diagonal window causal mask
as well assequence mask
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
as well assequence mask
a sequence mask.

* Update `s` by filtering its values according to a special functional mask, which includes the logics of lower-triangular and diagonal window causal mask
as well assequence mask
* Do element-wise SiLu on the `lower seqlen` dimension of `s` to get temporary tensor `p: [batches, nhead, seqlen, seqlen]`
* Multiply `p : [batches, nhead, seqlen, seqlen]` with `v: [batches, seqlen, nhead, hdim_v]` to get final output `o: [batches, seqlen_q, nhead, headsz_v]`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
* Multiply `p : [batches, nhead, seqlen, seqlen]` with `v: [batches, seqlen, nhead, hdim_v]` to get final output `o: [batches, seqlen_q, nhead, headsz_v]`
* Multiply `p : [batches, nhead, seqlen, seqlen]` with `v: [batches, seqlen, nhead, hdim_v]` to get the final tensor `o: [batches, seqlen_q, nhead, headsz_v]`

as well assequence mask
* Do element-wise SiLu on the `lower seqlen` dimension of `s` to get temporary tensor `p: [batches, nhead, seqlen, seqlen]`
* Multiply `p : [batches, nhead, seqlen, seqlen]` with `v: [batches, seqlen, nhead, hdim_v]` to get final output `o: [batches, seqlen_q, nhead, headsz_v]`
* Jagged inputs are also supported, where each batch has separate seqlen defined by the `sequence_offsets[]`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
* Jagged inputs are also supported, where each batch has separate seqlen defined by the `sequence_offsets[]`
Jagged inputs are also supported, where each batch has separate seqlen defined by the `sequence_offsets[]`

This isn't a thing that the operator does, so it shouldn't be in the same bullet list


## implementation

The operator is implemented using a fused kernel in the example:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
The operator is implemented using a fused kernel in the example:
The operator is implemented using a fused kernel:


The operator is implemented using a fused kernel in the example:

* Tensor S and Tensor P only exist in VGPRs as per-workgroup tiles, no global memory access is needed
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't need to be a bullet point. You can combine it with the sentence above.

#> . example/ck_tile/07_hstu_attention/test_hstu_attention.sh
```

Check the example file `example_hstu_attention.cpp` for an understanding of the command-line arguments. Which is like the following:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Check the example file `example_hstu_attention.cpp` for an understanding of the command-line arguments. Which is like the following:
Check the example file `example_hstu_attention.cpp` for more information about the command-line arguments.

To be honest, I'd rather see the explanations here, or at least have the code snippet commented. It doesn't need to be everything but some of it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants