Skip to content

Commit

Permalink
Add optimizer priming for dist chkpt (#6572)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonb377 authored May 30, 2024
1 parent 8c2234e commit 8fd051f
Show file tree
Hide file tree
Showing 6 changed files with 171 additions and 5 deletions.
20 changes: 18 additions & 2 deletions docs/spmd.md
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ checkpointing directly to any fsspec-compatible filesystem, including GCS.
Example usage of the CheckpointManager is below:

```python
from torch_xla.experimental.distributed_checkpoint import CheckpointManager
from torch_xla.experimental.distributed_checkpoint import CheckpointManager, prime_optimizer

# Create a CheckpointManager to checkpoint every 10 steps into GCS.
chkpt_mgr = CheckpointManager('gs://my-bucket/my-experiment', 10)
Expand All @@ -307,9 +307,13 @@ tracked_steps = chkpt_mgr.all_steps()
if tracked_steps:
# Choose the highest step
best_step = max(tracked_steps)
state_dict = {'model': model.state_dict()}
# Before restoring the checkpoint, the optimizer state must be primed
# to allow state to be loaded into it.
prime_optimizer(optim)
state_dict = {'model': model.state_dict(), 'optim': optim.state_dict()}
chkpt_mgr.restore(best_step, state_dict)
model.load_state_dict(state_dict['model'])
optim.load_state_dict(state_dict['optim'])

# Call `save` or `save_async` every step within the train loop. These methods
# return True when a checkpoint is taken.
Expand All @@ -320,6 +324,18 @@ for step, data in enumerate(dataloader):
print(f'Checkpoint taken at step {step}')
```

##### Restoring Optimizer State

In distributed checkpointing, the state_dicts are loaded in-place, and only the
required shards of the checkpoint are loaded. Since optimizer states are lazily
created, the state isn't present until the first `optimizer.step` call, and
attempts to load an unprimed optimizer will fail.

The utility method `prime_optimizer` is provided for this: it runs a fake train
step by setting all gradients to zero and calling `optimizer.step`. *This is a
destructive method and will touch both model parameters and optimizer state*,
so it should only be called just prior to restoration.

### Process Groups
To use `torch.distributed` APIs such as distributed checkpointing, a process
group is required. In SPMD mode, the `xla` backend is not supported since the
Expand Down
95 changes: 93 additions & 2 deletions test/spmd/test_xla_distributed_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@
import torch_xla.distributed.spmd as xs

from torch.distributed.checkpoint._fsspec_filesystem import *
from collections.abc import Iterable

from torch.distributed.checkpoint.default_planner import (
create_default_local_save_plan,
create_default_global_save_plan,
)
from torch_xla.experimental.distributed_checkpoint import SPMDLoadPlanner, SPMDSavePlanner, CheckpointManager
from torch_xla.experimental.distributed_checkpoint import SPMDLoadPlanner, SPMDSavePlanner, CheckpointManager, prime_optimizer
from torch_xla.experimental.distributed_checkpoint._helpers import (
_sharded_cpu_state_dict, _CpuShards, _is_sharded_tensor)

Expand Down Expand Up @@ -68,6 +70,33 @@ def _same_shard_data(self, shards, others) -> bool:
return False
return True

def _assert_same_state_dict(self, sd1, sd2, keypath=""):
assert type(sd1) == type(
sd2), f"Different types in state_dict: {sd1} vs {sd2}"

if isinstance(sd1, torch.Tensor):
assert sd1.device == sd2.device, f"Tensors on different devices at {keypath}: {sd1} vs {sd2}"
if sd1.device == xm.xla_device():
sharding1 = torch_xla._XLAC._get_xla_sharding_spec(sd1)
sharding2 = torch_xla._XLAC._get_xla_sharding_spec(sd2)
assert sharding1 == sharding2, f"Different sharding on tensors at {keypath}: {sharding1} vs {sharding2}"
assert torch.equal(
sd1, sd2), f"Different tensors at {keypath}:\n{sd1} vs {sd2}"

elif isinstance(sd1, dict):
assert sd1.keys() == sd2.keys(
), f"Different keys at {keypath}: {sd1} vs {sd2}"
for key in sd1:
self._assert_same_state_dict(
sd1[key], sd2[key], keypath=f'{keypath}.{key}')

elif isinstance(sd1, Iterable):
for ind, (a, b) in enumerate(zip(sd1, sd2)):
self._assert_same_state_dict(a, b, keypath=f'{keypath}[{ind}]')

else:
assert sd1 == sd2, f"Different value at {keypath}: {sd1} vs {sd2}"


class EndToEndCheckpointTest(DistributedCheckpointTestBase):

Expand Down Expand Up @@ -357,7 +386,7 @@ class CheckpointManagerTest(DistributedCheckpointTestBase):

def setUp(self):
super().setUp()
# Initialize the a minimal process group
# Initialize a minimal process group
dist.init_process_group(
backend='gloo',
init_method='tcp://localhost:8932',
Expand Down Expand Up @@ -565,6 +594,68 @@ def test_auto_checkpoint(self, tmpdir):
self.assertTrue(chkpt_mgr.reached_preemption(step))


@unittest.skipIf(xr.device_type() != 'TPU',
'TPU required for worker IP discovery')
class OptimizerCheckpointTest(DistributedCheckpointTestBase):

def setUp(self):
super().setUp()
# Initialize a minimal process group
dist.init_process_group(
backend='gloo',
init_method='tcp://localhost:8932',
world_size=1,
rank=0)

def tearDown(self):
super().tearDown()
# Destroy the CPU process group after the test
dist.destroy_process_group()

def _get_model_and_optimizer(self, optim_cls):
model = self._get_sharded_model()
optim = optim_cls(params=model.parameters())
return model, optim

def _run_train_step(self, model, optim):
torch.manual_seed(42)
model(torch.ones(10, 128).to('xla')).square().sum().backward()
optim.step()
xm.mark_step()

def _test_optimizer(self, tmpdir, optim_cls):
model, optim = self._get_model_and_optimizer(optim_cls)
self._run_train_step(model, optim)

# Take a checkpoint including the optimizer
chkpt_mgr = CheckpointManager(tmpdir, 1)
state_dict = {'model': model.state_dict(), 'optim': optim.state_dict()}
chkpt_mgr.save(0, state_dict, force=True)

# Load the checkpoint into a new model and optimizer
new_model, new_optim = self._get_model_and_optimizer(optim_cls)
prime_optimizer(new_optim)
new_state_dict = {
'model': new_model.state_dict(),
'optim': new_optim.state_dict()
}
chkpt_mgr.restore(0, new_state_dict)
self._assert_same_state_dict(state_dict, new_state_dict)

new_model.load_state_dict(new_state_dict['model'])
new_optim.load_state_dict(new_state_dict['optim'])
self._assert_same_state_dict(new_model.state_dict(), model.state_dict())
self._assert_same_state_dict(new_optim.state_dict(), optim.state_dict())

@run_with_tmpdir
def test_sgd(self, tmpdir):
self._test_optimizer(tmpdir, torch.optim.SGD)

@run_with_tmpdir
def test_adamw(self, tmpdir):
self._test_optimizer(tmpdir, torch.optim.AdamW)


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
10 changes: 10 additions & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2061,6 +2061,16 @@ void InitXlaModuleBindings(py::module m) {
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
return GetXLAShardingSpec(xtensor);
});
m.def("_get_xla_op_sharding",
[](const at::Tensor& input) -> std::optional<xla::OpSharding> {
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
XLATensor::ShardingSpecPtr sharding_spec =
xtensor ? xtensor->sharding_spec() : nullptr;
if (sharding_spec != nullptr) {
return sharding_spec->sharding;
}
return std::nullopt;
});
m.def("_get_xla_sharding_specs",
[](const std::vector<at::Tensor>& tensors) -> std::vector<std::string> {
tsl::profiler::TraceMe activity("_get_xla_sharding_specs",
Expand Down
2 changes: 2 additions & 0 deletions torch_xla/experimental/distributed_checkpoint/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from .manager import CheckpointManager
from .planners import SPMDSavePlanner, SPMDLoadPlanner
from .util import prime_optimizer

__all__ = [
"CheckpointManager",
"SPMDSavePlanner",
"SPMDLoadPlanner",
"prime_optimizer",
]
5 changes: 4 additions & 1 deletion torch_xla/experimental/distributed_checkpoint/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,10 @@ class CheckpointManager:
>>> if tracked_steps:
>>> # Choose the highest step
>>> best_step = max(tracked_steps)
>>> state_dict = {'model': model.state_dict()}
>>> # Before restoring the checkpoint, the optimizer state must be primed
>>> # to allow state to be loaded into it.
>>> prime_optimizer(optim)
>>> state_dict = {'model': model.state_dict(), 'optim': optim.state_dict()}
>>> chkpt_mgr.restore(best_step, state_dict)
>>> model.load_state_dict(state_dict['model'])
Expand Down
44 changes: 44 additions & 0 deletions torch_xla/experimental/distributed_checkpoint/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import torch
from torch.utils._pytree import tree_map
import torch_xla
import torch_xla.core.xla_model as xm


def prime_optimizer(optimizer: torch.optim.Optimizer) -> torch.optim.Optimizer:
"""
Prime the optimizer state by running a dummy weight update.
Optimizer state isn't created until after the first training step. Since the
distributed checkpointing library loads the state_dict in-place, the
optimizer state must already exist before loading the checkpoint.
This utility method runs a dummy weight update with zero gradient to ensure
the optimizer state exists and can be loaded into.
**Warning** This method calls `optimizer.step`, which can impact the
optimizer's state and model parameters. Therefore, it should only be used
prior to restoring a checkpoint, when the state and parameters will be
immediately overwritten.
Args:
optimizer: The optimizer whose state should be primed for checkpoint
loading.
"""

# Initial mark_step to ensure all param_groups are backed by device data.
xm.mark_step()
xm.wait_device_ops()

def zero_grad(x):
if isinstance(x, torch.Tensor) and x.requires_grad:
x.grad = torch.zeros_like(x, requires_grad=False)
param_sharding = torch_xla._XLAC._get_xla_op_sharding(x)
if param_sharding:
# Match the gradient sharding to the parameter's.
torch_xla._XLAC._xla_mark_sharding(x.grad, param_sharding)

tree_map(zero_grad, optimizer.param_groups)
optimizer.step()
xm.mark_step()
xm.wait_device_ops()
return optimizer

0 comments on commit 8fd051f

Please sign in to comment.