-
Notifications
You must be signed in to change notification settings - Fork 35
Description
Is your feature request related to a problem? Please describe.
Current usage patterns (long-context, multi-document, selective attention) require masking strategies beyond simple causal or fixed windows. A naïve dense L×L mask (even bit-packed) inflates per-layer peak memory and prevents fully exploiting block-level skipping in kernels. For 32K sequences, materializing logical S/P or full masks—even transiently—creates pressure on memory and bandwidth and limits scalability.
Describe the solution you'd like
Introduce a unified mask abstraction that:
- Supports: causal mask, document (intra-segment) mask, windowed banded mask, hybrid doc+causal, dynamic (runtime-pruned) mask.
- Represents masks in compressed forms:
- Parametric (causal, sliding window) → no storage.
- Block bitset (B×B granularity) for moderate sparsity.
- BCSR (row_ptr + col_idx) for irregular sparse patterns.
- Mixed granularity: dense blocks + few partial blocks (bitpacked).
- Kernel API accepts a lightweight per-(query block) descriptor to enumerate active key blocks; token-level refinement generated on-the-fly if needed.
- Forward & backward share identical block activation logic (OR(mask_block) pre-check) to skip inactive tiles before K/V loads.
- Optional dynamic refinement hook (user callback or pre-pass) for top-k / score-threshold pruning without global L×L reduction.
Describe alternatives you've considered
- Full L×L mask tensor (too large, defeats memory savings).
- Storing per-row start/end (works only for pure band/window).
- Pure top-k inside kernel (requires global reductions per row; high latency).
- Precomputing pruned indices at Python level (overhead + host/device sync).
Implementation details
- CUDA/Triton kernel changes:
- Add indirection layer: query_block_id → (bitset | block index list | param spec).
- Fast path specialization for parametric (no loads).
- Shared-memory staging of active col block IDs; warp ballot over bitset.
- Optional partial-block submask bitpack (<= 128 bits) applied after QKᵀ tile.
- Python API: new Mask classes (CausalMask, WindowMask, DocMask, DynamicMask).
- Dynamic mask: optional prepass producing BCSR; or callback registered to produce active block list per step.
- Performance: reduces unnecessary K/V reads; improves arithmetic intensity by eliminating zero tiles; keeps log-sum-exp streaming unchanged.
- GPU arch: design keeps coalesced loads (reorder col_idx to monotonic); Hopper/Ampere friendly (avoid divergence via grouped warp processing).
- Autograd: backward reuses same block metadata; no need to store dense mask.
- Fallback path: if mask object not recognized, materialize dense.
Describe alternatives you've considered
(See above; each discarded due to memory or inflexibility.)
Use case
- Long-document models (e.g., 32K–128K) mixing causal + intra-document isolation.
- Retrieval-augmented decoding where only retrieved segments should be visible.
- Structured band + sparse global token patterns (hybrid).
- Dynamic pruning (adaptive focus) without exploding memory.
Target benefit: lower per-layer peak memory, enabling larger batch or longer context while retaining correctness.
Additional context
- Forward FLOPs shrink from 2L²d to ~2L W_eff d; backward ~5L²d to ~5L W_eff d, with W_eff determined by active blocks.
- Memory logical coverage: full S/P: 2L² vs compressed (L² + 2LW) only in “store mask” framing; with fully parametric + recompute, actual resident mask storage can approach O(L/B · a) or O(L).
- Block size tuning (e.g., B=128) yields large compression (bitset ~ O((L/B)²) bits).
- Dynamic path must avoid per-row global reductions; propose two-stage (block summary prefilter → fine QKᵀ inside selected blocks).
Related work
- FlashAttention: streaming softmax avoids storing full P/S.
- Longformer / BigBird: pattern-based sparsity (parametric masks).
- Sparse Attention (BlockSparse, OpenAI): BCSR-like indexing for blocks.
- Top-k / adaptive sparsity papers (e.g., Routing Transformer) motivate dynamic selection.
Value: unify these patterns under one lightweight runtime + kernel interface without forcing users to choose incompatible attention ops.
Proposed next steps
- Define Mask interface.
- Implement parametric + block bitset path.
- Add mixed granularity.
- Integrate into forward kernel (enumerate active blocks).
- Mirror logic in backward.
- Add benchmarks (memory peak, TFLOPs, latency) on 8K/16K/32K vs dense mask.
- Extend with dynamic pruning hook.