Skip to content

[BUG REPORT] INF occurs in backward phrase of the first training step #180

@ftgreat

Description

@ftgreat

Describe the bug
Using code version v1.1.9, the training configuration is set to seq_len=4096 and window=2048. During the first training step, the backward pass consistently fails with an INF (infinity) error. Preliminary investigation has pinpointed the location in the code where the error occurs https://github.com/SmallDoges/flash-dmattn/blob/main/flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py#L91. The root cause appears to be that some numerical values are outside the representable range of the bf16 data type.

To Reproduce
Steps to reproduce the behavior:

  1. Import flash_dmattn
  2. Run the following code:
# Paste your code here
  1. See error

Expected behavior
No INF values.

Environment Information
PyTorch: 2.6.0+cu124
CUDA: 12.4
GPU: NVIDIA A800-SXM4-80GB

Additional context

  • OS: Ubuntu 20.04
  • Python version: 3.12.9
  • Flash-DMA version: 1.1.9
  • Compute Capability: 8.0

Error traceback

RuntimeError: Rank 0, node job-7db94950-80fb-47db-bacf-a9a63edec186-master-0, device 0, iteration 1: Unexpected result nan (message='found NaN in local grad norm for bucket #0 in backward pass 

Debugging Information

Image

Metadata

Metadata

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions