-
Notifications
You must be signed in to change notification settings - Fork 35
Description
Describe the bug
Using code version v1.1.9, the training configuration is set to seq_len=4096 and window=2048. During the first training step, the backward pass consistently fails with an INF (infinity) error. Preliminary investigation has pinpointed the location in the code where the error occurs https://github.com/SmallDoges/flash-dmattn/blob/main/flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py#L91. The root cause appears to be that some numerical values are outside the representable range of the bf16 data type.
To Reproduce
Steps to reproduce the behavior:
- Import flash_dmattn
- Run the following code:
# Paste your code here
- See error
Expected behavior
No INF values.
Environment Information
PyTorch: 2.6.0+cu124
CUDA: 12.4
GPU: NVIDIA A800-SXM4-80GB
Additional context
- OS: Ubuntu 20.04
- Python version: 3.12.9
- Flash-DMA version: 1.1.9
- Compute Capability: 8.0
Error traceback
RuntimeError: Rank 0, node job-7db94950-80fb-47db-bacf-a9a63edec186-master-0, device 0, iteration 1: Unexpected result nan (message='found NaN in local grad norm for bucket #0 in backward pass
Debugging Information
