Skip to content

Commit

Permalink
[Checkpoint]Update dt_planner to use local_offsets from DTensorSpec (p…
Browse files Browse the repository at this point in the history
…ytorch#617)

Update dt_planner to use local_offsets() from DTensorSpec.

Test:
```
python3 test/spmd/checkpoint/test_dt_planner.py
```
  • Loading branch information
wz337 authored Nov 17, 2022
1 parent fe56150 commit 25542e6
Showing 1 changed file with 2 additions and 10 deletions.
12 changes: 2 additions & 10 deletions spmd/checkpoint/dt_planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,17 +121,9 @@ def get_box_for(
device_mesh = tensor.device_mesh
assert device_mesh.ndim == 1, "Only 1D DeviceMeshes currently handled"

placement = tensor.placements[0]
offsets = [0] * len(tensor.size())
num_chunks = device_mesh.size(dim=0)

if tensor.placements[0].is_shard():
shard_dim = placement.dim # type: ignore # pyre-ignore[16]
chunk_size = tensor.size(shard_dim) // num_chunks
offsets[shard_dim] = chunk_size

size = tensor.to_local().size()
offsets = [val * idx for val in offsets] # type: ignore
offsets = tensor._spec.local_offsets

return (torch.Size(offsets), size)


Expand Down

0 comments on commit 25542e6

Please sign in to comment.