-
Notifications
You must be signed in to change notification settings - Fork 35
Description
Is your performance issue related to a problem? Please describe.
The current backward pass launch templates in Flash-DMA may not be optimally configured for all GPU architectures and problem sizes. Users experience suboptimal performance when running backward passes on different sequence lengths, batch sizes, and head dimensions, particularly on newer GPU architectures (SM 8.9, SM 9.0) where advanced features like asynchronous execution and improved shared memory hierarchies could be better utilized.
Describe the optimization you'd like
Implement adaptive backward launch template optimization that:
- Automatically selects optimal block sizes based on problem dimensions and GPU architecture
- Utilizes architecture-specific features (e.g., async copy, multi-level shared memory on H100)
- Provides better load balancing across different sequence length ranges
- Optimizes register usage and occupancy for backward-specific computation patterns
Describe alternatives you've considered
- Static template specialization: Pre-define optimal configurations for common scenarios
- Runtime auto-tuning: Benchmark different configurations on first run and cache results
- Heuristic-based selection: Use mathematical models to predict optimal configurations
- Hybrid approach: Combine static specialization with runtime adaptation for edge cases
Implementation details
This optimization would require:
CUDA kernel changes:
- Enhance backward launch template system in
flash_bwd_launch_template.h
- Add architecture-specific optimizations for SM 8.9 and SM 9.0
- Implement adaptive block size selection based on problem dimensions
- Optimize shared memory usage patterns for backward-specific data access
- Improve register allocation for gradient accumulation patterns
Python API changes:
- Add performance profiling hooks for auto-tuning
- Expose configuration options for advanced users
- Implement caching mechanism for optimal configurations
- Add backward-specific performance monitoring
Performance implications:
- Expected 15-25% improvement in backward pass latency
- Better memory bandwidth utilization (target: >85% peak bandwidth)
- Reduced register spilling on complex head dimension configurations
- Improved occupancy for small batch sizes and long sequences
Compatibility concerns:
- Maintain backward compatibility with existing kernel launches
- Ensure graceful fallback for unsupported architectures
- Validate across SM 8.0, 8.6, 8.9, and 9.0 architectures
- Test with various PyTorch versions (2.0+)
Use case
Specific scenarios where this optimization would provide significant benefits:
Sequence lengths:
- Long sequences (8K-32K tokens) where memory bandwidth becomes critical
- Variable length batches where load balancing is challenging
- Mixed precision training with bfloat16/float16 gradients
Target applications:
- Large language model training (LLaMA, GPT architectures)
- Long document processing and summarization
- Code generation models with extended context
- Multi-modal models with long sequence inputs
Current workflow bottlenecks:
- Backward pass taking 60-70% of total training time on long sequences
- Suboptimal GPU utilization during gradient computation phases
- Memory bandwidth underutilization on newer architectures
- Inefficient kernel launches for small batch, long sequence scenarios
Performance benchmarking details
Current measurements showing the need for optimization:
Additional context
This optimization is particularly important as:
- Training workloads are increasingly dominated by backward pass computation
- New GPU architectures provide underutilized performance features
- Dynamic mask attention backward passes have unique memory access patterns
- Current launch templates were optimized primarily for forward pass patterns
Profiling data
Attach NSight Compute profiles showing:
- Memory access patterns during backward gradient accumulation
- Register pressure analysis for different head dimensions
- Occupancy limitations with current block size selections
- Shared memory bank conflicts during bias gradient computation
Related work
This optimization is inspired by:
- FlashAttention-2 adaptive kernel selection strategies
- CUTLASS 3.x performance tuning methodologies
- Triton auto-tuning frameworks for GPU kernels
- Recent advances in CUDA 12.x asynchronous execution patterns
Implementation priority
High priority optimization that would benefit:
- All users training with long sequences (>2K tokens)
- Users upgrading to newer GPU architectures
- Applications where backward pass is the primary bottleneck
- Research requiring efficient gradient computation for attention mechanisms
Success metrics
- 15-25% reduction in backward pass latency
- >85% memory bandwidth utilization achieved
- Zero register spilling for common configurations
- >80% occupancy maintained across problem sizes
- Graceful performance scaling from SM 8.0 to SM 9.0