Skip to content

Commit

Permalink
Use dest_offsets directly in LoadPlanner (#7155)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonb377 committed Jun 11, 2024
1 parent 71c25e6 commit 7be1d3d
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 5 deletions.
28 changes: 28 additions & 0 deletions test/spmd/test_xla_distributed_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,34 @@ def test_resharding_different_device_mesh(self):
save_planner=SPMDSavePlanner(),
load_planner=SPMDLoadPlanner())

@unittest.skipIf(xr.global_runtime_device_count() == 1,
"Multiple devices needed to change mesh")
def test_resharding_transpose_device_mesh(self):
dim = self.n_devices // 2
model1 = self._get_sharded_model(mesh_shape=(dim, self.n_devices // dim))
model2 = self._get_sharded_model(mesh_shape=(self.n_devices // dim, dim))
self._save_and_restore(
model1,
model2,
save_planner=SPMDSavePlanner(),
load_planner=SPMDLoadPlanner())

@unittest.skipIf(xr.global_runtime_device_count() == 1,
"Multiple devices needed to change mesh")
def test_padded_tensor(self):
# Use a linear layer with shape not divisible by the number of devices.
model1 = torch.nn.Linear(127, 63).to('xla')
model2 = torch.nn.Linear(127, 63).to('xla')
mesh = xs.Mesh(range(self.n_devices), (self.n_devices,))
# Transpose the sharding to induce resharding in the restore path
xs.mark_sharding(model1.weight, mesh, (0, None))
xs.mark_sharding(model2.weight, mesh, (None, 0))
self._save_and_restore(
model1,
model2,
save_planner=SPMDSavePlanner(),
load_planner=SPMDLoadPlanner())

@unittest.skipUnless('CHKPT_PATH' in os.environ,
'CHKPT_PATH must be set for multihost checkpoint')
def test_multihost_checkpoint(self):
Expand Down
5 changes: 0 additions & 5 deletions torch_xla/experimental/distributed_checkpoint/planners.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,11 +282,6 @@ def transform_tensor(self, read_item: ReadItem, tensor: torch.Tensor):
lengths and offsets into the global tensor.
"""
offsets = read_item.dest_offsets
index = read_item.dest_index
if index.fqn in self.sharded_state_dict:
# Update offsets to index into the shard rather than the global tensor
shard = self._local_shards[index.fqn][index.index]
offsets = torch.Size(d - i.start for d, i in zip(offsets, shard.indices))
return narrow_tensor_by_index(tensor, offsets, read_item.lengths)

def commit_tensor(self, read_item: ReadItem, tensor: torch.Tensor) -> None:
Expand Down

0 comments on commit 7be1d3d

Please sign in to comment.