Skip to content

Conversation

AdvancedCompiler
Copy link
Contributor

PR Category

Operator

Type of Change

Bug Fix

Description

fix get_scheduler_metadata on sm90-platform

Issue

Progress

  • Change is properly reviewed (1 reviewer required, 2 recommended).
  • Change is responded to an issue.
  • Change is fully covered by a UT.

Performance

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @AdvancedCompiler, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a bug and significantly enhances the get_scheduler_metadata function, particularly for SM90 (Hopper+) platforms. It introduces specialized logic and helper functions for optimal tile sizing, paged KV management with TMA, and GQA packing heuristics. Furthermore, the core Triton JIT kernel responsible for scheduler preparation has been refactored into a two-pass approach, leading to more robust and efficient metadata computation for attention mechanisms.

Highlights

  • SM90 (Hopper+) Architecture Optimization: Introduced a new function tile_size_fwd_sm90 and integrated it into get_optimal_block_mn to provide specific tile size heuristics for NVIDIA Hopper+ GPUs, enhancing performance for these architectures.
  • Paged KV with TMA Support: Added get_pagedkv_tma to determine when to leverage Tensor Memory Accelerator (TMA) for Paged Key-Value (KV) caching, specifically for SM90, optimizing memory access patterns.
  • Grouped Query Attention (GQA) Packing Heuristics: Implemented should_pack_gqa and updated get_pack_gqa to intelligently decide whether to pack GQA based on various parameters, improving efficiency for different attention configurations.
  • Refactored Triton JIT Kernel for Scheduler Preparation: The single _prepare_scheduler_kernel has been split and refactored into two distinct Triton JIT kernels, _prepare_pass1_kernel and _prepare_pass2_kernel. This change improves the efficiency and parallelism of calculating block and split metadata for the attention scheduler.
  • Updated Device Capability Access: Switched from flag_gems.runtime.torch_device_fn to direct torch.cuda calls for retrieving device capabilities and properties, streamlining the API usage.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for SM90 architecture in get_scheduler_metadata by adding new heuristics and refactoring Triton kernels. The changes are extensive and improve performance by vectorizing kernel operations.

My review focuses on several key areas:

  • Critical Performance Issue: I've identified a redundant GPU kernel launch which should be fixed.
  • Code Clarity and Maintainability: I've suggested removing dead code and simplifying some expressions to make the code easier to understand and maintain.
  • Consistency: I've pointed out a few minor issues like a non-English error message and redundant code for better consistency across the codebase.

Overall, the changes are a significant step forward, but the critical performance issue and dead code should be addressed before merging.

Comment on lines 698 to 743
seqlen_q = (
seqused_q
if seqused_q is not None
else torch.full((batch_size,), max_seqlen_q, dtype=dtype, device=device)
)
seqlen_k = seqused_k
seqlen_knew = (
torch.full((batch_size,), max_seqlen_k_new, dtype=dtype, device=device)
if max_seqlen_k_new > 0
else None
)

num_m_blocks = torch.empty_like(seqlen_q)
num_n_blocks = torch.empty_like(seqlen_k)
total_blocks = torch.zeros((1,), dtype=dtype, device=device)
num_splits_dynamic = torch.empty_like(seqlen_q)

BLOCK_SIZE_B = 128
grid = (triton.cdiv(batch_size, BLOCK_SIZE_B),)

_prepare_pass1_kernel[grid](
num_m_blocks,
num_n_blocks,
total_blocks,
seqlen_k,
cu_seqlens_q,
cu_seqlens_k,
cu_seqlens_k_new,
seqused_q,
seqused_k,
leftpad_k,
batch_size,
qhead_per_khead,
max_seqlen_q=max_seqlen_q,
max_seqlen_k_new=max_seqlen_k_new,
BLOCK_M=blockM,
BLOCK_N=blockN,
BLOCK_SIZE_B=BLOCK_SIZE_B,
HAS_CU_SEQLENS_Q=cu_seqlens_q is not None,
HAS_CU_SEQLENS_K=cu_seqlens_k is not None,
HAS_SEQUSED_Q=seqused_q is not None,
HAS_SEQUSED_K=True,
HAS_LEFT_PAD=leftpad_k is not None,
HAS_K_NEW=seqlen_knew is not None,
HAS_CU_SEQLENS_K_NEW=cu_seqlens_k_new is not None,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This block of code, which prepares tensors and launches the _prepare_pass1_kernel, is duplicated later in the function (lines 835-865). This results in launching the same kernel twice, which is inefficient. This entire block should be removed to avoid the redundant kernel launch and code duplication. The logic should be structured to compute all necessary parameters first, and then launch the kernel once.

Comment on lines +887 to 894
if arch >= 90:
scheduler_needs_semaphore = (
(final_is_causal or final_is_local) and (final_num_splits == 1)
) or is_varlen
else:
scheduler_needs_semaphore = (final_is_causal and not is_varlen) or (
is_varlen and final_num_splits > 1
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This block of code calculates scheduler_needs_semaphore with complex logic. However, its result is immediately overwritten by a simpler expression on line 901. This makes the block from line 887 to 894 dead code. It should be removed to improve clarity and avoid confusion.

Comment on lines +114 to +119
if use_one_mma_wg:
return 64, (128 if (is_causal or is_local or paged_kv_non_TMA) else 176)
else:
return 128, (
128 if (is_causal or is_local or paged_kv_non_TMA) else 176
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The condition is_causal or is_local or paged_kv_non_TMA is used twice. To improve readability and reduce repetition, you could extract it into a boolean variable.

Suggested change
if use_one_mma_wg:
return 64, (128 if (is_causal or is_local or paged_kv_non_TMA) else 176)
else:
return 128, (
128 if (is_causal or is_local or paged_kv_non_TMA) else 176
)
use_blockN_128 = is_causal or is_local or paged_kv_non_TMA
block_n = 128 if use_blockN_128 else 176
if use_one_mma_wg:
return 64, block_n
else:
return 128, block_n


return kBlockM, kBlockN
if arch >= 90:
paged_kv_non_TMA = bool(paged_kv and (not pagedkv_tma))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The bool() constructor is redundant here, as the expression paged_kv and (not pagedkv_tma) already evaluates to a boolean.

Suggested change
paged_kv_non_TMA = bool(paged_kv and (not pagedkv_tma))
paged_kv_non_TMA = paged_kv and not pagedkv_tma

)
return kBlockM, kBlockN
else:
kBlockM, kBlockN, kNWarps, kStages, Q_in_regs = tile_size_fwd_sm8x(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The values kNWarps, kStages, and Q_in_regs are unpacked but not used. You can use _ to indicate that these values are intentionally ignored, which improves code clarity.

Suggested change
kBlockM, kBlockN, kNWarps, kStages, Q_in_regs = tile_size_fwd_sm8x(
kBlockM, kBlockN, _, _, _ = tile_size_fwd_sm8x(

heuristic_num_splits = get_num_splits(
if not (is_fp16 or is_bf16):
raise ValueError(
f"不支持的数据类型: {qkv_dtype}. FlashAttention只支持: torch.float16, torch.bfloat16"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The error message is in Chinese, while the rest of the codebase and comments are in English. For consistency, it's better to use English for error messages.

Suggested change
f"不支持的数据类型: {qkv_dtype}. FlashAttention只支持: torch.float16, torch.bfloat16"
f"Unsupported data type: {qkv_dtype}. FlashAttention only supports: torch.float16, torch.bfloat16"

Comment on lines +759 to +760
d_rounded = d_rounded
dv_rounded = dv_rounded
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

These assignments are redundant as d_rounded and dv_rounded are already assigned to themselves. They can be safely removed.

@0x45f 0x45f self-assigned this Oct 22, 2025
@0x45f 0x45f merged commit 078834c into FlagOpen:master Oct 23, 2025
11 of 14 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants