Skip to content

[FEATURE REQUEST] Unified Sparse Mask Strategy with Block-Level Skipping #163

@LoserCheems

Description

@LoserCheems

Is your feature request related to a problem? Please describe.
Current usage patterns (long-context, multi-document, selective attention) require masking strategies beyond simple causal or fixed windows. A naïve dense L×L mask (even bit-packed) inflates per-layer peak memory and prevents fully exploiting block-level skipping in kernels. For 32K sequences, materializing logical S/P or full masks—even transiently—creates pressure on memory and bandwidth and limits scalability.

Describe the solution you'd like
Introduce a unified mask abstraction that:

  • Supports: causal mask, document (intra-segment) mask, windowed banded mask, hybrid doc+causal, dynamic (runtime-pruned) mask.
  • Represents masks in compressed forms:
    • Parametric (causal, sliding window) → no storage.
    • Block bitset (B×B granularity) for moderate sparsity.
    • BCSR (row_ptr + col_idx) for irregular sparse patterns.
    • Mixed granularity: dense blocks + few partial blocks (bitpacked).
  • Kernel API accepts a lightweight per-(query block) descriptor to enumerate active key blocks; token-level refinement generated on-the-fly if needed.
  • Forward & backward share identical block activation logic (OR(mask_block) pre-check) to skip inactive tiles before K/V loads.
  • Optional dynamic refinement hook (user callback or pre-pass) for top-k / score-threshold pruning without global L×L reduction.

Describe alternatives you've considered

  • Full L×L mask tensor (too large, defeats memory savings).
  • Storing per-row start/end (works only for pure band/window).
  • Pure top-k inside kernel (requires global reductions per row; high latency).
  • Precomputing pruned indices at Python level (overhead + host/device sync).

Implementation details

  • CUDA/Triton kernel changes:
    • Add indirection layer: query_block_id → (bitset | block index list | param spec).
    • Fast path specialization for parametric (no loads).
    • Shared-memory staging of active col block IDs; warp ballot over bitset.
    • Optional partial-block submask bitpack (<= 128 bits) applied after QKᵀ tile.
  • Python API: new Mask classes (CausalMask, WindowMask, DocMask, DynamicMask).
  • Dynamic mask: optional prepass producing BCSR; or callback registered to produce active block list per step.
  • Performance: reduces unnecessary K/V reads; improves arithmetic intensity by eliminating zero tiles; keeps log-sum-exp streaming unchanged.
  • GPU arch: design keeps coalesced loads (reorder col_idx to monotonic); Hopper/Ampere friendly (avoid divergence via grouped warp processing).
  • Autograd: backward reuses same block metadata; no need to store dense mask.
  • Fallback path: if mask object not recognized, materialize dense.

Describe alternatives you've considered
(See above; each discarded due to memory or inflexibility.)

Use case

  • Long-document models (e.g., 32K–128K) mixing causal + intra-document isolation.
  • Retrieval-augmented decoding where only retrieved segments should be visible.
  • Structured band + sparse global token patterns (hybrid).
  • Dynamic pruning (adaptive focus) without exploding memory.
    Target benefit: lower per-layer peak memory, enabling larger batch or longer context while retaining correctness.

Additional context

  • Forward FLOPs shrink from 2L²d to ~2L W_eff d; backward ~5L²d to ~5L W_eff d, with W_eff determined by active blocks.
  • Memory logical coverage: full S/P: 2L² vs compressed (L² + 2LW) only in “store mask” framing; with fully parametric + recompute, actual resident mask storage can approach O(L/B · a) or O(L).
  • Block size tuning (e.g., B=128) yields large compression (bitset ~ O((L/B)²) bits).
  • Dynamic path must avoid per-row global reductions; propose two-stage (block summary prefilter → fine QKᵀ inside selected blocks).

Related work

  • FlashAttention: streaming softmax avoids storing full P/S.
  • Longformer / BigBird: pattern-based sparsity (parametric masks).
  • Sparse Attention (BlockSparse, OpenAI): BCSR-like indexing for blocks.
  • Top-k / adaptive sparsity papers (e.g., Routing Transformer) motivate dynamic selection.
    Value: unify these patterns under one lightweight runtime + kernel interface without forcing users to choose incompatible attention ops.

Proposed next steps

  1. Define Mask interface.
  2. Implement parametric + block bitset path.
  3. Add mixed granularity.
  4. Integrate into forward kernel (enumerate active blocks).
  5. Mirror logic in backward.
  6. Add benchmarks (memory peak, TFLOPs, latency) on 8K/16K/32K vs dense mask.
  7. Extend with dynamic pruning hook.

Metadata

Metadata

Labels

featureNew feature request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions