-
Notifications
You must be signed in to change notification settings - Fork 35
Description
Describe the bug
During the backward equivalence tests (dense path; dynamic mask and learnable bias temporarily disabled), gradients dV
occasionally contain NaN
/ Inf
values while dQ
, dK
, forward output, and softmax log-sum-exp remain numerically stable and within tolerance. This indicates a localized correctness / masking / OOB predication issue specific to the dV accumulation or store path.
To Reproduce
Steps to reproduce:
- Install and build Flash-DMAttn (Support-backward branch).
- Run the backward equivalence benchmark with CUDA launch blocking enabled:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
# Run from repository root
# Equivalent CLI:
# CUDA_LAUNCH_BLOCKING=1 python benchmarks/backward_equivalence.py --test-type cuda
from subprocess import run
run("python benchmarks/backward_equivalence.py --test-type cuda", shell=True)
- Observe test output: only
dV
fails equivalence (presence of NaN / Inf), while others pass.
Minimal inline snippet (mirroring failing configuration):
import torch
import flash_dmattn
torch.manual_seed(42)
device = "cuda"
B, H, HKV = 1, 1, 1
Q_LEN = 256
K_LEN = 256
D = 64
is_causal = True
q = torch.randn(B, Q_LEN, H, D, device=device, dtype=torch.bfloat16, requires_grad=True)
k = torch.randn(B, K_LEN, HKV, D, device=device, dtype=torch.bfloat16, requires_grad=True)
v = torch.randn(B, K_LEN, HKV, D, device=device, dtype=torch.bfloat16, requires_grad=True)
attn_mask = None
attn_bias = None
out = flash_dmattn.triton_dmattn_func(q, k, v, attn_mask, attn_bias, is_causal=is_causal, scale=None)
loss = out.sum()
loss.backward()
print("Has NaN in dV:", torch.isnan(v.grad).any().item(), "Has NaN in dK:", torch.isnan(k.grad).any().item(), "Has NaN in dQ:", torch.isnan(q.grad).any().item())
Expected behavior
All backward gradients (dQ
, dK
, dV
, and optional dBias
) should be finite and closely match the Python reference (equivalence test threshold passed) with no NaN
/ Inf
values.
Environment Information
(Example run; please adjust if different)
python -c "import torch; print(f'PyTorch: {torch.__version__}'); print(f'CUDA: {torch.version.cuda}'); print(f'GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'None'}')"
Reported during failing run:
- PyTorch: 2.8.0a0+5228986c39.nv25.05
- GPU: NVIDIA GeForce RTX 4090
- CUDA context active; dtype tested: bfloat16 (bf16).
OS: (container inside Linux base image)
Python: 3.12
Flash-DMAttn branch:Support-backward
Additional context
- Dynamic mask & learnable bias logic were temporarily disabled (running dense path) to isolate the backward math.
- Only
dV
exhibitsNaN
/Inf
;dQ
,dK
, forward output, and softmax LSE all pass with full element ratio within tolerance. - Forward path equivalence & performance tests already pass.
Error traceback
No Python exception is thrown; issue is silent numerical corruption. Sample benchmark log excerpt (trimmed):
Analyzing dV gradients:
Original result range: [-2.718750, 4.156250]
CUDA result range: [nan, nan]
Original result contains NaN: False, Inf: False
CUDA result contains NaN: True, Inf: True
Elements within tolerance ratio: 0.8490 (13910/16384)
Accuracy threshold: Fail
Debugging Information
Config where failure reproduced:
- batch_size=1
- num_heads=1, num_kv_heads=1
- query_len=256, key_len=256
- head_dim=64
- is_causal=True
- dtype: bfloat16
- Path: dense (mask / bias turned off)
Observations & Hypothesis:
- dV path = GEMM (P^T @ dO). Only this accumulation shows NaNs ⇒ likely uninitialized or OOB (out-of-bounds) lane participation in MMA or store.
- Code comments near dV writeback explicitly disable OOB clearing: “Clear_OOB_K must be false…” which increases risk if predicate masking is incomplete for tail tiles or vectorized BF16 stores.
- Even with dimensions seemingly aligned (kBlockN divides seqlen_k, head_dim divides tile headdim), internal MMA atom / ValLayout subdivisions (e.g., warpcontiguousN with factor 2) can create per-lane partial OOB requiring predicate.
- Potential sources:
- Missing zfill / predicate on cp.async loads for last tile of P or dO ⇒ NaN travels into accumulator.
- Vectorized BF16 store of
acc_dv
packs stray register values (uninitialized lanes) when Clear_OOB disabled. - Shared memory region reuse or offset miscalculation for
sdO
/sPt
causing overlap before consumers finish.
- Less likely (since dQ/dK are fine):
- Softmax LSE corruption (would affect all grads).
- Scale or softcap overflow (others would fail too).
Suggested Immediate Diagnostics:
- Add debug guard before dV store:
- Scan
acc_dv
(float accumulator) for non-finite; if present, log lane indices.
- Scan
- Temporarily enable OOB clearing (or predicate store) for dV tail tiles; see if NaNs disappear.
- Force scalar (non-vectorized) store for dV to test alignment hypothesis.
- Run with fp32 accumulate + fp32 output (avoid bf16 cast) to see if NaNs only arise after conversion.
- Test non-divisible shapes (e.g., key_len=192, head_dim=80) to amplify OOB surface—if NaN frequency changes, confirms predicate issue.
Potential Fix Directions:
- Re-introduce Clear_OOB for dV or implement per-lane predicate in gmem store (mask: column < headdim && row < seqlen_k).
- Use cp.async.zfill for tail loads of P and dO tiles.
- Validate shared-memory offset arithmetic for sdO vs sQ vs sP region boundaries.
- Add explicit zero-initialization of accumulator (
clear(acc_dv)
) verified already present earlier; confirm it happens for every n_block iteration. - Ensure
convert_type<Element>
does not depend on uninitialized fragment lanes (apply predicate before conversion).
Workaround (temporary):
- Enable OOB clearing or explicitly zero any non-finite elements in
rdV
before the global store (debug-only guard). - Run backward in fp32 (if memory permits) to mitigate silent NaN propagation from partial BF16 packs (not a real fix, only scoping aid).