Skip to content

[BUG] NaN / Inf values appear only in dV during backward pass #121

@LoserCheems

Description

@LoserCheems

Describe the bug
During the backward equivalence tests (dense path; dynamic mask and learnable bias temporarily disabled), gradients dV occasionally contain NaN / Inf values while dQ, dK, forward output, and softmax log-sum-exp remain numerically stable and within tolerance. This indicates a localized correctness / masking / OOB predication issue specific to the dV accumulation or store path.

To Reproduce
Steps to reproduce:

  1. Install and build Flash-DMAttn (Support-backward branch).
  2. Run the backward equivalence benchmark with CUDA launch blocking enabled:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

# Run from repository root
# Equivalent CLI:
# CUDA_LAUNCH_BLOCKING=1 python benchmarks/backward_equivalence.py --test-type cuda
from subprocess import run
run("python benchmarks/backward_equivalence.py --test-type cuda", shell=True)
  1. Observe test output: only dV fails equivalence (presence of NaN / Inf), while others pass.

Minimal inline snippet (mirroring failing configuration):

import torch
import flash_dmattn

torch.manual_seed(42)
device = "cuda"
B, H, HKV = 1, 1, 1
Q_LEN = 256
K_LEN = 256
D = 64
is_causal = True

q = torch.randn(B, Q_LEN, H, D, device=device, dtype=torch.bfloat16, requires_grad=True)
k = torch.randn(B, K_LEN, HKV, D, device=device, dtype=torch.bfloat16, requires_grad=True)
v = torch.randn(B, K_LEN, HKV, D, device=device, dtype=torch.bfloat16, requires_grad=True)
attn_mask = None
attn_bias = None

out = flash_dmattn.triton_dmattn_func(q, k, v, attn_mask, attn_bias, is_causal=is_causal, scale=None)
loss = out.sum()
loss.backward()

print("Has NaN in dV:", torch.isnan(v.grad).any().item(), "Has NaN in dK:", torch.isnan(k.grad).any().item(), "Has NaN in dQ:", torch.isnan(q.grad).any().item())

Expected behavior
All backward gradients (dQ, dK, dV, and optional dBias) should be finite and closely match the Python reference (equivalence test threshold passed) with no NaN / Inf values.

Environment Information
(Example run; please adjust if different)

python -c "import torch; print(f'PyTorch: {torch.__version__}'); print(f'CUDA: {torch.version.cuda}'); print(f'GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'None'}')"

Reported during failing run:

  • PyTorch: 2.8.0a0+5228986c39.nv25.05
  • GPU: NVIDIA GeForce RTX 4090
  • CUDA context active; dtype tested: bfloat16 (bf16).
    OS: (container inside Linux base image)
    Python: 3.12
    Flash-DMAttn branch: Support-backward

Additional context

  • Dynamic mask & learnable bias logic were temporarily disabled (running dense path) to isolate the backward math.
  • Only dV exhibits NaN / Inf; dQ, dK, forward output, and softmax LSE all pass with full element ratio within tolerance.
  • Forward path equivalence & performance tests already pass.

Error traceback
No Python exception is thrown; issue is silent numerical corruption. Sample benchmark log excerpt (trimmed):

Analyzing dV gradients:
Original result range: [-2.718750, 4.156250]
CUDA result range: [nan, nan]
Original result contains NaN: False, Inf: False
CUDA result contains NaN: True, Inf: True
Elements within tolerance ratio: 0.8490 (13910/16384)
Accuracy threshold: Fail

Debugging Information
Config where failure reproduced:

  • batch_size=1
  • num_heads=1, num_kv_heads=1
  • query_len=256, key_len=256
  • head_dim=64
  • is_causal=True
  • dtype: bfloat16
  • Path: dense (mask / bias turned off)

Observations & Hypothesis:

  1. dV path = GEMM (P^T @ dO). Only this accumulation shows NaNs ⇒ likely uninitialized or OOB (out-of-bounds) lane participation in MMA or store.
  2. Code comments near dV writeback explicitly disable OOB clearing: “Clear_OOB_K must be false…” which increases risk if predicate masking is incomplete for tail tiles or vectorized BF16 stores.
  3. Even with dimensions seemingly aligned (kBlockN divides seqlen_k, head_dim divides tile headdim), internal MMA atom / ValLayout subdivisions (e.g., warpcontiguousN with factor 2) can create per-lane partial OOB requiring predicate.
  4. Potential sources:
    • Missing zfill / predicate on cp.async loads for last tile of P or dO ⇒ NaN travels into accumulator.
    • Vectorized BF16 store of acc_dv packs stray register values (uninitialized lanes) when Clear_OOB disabled.
    • Shared memory region reuse or offset miscalculation for sdO / sPt causing overlap before consumers finish.
  5. Less likely (since dQ/dK are fine):
    • Softmax LSE corruption (would affect all grads).
    • Scale or softcap overflow (others would fail too).

Suggested Immediate Diagnostics:

  • Add debug guard before dV store:
    • Scan acc_dv (float accumulator) for non-finite; if present, log lane indices.
  • Temporarily enable OOB clearing (or predicate store) for dV tail tiles; see if NaNs disappear.
  • Force scalar (non-vectorized) store for dV to test alignment hypothesis.
  • Run with fp32 accumulate + fp32 output (avoid bf16 cast) to see if NaNs only arise after conversion.
  • Test non-divisible shapes (e.g., key_len=192, head_dim=80) to amplify OOB surface—if NaN frequency changes, confirms predicate issue.

Potential Fix Directions:

  • Re-introduce Clear_OOB for dV or implement per-lane predicate in gmem store (mask: column < headdim && row < seqlen_k).
  • Use cp.async.zfill for tail loads of P and dO tiles.
  • Validate shared-memory offset arithmetic for sdO vs sQ vs sP region boundaries.
  • Add explicit zero-initialization of accumulator (clear(acc_dv)) verified already present earlier; confirm it happens for every n_block iteration.
  • Ensure convert_type<Element> does not depend on uninitialized fragment lanes (apply predicate before conversion).

Workaround (temporary):

  • Enable OOB clearing or explicitly zero any non-finite elements in rdV before the global store (debug-only guard).
  • Run backward in fp32 (if memory permits) to mitigate silent NaN propagation from partial BF16 packs (not a real fix, only scoping aid).

Metadata

Metadata

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions