Skip to content

TODO List #72

@LoserCheems

Description

@LoserCheems

Project Overview

This project implements an enhanced version of Flash Attention, called Flash Dynamic Masked Attention (Flash-DMAttn), which maintains the memory efficiency and computational speed of Flash Attention while adding Dynamic Mask and Learnable Bias functionalities. These enhancements significantly improve the flexibility and expressiveness of the attention mechanism.

Basic Principles

Flash Attention

Flash Attention is a memory-efficient attention computation method that avoids storing the complete attention matrix through block-wise computation and recomputation strategies, thereby reducing memory requirements and improving GPU utilization. Its core optimization is to decompose attention calculations into smaller sub-blocks, reducing data transfer between HBM and SRAM.

Dynamic Mask

Dynamic Mask extends the attention mechanism, enabling the model to focus on different parts of the input sequence selectively. Unlike traditional static masks (such as causal masks), dynamic masks can:

  • Create sparse attention patterns based on input content
  • Simulate complex attention constraints such as local attention, hierarchical attention, etc.
  • Optimize computational efficiency by skipping unnecessary calculations to accelerate attention computation

Learnable Bias

Learnable Bias allows the model to make fine-grained adjustments to attention scores:

  • Add trainable bias terms for different position pairs
  • Encode positional information or prior knowledge
  • Adjust attention distribution, enhancing the model's expressiveness

Project Progress

Currently working on integrating dynamic_mask_attention_python into flash attention.

The integration work mainly consists of the following parts: First, dynamic mask computation occurs, where the Python side pre-computes Mask(batch, num_heads, query_len, key_len) and Bias(batch, num_heads, query_len, key_len), and then passes them to the CUDA backend. The CUDA side needs to load them from global memory to shared memory correctly, and then to registers. Second is sparse attention weight computation, where we first use the Mask to sparsely compute the matrix multiplication accumulation of K^T when calculating QK^T. Then, we set the masked parts to -inf, apply scaling to the valid parts, and add Bias. Finally, when computing Attention Score V, we again use the Mask to calculate the matrix multiplication accumulation of V sparsely.

Work completed so far:

  • Completed parameter structure definition in @flash.h, including Mask and Bias definitions.
  • Completed shared memory layout and copy layout definitions for Mask and Bias in @kernel_traits.h.
  • Completed attn_mask_offset and attn_bias_offset in @block_info.h for Mask and Bias.
  • Completed copy_Mask implementation in @utils.h, including loading logic from global memory to shared memory.
  • Completed dynamic mask computation logic in @mask.h, apply_mask, including Mask and Bias calculations.
  • Completed sparse matrix multiplication in @utils.h: sparse_gemm, sparse_gemm_rs.
  • Completed compute_attn_1rowblock adaptation for Mask and Bias in @flash_fwd_kernel.h.
  • Completed launch function adaptation for Mask and Bias in @flash_fwd_launch_template.h.
  • Completed interface functions adaptation for Mask and Bias in @flash_api.cpp.
  • Completed forward equivalence testing in @benchmark_forward_equivalence.py.
  • Completed forward performance testing in @benchmark_forward_performance.py.

The above completed work may need to be further improved to accommodate new requirements as the project progresses.

Work that needs to be completed now:

Implement Mask and Bias adaptation for compute_dq_dk_dv_1colblock in @flash_bwd_kernel.h, referencing the dBias calculation logic in the backward pass of @flash_dmattn_triton.py.

Metadata

Metadata

Labels

featureNew feature requestquestionFurther information is requested

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions