From 8fd051f2d8f57703f70f58a7e0b0208d1c5f74c5 Mon Sep 17 00:00:00 2001 From: jonb377 Date: Thu, 30 May 2024 13:29:31 -0700 Subject: [PATCH] Add optimizer priming for dist chkpt (#6572) --- docs/spmd.md | 20 +++- test/spmd/test_xla_distributed_checkpoint.py | 95 ++++++++++++++++++- torch_xla/csrc/init_python_bindings.cpp | 10 ++ .../distributed_checkpoint/__init__.py | 2 + .../distributed_checkpoint/manager.py | 5 +- .../distributed_checkpoint/util.py | 44 +++++++++ 6 files changed, 171 insertions(+), 5 deletions(-) create mode 100644 torch_xla/experimental/distributed_checkpoint/util.py diff --git a/docs/spmd.md b/docs/spmd.md index 5c60360fdc3..84236c13469 100644 --- a/docs/spmd.md +++ b/docs/spmd.md @@ -297,7 +297,7 @@ checkpointing directly to any fsspec-compatible filesystem, including GCS. Example usage of the CheckpointManager is below: ```python -from torch_xla.experimental.distributed_checkpoint import CheckpointManager +from torch_xla.experimental.distributed_checkpoint import CheckpointManager, prime_optimizer # Create a CheckpointManager to checkpoint every 10 steps into GCS. chkpt_mgr = CheckpointManager('gs://my-bucket/my-experiment', 10) @@ -307,9 +307,13 @@ tracked_steps = chkpt_mgr.all_steps() if tracked_steps: # Choose the highest step best_step = max(tracked_steps) - state_dict = {'model': model.state_dict()} + # Before restoring the checkpoint, the optimizer state must be primed + # to allow state to be loaded into it. + prime_optimizer(optim) + state_dict = {'model': model.state_dict(), 'optim': optim.state_dict()} chkpt_mgr.restore(best_step, state_dict) model.load_state_dict(state_dict['model']) + optim.load_state_dict(state_dict['optim']) # Call `save` or `save_async` every step within the train loop. These methods # return True when a checkpoint is taken. @@ -320,6 +324,18 @@ for step, data in enumerate(dataloader): print(f'Checkpoint taken at step {step}') ``` +##### Restoring Optimizer State + +In distributed checkpointing, the state_dicts are loaded in-place, and only the +required shards of the checkpoint are loaded. Since optimizer states are lazily +created, the state isn't present until the first `optimizer.step` call, and +attempts to load an unprimed optimizer will fail. + +The utility method `prime_optimizer` is provided for this: it runs a fake train +step by setting all gradients to zero and calling `optimizer.step`. *This is a +destructive method and will touch both model parameters and optimizer state*, +so it should only be called just prior to restoration. + ### Process Groups To use `torch.distributed` APIs such as distributed checkpointing, a process group is required. In SPMD mode, the `xla` backend is not supported since the diff --git a/test/spmd/test_xla_distributed_checkpoint.py b/test/spmd/test_xla_distributed_checkpoint.py index a78057210ab..a035a3f11bd 100644 --- a/test/spmd/test_xla_distributed_checkpoint.py +++ b/test/spmd/test_xla_distributed_checkpoint.py @@ -17,11 +17,13 @@ import torch_xla.distributed.spmd as xs from torch.distributed.checkpoint._fsspec_filesystem import * +from collections.abc import Iterable + from torch.distributed.checkpoint.default_planner import ( create_default_local_save_plan, create_default_global_save_plan, ) -from torch_xla.experimental.distributed_checkpoint import SPMDLoadPlanner, SPMDSavePlanner, CheckpointManager +from torch_xla.experimental.distributed_checkpoint import SPMDLoadPlanner, SPMDSavePlanner, CheckpointManager, prime_optimizer from torch_xla.experimental.distributed_checkpoint._helpers import ( _sharded_cpu_state_dict, _CpuShards, _is_sharded_tensor) @@ -68,6 +70,33 @@ def _same_shard_data(self, shards, others) -> bool: return False return True + def _assert_same_state_dict(self, sd1, sd2, keypath=""): + assert type(sd1) == type( + sd2), f"Different types in state_dict: {sd1} vs {sd2}" + + if isinstance(sd1, torch.Tensor): + assert sd1.device == sd2.device, f"Tensors on different devices at {keypath}: {sd1} vs {sd2}" + if sd1.device == xm.xla_device(): + sharding1 = torch_xla._XLAC._get_xla_sharding_spec(sd1) + sharding2 = torch_xla._XLAC._get_xla_sharding_spec(sd2) + assert sharding1 == sharding2, f"Different sharding on tensors at {keypath}: {sharding1} vs {sharding2}" + assert torch.equal( + sd1, sd2), f"Different tensors at {keypath}:\n{sd1} vs {sd2}" + + elif isinstance(sd1, dict): + assert sd1.keys() == sd2.keys( + ), f"Different keys at {keypath}: {sd1} vs {sd2}" + for key in sd1: + self._assert_same_state_dict( + sd1[key], sd2[key], keypath=f'{keypath}.{key}') + + elif isinstance(sd1, Iterable): + for ind, (a, b) in enumerate(zip(sd1, sd2)): + self._assert_same_state_dict(a, b, keypath=f'{keypath}[{ind}]') + + else: + assert sd1 == sd2, f"Different value at {keypath}: {sd1} vs {sd2}" + class EndToEndCheckpointTest(DistributedCheckpointTestBase): @@ -357,7 +386,7 @@ class CheckpointManagerTest(DistributedCheckpointTestBase): def setUp(self): super().setUp() - # Initialize the a minimal process group + # Initialize a minimal process group dist.init_process_group( backend='gloo', init_method='tcp://localhost:8932', @@ -565,6 +594,68 @@ def test_auto_checkpoint(self, tmpdir): self.assertTrue(chkpt_mgr.reached_preemption(step)) +@unittest.skipIf(xr.device_type() != 'TPU', + 'TPU required for worker IP discovery') +class OptimizerCheckpointTest(DistributedCheckpointTestBase): + + def setUp(self): + super().setUp() + # Initialize a minimal process group + dist.init_process_group( + backend='gloo', + init_method='tcp://localhost:8932', + world_size=1, + rank=0) + + def tearDown(self): + super().tearDown() + # Destroy the CPU process group after the test + dist.destroy_process_group() + + def _get_model_and_optimizer(self, optim_cls): + model = self._get_sharded_model() + optim = optim_cls(params=model.parameters()) + return model, optim + + def _run_train_step(self, model, optim): + torch.manual_seed(42) + model(torch.ones(10, 128).to('xla')).square().sum().backward() + optim.step() + xm.mark_step() + + def _test_optimizer(self, tmpdir, optim_cls): + model, optim = self._get_model_and_optimizer(optim_cls) + self._run_train_step(model, optim) + + # Take a checkpoint including the optimizer + chkpt_mgr = CheckpointManager(tmpdir, 1) + state_dict = {'model': model.state_dict(), 'optim': optim.state_dict()} + chkpt_mgr.save(0, state_dict, force=True) + + # Load the checkpoint into a new model and optimizer + new_model, new_optim = self._get_model_and_optimizer(optim_cls) + prime_optimizer(new_optim) + new_state_dict = { + 'model': new_model.state_dict(), + 'optim': new_optim.state_dict() + } + chkpt_mgr.restore(0, new_state_dict) + self._assert_same_state_dict(state_dict, new_state_dict) + + new_model.load_state_dict(new_state_dict['model']) + new_optim.load_state_dict(new_state_dict['optim']) + self._assert_same_state_dict(new_model.state_dict(), model.state_dict()) + self._assert_same_state_dict(new_optim.state_dict(), optim.state_dict()) + + @run_with_tmpdir + def test_sgd(self, tmpdir): + self._test_optimizer(tmpdir, torch.optim.SGD) + + @run_with_tmpdir + def test_adamw(self, tmpdir): + self._test_optimizer(tmpdir, torch.optim.AdamW) + + if __name__ == '__main__': test = unittest.main() sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 12e955adf38..57c9e441e59 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -2061,6 +2061,16 @@ void InitXlaModuleBindings(py::module m) { XLATensorPtr xtensor = bridge::GetXlaTensor(input); return GetXLAShardingSpec(xtensor); }); + m.def("_get_xla_op_sharding", + [](const at::Tensor& input) -> std::optional { + XLATensorPtr xtensor = bridge::GetXlaTensor(input); + XLATensor::ShardingSpecPtr sharding_spec = + xtensor ? xtensor->sharding_spec() : nullptr; + if (sharding_spec != nullptr) { + return sharding_spec->sharding; + } + return std::nullopt; + }); m.def("_get_xla_sharding_specs", [](const std::vector& tensors) -> std::vector { tsl::profiler::TraceMe activity("_get_xla_sharding_specs", diff --git a/torch_xla/experimental/distributed_checkpoint/__init__.py b/torch_xla/experimental/distributed_checkpoint/__init__.py index cad57c3a405..a29b943f217 100644 --- a/torch_xla/experimental/distributed_checkpoint/__init__.py +++ b/torch_xla/experimental/distributed_checkpoint/__init__.py @@ -1,8 +1,10 @@ from .manager import CheckpointManager from .planners import SPMDSavePlanner, SPMDLoadPlanner +from .util import prime_optimizer __all__ = [ "CheckpointManager", "SPMDSavePlanner", "SPMDLoadPlanner", + "prime_optimizer", ] diff --git a/torch_xla/experimental/distributed_checkpoint/manager.py b/torch_xla/experimental/distributed_checkpoint/manager.py index 4ce57b5fb38..5d4ce7814e2 100644 --- a/torch_xla/experimental/distributed_checkpoint/manager.py +++ b/torch_xla/experimental/distributed_checkpoint/manager.py @@ -66,7 +66,10 @@ class CheckpointManager: >>> if tracked_steps: >>> # Choose the highest step >>> best_step = max(tracked_steps) - >>> state_dict = {'model': model.state_dict()} + >>> # Before restoring the checkpoint, the optimizer state must be primed + >>> # to allow state to be loaded into it. + >>> prime_optimizer(optim) + >>> state_dict = {'model': model.state_dict(), 'optim': optim.state_dict()} >>> chkpt_mgr.restore(best_step, state_dict) >>> model.load_state_dict(state_dict['model']) diff --git a/torch_xla/experimental/distributed_checkpoint/util.py b/torch_xla/experimental/distributed_checkpoint/util.py new file mode 100644 index 00000000000..198cb350323 --- /dev/null +++ b/torch_xla/experimental/distributed_checkpoint/util.py @@ -0,0 +1,44 @@ +import torch +from torch.utils._pytree import tree_map +import torch_xla +import torch_xla.core.xla_model as xm + + +def prime_optimizer(optimizer: torch.optim.Optimizer) -> torch.optim.Optimizer: + """ + Prime the optimizer state by running a dummy weight update. + + Optimizer state isn't created until after the first training step. Since the + distributed checkpointing library loads the state_dict in-place, the + optimizer state must already exist before loading the checkpoint. + + This utility method runs a dummy weight update with zero gradient to ensure + the optimizer state exists and can be loaded into. + + **Warning** This method calls `optimizer.step`, which can impact the + optimizer's state and model parameters. Therefore, it should only be used + prior to restoring a checkpoint, when the state and parameters will be + immediately overwritten. + + Args: + optimizer: The optimizer whose state should be primed for checkpoint + loading. + """ + + # Initial mark_step to ensure all param_groups are backed by device data. + xm.mark_step() + xm.wait_device_ops() + + def zero_grad(x): + if isinstance(x, torch.Tensor) and x.requires_grad: + x.grad = torch.zeros_like(x, requires_grad=False) + param_sharding = torch_xla._XLAC._get_xla_op_sharding(x) + if param_sharding: + # Match the gradient sharding to the parameter's. + torch_xla._XLAC._xla_mark_sharding(x.grad, param_sharding) + + tree_map(zero_grad, optimizer.param_groups) + optimizer.step() + xm.mark_step() + xm.wait_device_ops() + return optimizer