-
Notifications
You must be signed in to change notification settings - Fork 145
[Advanced Compiler]fix get_scheduler_metadata #1010
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Advanced Compiler]fix get_scheduler_metadata #1010
Conversation
Summary of ChangesHello @AdvancedCompiler, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses a bug and significantly enhances the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for SM90 architecture in get_scheduler_metadata
by adding new heuristics and refactoring Triton kernels. The changes are extensive and improve performance by vectorizing kernel operations.
My review focuses on several key areas:
- Critical Performance Issue: I've identified a redundant GPU kernel launch which should be fixed.
- Code Clarity and Maintainability: I've suggested removing dead code and simplifying some expressions to make the code easier to understand and maintain.
- Consistency: I've pointed out a few minor issues like a non-English error message and redundant code for better consistency across the codebase.
Overall, the changes are a significant step forward, but the critical performance issue and dead code should be addressed before merging.
seqlen_q = ( | ||
seqused_q | ||
if seqused_q is not None | ||
else torch.full((batch_size,), max_seqlen_q, dtype=dtype, device=device) | ||
) | ||
seqlen_k = seqused_k | ||
seqlen_knew = ( | ||
torch.full((batch_size,), max_seqlen_k_new, dtype=dtype, device=device) | ||
if max_seqlen_k_new > 0 | ||
else None | ||
) | ||
|
||
num_m_blocks = torch.empty_like(seqlen_q) | ||
num_n_blocks = torch.empty_like(seqlen_k) | ||
total_blocks = torch.zeros((1,), dtype=dtype, device=device) | ||
num_splits_dynamic = torch.empty_like(seqlen_q) | ||
|
||
BLOCK_SIZE_B = 128 | ||
grid = (triton.cdiv(batch_size, BLOCK_SIZE_B),) | ||
|
||
_prepare_pass1_kernel[grid]( | ||
num_m_blocks, | ||
num_n_blocks, | ||
total_blocks, | ||
seqlen_k, | ||
cu_seqlens_q, | ||
cu_seqlens_k, | ||
cu_seqlens_k_new, | ||
seqused_q, | ||
seqused_k, | ||
leftpad_k, | ||
batch_size, | ||
qhead_per_khead, | ||
max_seqlen_q=max_seqlen_q, | ||
max_seqlen_k_new=max_seqlen_k_new, | ||
BLOCK_M=blockM, | ||
BLOCK_N=blockN, | ||
BLOCK_SIZE_B=BLOCK_SIZE_B, | ||
HAS_CU_SEQLENS_Q=cu_seqlens_q is not None, | ||
HAS_CU_SEQLENS_K=cu_seqlens_k is not None, | ||
HAS_SEQUSED_Q=seqused_q is not None, | ||
HAS_SEQUSED_K=True, | ||
HAS_LEFT_PAD=leftpad_k is not None, | ||
HAS_K_NEW=seqlen_knew is not None, | ||
HAS_CU_SEQLENS_K_NEW=cu_seqlens_k_new is not None, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This block of code, which prepares tensors and launches the _prepare_pass1_kernel
, is duplicated later in the function (lines 835-865). This results in launching the same kernel twice, which is inefficient. This entire block should be removed to avoid the redundant kernel launch and code duplication. The logic should be structured to compute all necessary parameters first, and then launch the kernel once.
if arch >= 90: | ||
scheduler_needs_semaphore = ( | ||
(final_is_causal or final_is_local) and (final_num_splits == 1) | ||
) or is_varlen | ||
else: | ||
scheduler_needs_semaphore = (final_is_causal and not is_varlen) or ( | ||
is_varlen and final_num_splits > 1 | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if use_one_mma_wg: | ||
return 64, (128 if (is_causal or is_local or paged_kv_non_TMA) else 176) | ||
else: | ||
return 128, ( | ||
128 if (is_causal or is_local or paged_kv_non_TMA) else 176 | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The condition is_causal or is_local or paged_kv_non_TMA
is used twice. To improve readability and reduce repetition, you could extract it into a boolean variable.
if use_one_mma_wg: | |
return 64, (128 if (is_causal or is_local or paged_kv_non_TMA) else 176) | |
else: | |
return 128, ( | |
128 if (is_causal or is_local or paged_kv_non_TMA) else 176 | |
) | |
use_blockN_128 = is_causal or is_local or paged_kv_non_TMA | |
block_n = 128 if use_blockN_128 else 176 | |
if use_one_mma_wg: | |
return 64, block_n | |
else: | |
return 128, block_n |
|
||
return kBlockM, kBlockN | ||
if arch >= 90: | ||
paged_kv_non_TMA = bool(paged_kv and (not pagedkv_tma)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
) | ||
return kBlockM, kBlockN | ||
else: | ||
kBlockM, kBlockN, kNWarps, kStages, Q_in_regs = tile_size_fwd_sm8x( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The values kNWarps
, kStages
, and Q_in_regs
are unpacked but not used. You can use _
to indicate that these values are intentionally ignored, which improves code clarity.
kBlockM, kBlockN, kNWarps, kStages, Q_in_regs = tile_size_fwd_sm8x( | |
kBlockM, kBlockN, _, _, _ = tile_size_fwd_sm8x( |
heuristic_num_splits = get_num_splits( | ||
if not (is_fp16 or is_bf16): | ||
raise ValueError( | ||
f"不支持的数据类型: {qkv_dtype}. FlashAttention只支持: torch.float16, torch.bfloat16" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error message is in Chinese, while the rest of the codebase and comments are in English. For consistency, it's better to use English for error messages.
f"不支持的数据类型: {qkv_dtype}. FlashAttention只支持: torch.float16, torch.bfloat16" | |
f"Unsupported data type: {qkv_dtype}. FlashAttention only supports: torch.float16, torch.bfloat16" |
d_rounded = d_rounded | ||
dv_rounded = dv_rounded |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PR Category
Operator
Type of Change
Bug Fix
Description
fix get_scheduler_metadata on sm90-platform
Issue
Progress
Performance