-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Open
Description
I have a working context parallel implementation forked from this repo for forward/backward passes which required two modifications
- padding conv layer input chunks on each GPU with the last N_padding tokens of the previous GPU and then discarding padding token output indices
- transferring final states in state passing point-to-point between GPUs sequentially
And then vice-a-versa for the backward pass. I believe I've also worked out a way to do this without sequential point-to-point.
Would this be useful to contribute? If so, would like to know best way to do so since it requires modification of the core wrapper of the mamba 2 triton code.
Metadata
Metadata
Assignees
Labels
No labels