From 551a76cd3cbc372ab09e2e9e39d1a11c7c9f4886 Mon Sep 17 00:00:00 2001 From: "Huang, Guangtai" Date: Mon, 10 Jun 2024 12:43:18 -0700 Subject: [PATCH] [ZeRO-1] Sync features from r2.1_aws_neuron branch (#7132) --- test/run_tests.sh | 2 +- test/test_zero1.py | 76 ++++++-- torch_xla/core/xla_model.py | 2 +- .../distributed/zero_redundancy_optimizer.py | 162 +++++++++++++++--- 4 files changed, 198 insertions(+), 44 deletions(-) diff --git a/test/run_tests.sh b/test/run_tests.sh index 531a9b0c8a9..4a298f01ee5 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -184,7 +184,6 @@ function run_xla_op_tests1 { run_test "$CDIR/test_python_ops.py" run_test "$CDIR/test_ops.py" run_test "$CDIR/test_metrics.py" - run_test "$CDIR/test_zero1.py" run_test "$CDIR/dynamo/test_dynamo_integrations_util.py" run_test "$CDIR/dynamo/test_dynamo_aliasing.py" run_test "$CDIR/dynamo/test_dynamo.py" @@ -296,6 +295,7 @@ function run_mp_op_tests { run_test "$CDIR/test_mp_collective_permute.py" run_test "$CDIR/test_mp_all_gather.py" run_test "$CDIR/test_mp_reduce_scatter.py" + run_test "$CDIR/test_zero1.py" run_test "$CDIR/test_mp_distributed_mm.py" run_test "$CDIR/test_mp_save.py" run_test "$CDIR/test_mp_mesh_reduce.py" diff --git a/test/test_zero1.py b/test/test_zero1.py index 17c46617973..f51f8ae3f8d 100644 --- a/test/test_zero1.py +++ b/test/test_zero1.py @@ -2,64 +2,104 @@ import torch.nn as nn import torch_xla import torch_xla.core.xla_model as xm +import torch_xla.distributed.xla_multiprocessing as xmp from torch_xla.distributed.zero_redundancy_optimizer import ZeroRedundancyOptimizer from torch_xla import runtime as xr -from torch.testing._internal.common_utils import TestCase from copy import deepcopy +import sys import unittest +import test_utils -class XlaZeRO1Test(TestCase): + +def _get_partial_states(s): + dp_size = xr.global_device_count() + dp_rank = xr.global_ordinal() + + def convert_fn(tensors): + torch_xla._XLAC._xla_sync_multi( + tensors, devices=[], wait=True, sync_xla_data=True) + ret = [] + for t in tensors: + ret.append(t.chunk(dp_size)[dp_rank].detach().cpu()) + return ret + + def select_fn(v): + return type(v) == torch.Tensor and xm.is_xla_tensor(v) + + return xm.ToXlaTensorArena(convert_fn, select_fn).transform(s) + + +class XlaZeRO1Test(test_utils.XlaTestCase): @unittest.skipIf(xr.device_type() == 'TPU', "Crash on TPU") - @unittest.skipIf(xr.device_type() == 'CUDA', "Crash on CUDA") def test_zero1(self): device = xm.xla_device() - model = nn.Linear(8, 8) - x = torch.ones((8, 8)) + model = nn.Linear(32, 32) + x = torch.ones((32, 32)) + x.requires_grad = True model = model.to(device) x = x.to(device) y = model(x).sum() y.backward() + xm.mark_step() opt1 = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) + opt1.step() + xm.mark_step() + opt2 = ZeroRedundancyOptimizer( model.parameters(), torch.optim.SGD, lr=0.01, momentum=0.9, grad_clipping=False) - - opt1.step() opt2.step() + xm.mark_step() + s1 = opt1.state_dict() s2 = opt2.state_dict() - self.assertEqual(s1['state'], s2['base_state']) + self.assertEqual(_get_partial_states(s1['state']), s2['base_state']) - # deepcopy s1 to load later because pytorch optimizers do not guarantee the input - # state_dict will not be modified. on the other hand, s2 has this guarantee. - s1_clone = deepcopy(s1) + s1_clone = deepcopy(xm._maybe_convert_to_cpu(s1)) + s2_clone = deepcopy(xm._maybe_convert_to_cpu(s2)) opt1.load_state_dict(s1) opt2.load_state_dict(s2) - self.assertEqual(opt1.state_dict()['state'], - opt2.state_dict()['base_state']) + self.assertEqual( + _get_partial_states(opt1.state_dict()['state']), + opt2.state_dict()['base_state']) # step still runnable opt1.step() opt2.step() + xm.mark_step() + opt1.load_state_dict(s1_clone) - opt2.load_state_dict(s2) - self.assertEqual(opt1.state_dict()['state'], - opt2.state_dict()['base_state']) + opt2.load_state_dict(s2_clone) + xm.mark_step() + self.assertEqual( + _get_partial_states(opt1.state_dict()['state']), + opt2.state_dict()['base_state']) # step still runnable opt1.step() opt2.step() + xm.mark_step() + + +def _mp_fn(index): + device = xm.xla_device() + if xm.xla_device_hw(device) in ('TPU', 'CUDA'): + test = unittest.main(exit=False) + sys.exit(0 if test.result.wasSuccessful() else 1) + else: + print( + 'Default device {} is not a TPU or CUDA device'.format(device), + file=sys.stderr) if __name__ == '__main__': - test = unittest.main() - sys.exit(0 if test.result.wasSuccessful() else 1) + xmp.spawn(_mp_fn, args=()) diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index 149aa99b67d..4fb7cc1316a 100644 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -967,7 +967,7 @@ def reduce_scatter_bucketized(reduce_type, see reduce_scatter for reduce_type, scale, scatter_dim, shard_count, groups, pin_layout input_list: List of input tensors output: Optional list of output torch.Tensor - bucket_cap_mb: Number of MegaBytes of the tensor bucket to fill before doing all-gather. + bucket_cap_mb: Number of MegaBytes of the tensor bucket to fill before doing reduce-scatter. Returns: A list of `torch.Tensors` with all the values reduced across replicas. Each process diff --git a/torch_xla/distributed/zero_redundancy_optimizer.py b/torch_xla/distributed/zero_redundancy_optimizer.py index 9b21fe4ead8..7e1e7b6cc10 100644 --- a/torch_xla/distributed/zero_redundancy_optimizer.py +++ b/torch_xla/distributed/zero_redundancy_optimizer.py @@ -1,4 +1,5 @@ import copy +import logging from typing import (Any, Iterator, Optional, Type, Union, List, Dict) import torch @@ -9,6 +10,7 @@ import torch_xla import torch_xla.core.xla_model as xm +import torch_xla.runtime as xr class ZeroRedundancyOptimizer(Optimizer): @@ -32,17 +34,28 @@ class ZeroRedundancyOptimizer(Optimizer): collective ops (all_gather and reduce_scatter). See `xm.all_reduce` for details on pinning layout. Default: True sharding_groups (list, Optional): - If specified, ZeRO-1 will use this ``sharding_groups`` for all-gather - and reduce-scatter ops in full parameter construction and gradient - sharding. This can be useful for mixing ZeRO-1 with model parallelism - such as Megatron. + If specified, ZeRO-1 will use this ``sharding_groups`` for all-gather + and reduce-scatter ops in full parameter construction and gradient + sharding. This can be useful for mixing ZeRO-1 with model parallelism + such as Megatron. grad_norm_groups (list, Optional): - If specified, ZeRO-1 will use this ``grad_norm_groups`` for the - EXTRA all-reduce op in grad norm calculation. This can be model parallel - groups when mixing ZeRO-1 with model parallelism such as Megatron. - bucket_cap_mb: - If non-zero, specifies the maximum number of megabytes to combine tensors - before doing the all-gather/reduce-scatter operations. + If specified, ZeRO-1 will use this ``grad_norm_groups`` for the + EXTRA all-reduce op in grad norm calculation. This can be model parallel + groups when mixing ZeRO-1 with model parallelism such as Megatron. + lazy_init (bool, Optional): if ``True``, the class will not shard paramaters + during initialization. Users need to call ``optimizer.init_zero()`` by themselves. + Default: False + bucket_cap_mb_all_gather (int, Optional): Number of MegaBytes of the tensor bucket to fill before + doing all-gather. Default: False + bucket_cap_mb_reduce_scatter (int, Optional): Number of MegaBytes of the tensor bucket to fill before + doing reduce-scatter. Default: False + use_grad_acc_hook (bool, Optional): if ``True``, use hooks for gradients accumulation, then + ``dtype`` of grad accumulation will be the same as ``optimizer_dtype``. Users can set this + to True to use higher precision for gradients accumulation. Default: False + save_master_weights (bool, Optional): + if ``True``, also save sharded master weights. Default: False + higher_cc_precision (bool, Optional): if ``True``, use higher precision for collective communication + operators (the same as ``optimizer_dtype``). Default: False **defaults: any trailing arguments, which are forwarded to the local optimizer. @@ -65,12 +78,20 @@ def __init__( lazy_init: bool = False, bucket_cap_mb_all_gather: int = 0, bucket_cap_mb_reduce_scatter: int = 0, + use_grad_acc_hook: bool = False, + save_master_weights: bool = False, + higher_cc_precision: bool = False, **defaults: Any, ): + if not save_master_weights: + logging.warning( + 'Not saving master weights may have accuracy issues when resuming training!' + ) + super().__init__(params, defaults) - self.global_world_size = xm.xrt_world_size() - self.global_rank = xm.get_ordinal() + self.global_world_size = xr.global_device_count() + self.global_rank = xr.global_ordinal() self._sharding_groups = [list(range(self.global_world_size)) ] if sharding_groups is None else sharding_groups self._grad_norm_groups = grad_norm_groups @@ -85,6 +106,11 @@ def __init__( self.bucket_cap_mb_reduce_scatter = bucket_cap_mb_reduce_scatter self.coalesce_cc_all_gather = bucket_cap_mb_all_gather > 0 self.coalesce_cc_reduce_scatter = bucket_cap_mb_reduce_scatter > 0 + self.save_master_weights = save_master_weights + self.higher_cc_precision = higher_cc_precision + self.use_grad_acc_hook = use_grad_acc_hook + self.grad_accs = [] + self.grad_acc_hooks = [] self._grad_norm = None @@ -93,6 +119,7 @@ def __init__( self.init_zero() def init_zero(self): + self.remove_hooks() self.local_world_size = len(self.sharding_groups[0]) # Infer the local rank from the group self.local_rank = None @@ -169,6 +196,41 @@ def _shard_tensor(self, tensor: torch.Tensor): tensor = tensor.chunk(self.local_world_size)[self.local_rank] return tensor + def _make_param_hook(self, param, shard): + """ + Create the grad accumulation hook for backprop. + """ + + def _param_hook(*unused): + # Accumulates gradients on main gradients + if param.grad is not None: + if not hasattr(shard, 'main_grad'): + # Create main gradients + shard.main_grad = torch.zeros( + param.shape, + dtype=self.optimizer_dtype, + device=self.device, + requires_grad=False) + param.main_grad = shard.main_grad + shard.main_grad.add_(param.grad.data.to(dtype=self.optimizer_dtype)) + # Deallocate grad memory. + param.grad = None + + return _param_hook + + def _register_hook(self, param, shard): + param_tmp = param.expand_as(param) + grad_acc = param_tmp.grad_fn.next_functions[0][0] + hook = grad_acc.register_hook(self._make_param_hook(param, shard)) + self.grad_acc_hooks.append(hook) + self.grad_accs.append(grad_acc) + + def remove_hooks(self): + for hook in self.grad_acc_hooks: + hook.remove() + self.grad_acc_hooks = [] + self.grad_accs = [] + def _shard_parameters(self): """ Shard all parameters. @@ -196,6 +258,8 @@ def _shard_parameters(self): shard_data = shard_data.to(dtype=self.optimizer_dtype) shard_data = shard_data.to(device=self.device) # move to xla device shard = nn.Parameter(shard_data, requires_grad=param.requires_grad) + if self.use_grad_acc_hook: + self._register_hook(param, shard) sharded_params.append(shard) sharded_params_group = copy.copy(param_group) sharded_params_group['params'] = sharded_params @@ -282,9 +346,13 @@ def step(self, closure=None, **kwargs): self.param_groups, self.base_optimizer.param_groups): for param, shard in zip(param_group['params'], sharded_param_group['params']): - if param.grad is not None: - padded_grad = self._pad_to_world_size(param.grad, - self.local_world_size) + if param.grad is not None or (self.use_grad_acc_hook and + hasattr(shard, 'main_grad')): + padded_grad = self._pad_to_world_size( + shard.main_grad if self.use_grad_acc_hook else param.grad, + self.local_world_size) + if self.higher_cc_precision: + padded_grad = padded_grad.to(dtype=self.optimizer_dtype) if self.coalesce_cc_reduce_scatter: padded_grads.append(padded_grad) else: @@ -317,7 +385,8 @@ def step(self, closure=None, **kwargs): self.param_groups, self.base_optimizer.param_groups): for param, shard in zip(param_group['params'], sharded_param_group['params']): - if param.grad is not None: + if param.grad is not None or (self.use_grad_acc_hook and + hasattr(shard, 'main_grad')): grad_shard = grad_shards[index] if grad_shard.dtype != self.optimizer_dtype: grad_shard = grad_shard.to(dtype=self.optimizer_dtype) @@ -334,16 +403,24 @@ def step(self, closure=None, **kwargs): # Remove shards' grads self.base_optimizer.zero_grad(set_to_none=True) + self.allgather_weights_and_update_full_parameter() + + # sync back + self._sync_param_groups(self.base_optimizer.param_groups, self.param_groups) + + return loss + + def allgather_weights_and_update_full_parameter(self): # All gather the new weights across the ranks and assign them to the full parameters sharded_data = [] for param_group, sharded_param_group in zip( self.param_groups, self.base_optimizer.param_groups): for param, shard in zip(param_group['params'], sharded_param_group['params']): - if param.grad is not None: + if param.grad is not None or (self.use_grad_acc_hook and + hasattr(shard, 'main_grad')): shard_data = shard.data - if param.dtype != self.optimizer_dtype: - shard_data = shard_data.to(dtype=param.dtype) + shard_data = shard_data.to(dtype=param.dtype) if self.coalesce_cc_all_gather: sharded_data.append(shard_data) else: @@ -368,21 +445,34 @@ def step(self, closure=None, **kwargs): self.param_groups, self.base_optimizer.param_groups): for param, shard in zip(param_group['params'], sharded_param_group['params']): - if param.grad is not None: + if param.grad is not None or (self.use_grad_acc_hook and + hasattr(shard, 'main_grad')): padded_param = padded_params[index] param.data.copy_(padded_param.data[:param.size(0)]) index += 1 - # sync back - self._sync_param_groups(self.base_optimizer.param_groups, self.param_groups) - - return loss + def zero_grad(self, set_to_none: bool = False): + super().zero_grad(set_to_none=set_to_none) + if self.use_grad_acc_hook: + for sharded_param_group in self.base_optimizer.param_groups: + for shard in sharded_param_group['params']: + if hasattr(shard, 'main_grad'): + shard.main_grad.zero_() def state_dict(self): state_dict = super().state_dict() base_state = self.base_optimizer.state_dict()['state'] state_dict['base_state'] = base_state state_dict['shape_info'] = self.get_shape_info() + if self.save_master_weights: + index = 0 + master_weights = {} + for sharded_param_group in self.base_optimizer.param_groups: + for shard in sharded_param_group['params']: + master_weights[index] = shard.data + index += 1 + state_dict['sharded_master_weights'] = master_weights + return state_dict def load_state_dict(self, state_dict): @@ -396,6 +486,30 @@ def load_state_dict(self, state_dict): tmp = self.base_optimizer.state_dict() tmp['state'] = base_state self.base_optimizer.load_state_dict(tmp) + if 'sharded_master_weights' in state_dict: + master_weights = state_dict['sharded_master_weights'] + index = 0 + for param_group, sharded_param_group in zip( + self.param_groups, self.base_optimizer.param_groups): + for param, shard in zip(param_group['params'], + sharded_param_group['params']): + shard.data.copy_(master_weights[index]) + # set dummy gradient for allgather to be triggered. + if self.use_grad_acc_hook: + # Create main gradients + shard.main_grad = torch.zeros( + param.shape, + dtype=self.optimizer_dtype, + device=self.device, + requires_grad=False) + param.main_grad = shard.main_grad + else: + param.grad = torch.zeros_like(param.data) + index += 1 + xm.mark_step() + # add mark_step around allgather to avoid large number of compilation + self.allgather_weights_and_update_full_parameter() + xm.mark_step() def get_shape_info(self): shape_info = {}