-
Couldn't load subscription status.
- Fork 5
State dict serialization #51
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks really solid, I'm impressed with how quickly this came together. A couple nits, but nothing major.
Could you please test speed increase in e2e test_models tests and report in this PR?
edit: Ah I just remembered we talked about landing this in stages so I think some of the necessary plumbing wont exist.
torchstore/state_dict_utils.py
Outdated
| size: int # Size in bytes | ||
|
|
||
|
|
||
| def generate_tensor_blob(state_dict: Dict[str, Any]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we instead use flatten_state_dict instead of making this recursive?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you think about making this a class method of a "TorchStoreStateDict", or similar?
Then we can do things like:
torchstore_sd = TorchStoreStateDict.from_state_dict(original_state_dict)
torchstore_sd.to_state_dict()
and also store any necessary data as objects in the state dict.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
torchstore/state_dict_utils.py
Outdated
| return modified_state_dict, torch.empty(0, dtype=torch.uint8) | ||
|
|
||
| # Calculate total size and update offsets | ||
| current_offset = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we have a _state_dict_size function in state dict utils
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_state_dict_size calculates approximate size return size << 20.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe the largest deltas on this PR are:
- Implementing as class
- Managing DTensor
My recommendation for DTensor is to first convert to tensor slice, and store all additional metadata in the state dict. (This is actually my advice always for dealing with dtensor so we can reduce the number of branches in the codebase)
|
Updated to class representation and using flattened state dict, which makes a lot of sense because list iteration is way simpler than recursion. Also added DTensor support with |
|
Haven't gone through the code but have a general question in mind. |
|
Hi @casteryh, for getting DTensor with a different sharding plan, right now the interface in torchstore is by specifying the get dtensor sharding plan in a inplace tensor. Right now in this PR, it only supports |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
overall LGTM.
still have some questions:
currently, this is not integrated with torchstore.put / torchstore.get right?
For example, if I have a state dict sd = {"a": t} where t is a DTensor sharded across two ranks.
- On each rank, if I do
ts_sd = TorchStoreStateDict.from_state_dict(d), thents_sdwill no longer contain DTensors, right? - Consequently, if I do a
ts.put("state_dict_key", ts_sd)on both ranks, then torchstore is supposed to detect that ts_sd is aTorchStoreStateDictand handle the sharding logic accordingly, right? <- My understanding is this part is not done yet
tests/test_state_dict.py
Outdated
| assert torchstore_state_dict.flattened_state_dict == {} | ||
| assert len(torchstore_state_dict.tensor_blob) == 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this seems to be testing implementation details as opposed to behaviors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. Removed.
tests/test_state_dict.py
Outdated
| assert len(torchstore_state_dict.tensor_blob) == 0 | ||
| reconstructed = torchstore_state_dict.to_state_dict() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed.
tests/test_state_dict.py
Outdated
| scalar_dict = {"scalar": torch.tensor(3.14159)} | ||
| torchstore_state_dict = TorchStoreStateDict.from_state_dict(scalar_dict) | ||
| # Check flattened state dict has TensorReference | ||
| scalar_ref = torchstore_state_dict.flattened_state_dict["scalar"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
|
|
||
| # Create DTensor from local tensor | ||
| local_tensor = torch.randn(4, 6, dtype=torch.float32) | ||
| dtensor = DTensor.from_local(local_tensor, device_mesh, [Replicate()]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add a test for sharded dtensor (with world size > 1)? I am actually also confused about the expected behavior in this case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That dtensor put then get functionality will be added in the next PR where we integrate the state_dict functionality into torchstore. This PR only do the serialization and deserialization part.
| from torch.distributed.tensor._utils import _compute_local_shape_and_global_offset | ||
|
|
||
|
|
||
| def create_tensor_slice_from_dtensor(dtensor: DTensor) -> "TensorSlice": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| def create_tensor_slice_from_dtensor(dtensor: DTensor) -> "TensorSlice": | |
| from torchstore.transport.pipe import TensorSlice | |
| def create_tensor_slice_from_dtensor(dtensor: DTensor) -> TensorSlice: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See the response to the other comment about import ordering.
| Returns: | ||
| TensorSlice containing the distributed tensor metadata | ||
| """ | ||
| from torchstore.transport.pipe import TensorSlice |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there a particular reason to avoid import this on the file level?
if not, move import to top of file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So there's a circular dependency where
pipe.TensorSlice
^
|
dtensor_util.create_tensor_slice_from_dtensor
^
|
pipe.Request.from_dtensor
Maybe we should put TensorSlice definition into dtensor_util.py module?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would from __future__ import annotations fix this?
If not then just leave it as is.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sync'ed with Yuxuan and Lucas. We will do DTensor put and get in the next PR. This current PR only makes sure that DTensor can be serialized and deserialized properly.
|
|
||
| # Create DTensor from local tensor | ||
| local_tensor = torch.randn(4, 6, dtype=torch.float32) | ||
| dtensor = DTensor.from_local(local_tensor, device_mesh, [Replicate()]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That dtensor put then get functionality will be added in the next PR where we integrate the state_dict functionality into torchstore. This PR only do the serialization and deserialization part.
| from torch.distributed.tensor._utils import _compute_local_shape_and_global_offset | ||
|
|
||
|
|
||
| def create_tensor_slice_from_dtensor(dtensor: DTensor) -> "TensorSlice": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See the response to the other comment about import ordering.
| Returns: | ||
| TensorSlice containing the distributed tensor metadata | ||
| """ | ||
| from torchstore.transport.pipe import TensorSlice |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So there's a circular dependency where
pipe.TensorSlice
^
|
dtensor_util.create_tensor_slice_from_dtensor
^
|
pipe.Request.from_dtensor
Maybe we should put TensorSlice definition into dtensor_util.py module?
tests/test_state_dict.py
Outdated
| scalar_dict = {"scalar": torch.tensor(3.14159)} | ||
| torchstore_state_dict = TorchStoreStateDict.from_state_dict(scalar_dict) | ||
| # Check flattened state dict has TensorReference | ||
| scalar_ref = torchstore_state_dict.flattened_state_dict["scalar"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
tests/test_state_dict.py
Outdated
| assert torchstore_state_dict.flattened_state_dict == {} | ||
| assert len(torchstore_state_dict.tensor_blob) == 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. Removed.
tests/test_state_dict.py
Outdated
| assert len(torchstore_state_dict.tensor_blob) == 0 | ||
| reconstructed = torchstore_state_dict.to_state_dict() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe try from __future__ import annotations.
If it doesn't work then don't bother.
This PR creates two functions for state_dict serialization and deserialization.
generate_tensor_blobrecursively looks for tensors in the state_dict and serializes them into a blob of tensors. And replace the tensor in the state_dict withTensorReference.TensorReferencecontains the metadata for the offset, shape and dtype of the original tensor.reconstruct_state_dict_from_tensor_blobdoes the reverse operation ofgenerate_tensor_blob, it takes the tensor blob and state_dict (with only tensor_references) and replace all of the tensor_references inside the state_dict with the reconstructed tensors (from tensor blob andTensorReference)