Multi-GPU Context Parallel Mamba2 #664
                
     Open
            
            
          
  Add this suggestion to a batch that can be applied as a single commit.
  This suggestion is invalid because no changes were made to the code.
  Suggestions cannot be applied while the pull request is closed.
  Suggestions cannot be applied while viewing a subset of changes.
  Only one suggestion per line can be applied in a batch.
  Add this suggestion to a batch that can be applied as a single commit.
  Applying suggestions on deleted lines is not supported.
  You must change the existing code in this line in order to create a valid suggestion.
  Outdated suggestions cannot be applied.
  This suggestion has been applied or marked resolved.
  Suggestions cannot be applied from pending reviews.
  Suggestions cannot be applied on multi-line comments.
  Suggestions cannot be applied while the pull request is queued to merge.
  Suggestion cannot be applied right now. Please check back later.
  
    
  
    
I've made an implementation here of Context Parallelism for Mamba 2. It uses a sequential step at the state transfer stage, but otherwise functions in parallel. I've validate that the results are numerically within floating point error between single GPU context and multi-GPU Context, for both forward and backward pass calculations.
It uses a hack of the causal_conv1d function by transfering the number of tokens equivalent to the convolution window between GPUs and then discarding the result for few prepended tokens on each GPU. This requires a new ContextMixer layer to be inserted before each Mamba2 Layer, which is automatically inserted in a modification to the Mamba 2 class. The actual GPU to GPU transfer is done in a loop in the ssd_combined function.
Please let me know how I can further improve the PR to make it a mergeable contribution. Also feel free to reach out if you'd like help setting up a multi-GPU context parallel run.
N.B. this PR does not include splitting of the initial input sequence or aggregating gradients after loss, both of which would need to be performed by the training loop code.