Skip to content

Commit

Permalink
Use reduce-scatter coalescing for FSDP (#6024)
Browse files Browse the repository at this point in the history
Co-authored-by: Mason Fu <fuxinwe@amazon.com>
  • Loading branch information
jeffhataws and Mason Fu committed Aug 2, 2024
1 parent d29b761 commit 7fe070a
Show file tree
Hide file tree
Showing 3 changed files with 166 additions and 52 deletions.
87 changes: 73 additions & 14 deletions torch_xla/distributed/fsdp/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,21 +60,80 @@ def dummy_all_reduce(reduce_type, inputs, scale=1.0, groups=None):
return [t.mul_(scale) for t in inputs]


def dummy_reduce_scatter(reduce_type,
input,
scale,
scatter_dim,
shard_count,
groups=None):
class DummyReduceScatter:
"""A dummy op for debugging with the same output shape as reduce_scatter"""
assert shard_count == xr.world_size()
full_size = input.size(scatter_dim)
shard_size = full_size // xr.world_size()
begin = shard_size * xr.global_ordinal()
end = begin + shard_size
slices = [None] * input.dim()
slices[scatter_dim] = slice(begin, end)
return input[tuple(slices)] * scale

def __init__(self, shard_count):
assert shard_count == xm.xrt_world_size()
self.scale = 1.0

def __call__(self, input, callback):
full_size = input.size(0)
shard_size = full_size // xm.xrt_world_size()
begin = shard_size * xm.get_ordinal()
end = begin + shard_size
slices = [None] * input.dim()
slices[0] = slice(begin, end)
callback(input[tuple(slices)])

def flush(self):
pass


class BucketizedReduceScatter:
"""A reduce_scatter op that group input tensors before reduce-scattering them."""

def __init__(self, bucket_size_mb, shard_count, groups, pin_layout) -> None:
self.bucket_size_bytes = bucket_size_mb * 1024 * 1024
self.shard_count = shard_count
self.groups = groups
self.pin_layout = pin_layout
self.scale = 1.0

self.callbacks = []
self.bucket = []
self.bucket_watermark = 0

def __call__(self, input, callback):
input_byte_size = input.element_size() * input.numel()
self.bucket.append(input)
self.callbacks.append(callback)
self.bucket_watermark += input_byte_size
# If bucket_size_mb is default 0, flush for every tensor rather than coalesce
if self.bucket_watermark > self.bucket_size_bytes:
self.flush()

def flush(self):
if not self.bucket:
return
# TODO: debug coalesce error "" for GPU when pin_layout=True.
# For now, workaround by using the non-coalesce version of reduce-scatter
# when there's only 1 tensor input (bucket_size_mb=0).
if len(self.bucket) == 1:
result = xm.reduce_scatter(
xm.REDUCE_SUM,
self.bucket[0],
scale=self.scale,
scatter_dim=0,
shard_count=self.shard_count,
groups=self.groups,
pin_layout=self.pin_layout)
self.callbacks[0](result)
else:
results = xm.reduce_scatter(
xm.REDUCE_SUM,
self.bucket,
scale=self.scale,
scatter_dim=0,
shard_count=self.shard_count,
groups=self.groups,
pin_layout=self.pin_layout)
for cb, result in zip(self.callbacks, results):
cb(result)

self.bucket.clear()
self.callbacks.clear()
self.bucket_watermark = 0


class XLAPatchedLinear(torch.autograd.Function):
Expand Down
118 changes: 80 additions & 38 deletions torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,14 @@
import torch_xla.core.xla_model as xm

from .xla_flatten_params_wrapper import XlaFlattenParamsWrapper
from .utils import dummy_all_gather, dummy_all_reduce, dummy_reduce_scatter, apply_xla_patch_to_nn_linear
from .utils import (
BucketizedReduceScatter,
DummyReduceScatter,
dummy_all_gather,
dummy_all_reduce,
apply_xla_patch_to_nn_linear,
)

from .wrap import recursive_wrap
from ._init_utils import _materialize_module

Expand All @@ -45,6 +52,8 @@
XLA_DISABLE_FUNCTIONALIZATION = bool(
os.environ.get('XLA_DISABLE_FUNCTIONALIZATION', False))

from torch_xla.utils.checkpoint import chkpt_status

FLOAT_DTYPES = [torch.float32, torch.float16, torch.bfloat16]


Expand Down Expand Up @@ -296,6 +305,7 @@ def __init__(
sharding_world_size: Optional[int] = None,
shard_param_on_dim_0: bool = False,
pin_layout_in_collective_ops: bool = True,
reduce_scatter_bucket_size_mb: Optional[int] = 0,
coalesce_all_gather_ops: bool = False,
auto_wrap_policy: Optional[Callable] = None,
auto_wrapper_callable: Optional[Callable] = None,
Expand Down Expand Up @@ -399,6 +409,20 @@ def __init__(
# When `_shard_param_on_dim_0` is True, we shard and all-gather model parameter tensors
# only along their dim 0 without flattening the parameter
self._shard_param_on_dim_0 = shard_param_on_dim_0 and not flatten_parameters
# Allow specifying groups for the sharding collective ops, useful for mixing
# FSDP data parallelism with model parallelism (e.g. Megatron)
self.sharding_groups = sharding_groups
if sharding_groups is None:
self.rank = xm.get_ordinal()
self.world_size = xm.xrt_world_size()
else:
if sharding_rank is None or sharding_world_size is None:
raise ValueError(
"sharding_rank and sharding_world_size must be provided when sharding_groups is specified"
)
self.rank = sharding_rank
self.world_size = sharding_world_size

self.coalesce_all_gather_ops = coalesce_all_gather_ops
# Set layout pinning to False in all_gather, all_reduce, and reduce_scatter so that they can work together
# TODO (ronghanghu): change the default layout pinning to True after it's supported simultaneously
Expand All @@ -414,10 +438,13 @@ def __init__(
self.all_reduce_op = functools.partial(
xm.all_reduce, pin_layout=pin_layout_in_collective_ops)
if _debug_dummy_reduce_scatter_op:
self.reduce_scatter_op = dummy_reduce_scatter
self.reduce_scatter_op = DummyReduceScatter(shard_count=self.world_size)
else:
self.reduce_scatter_op = functools.partial(
xm.reduce_scatter, pin_layout=pin_layout_in_collective_ops)
self.reduce_scatter_op = BucketizedReduceScatter(
reduce_scatter_bucket_size_mb,
shard_count=self.world_size,
groups=self.sharding_groups,
pin_layout=pin_layout_in_collective_ops)
if _debug_dummy_optimization_barrier_op:
self.optimization_barrier_op = lambda *args: None
else:
Expand Down Expand Up @@ -555,6 +582,10 @@ def set_gradient_divide_factors(self, pre: float, post: float,
module.set_gradient_divide_factors(pre, post, False)
self.gradient_predivide_factor = pre
self.gradient_postdivide_factor = post
if (pre, post) == (1, 1):
self.reduce_scatter_op.scale = 1.0 / self.world_size
else:
self.reduce_scatter_op.scale = 1.0

@property
def module(self) -> XlaFlattenParamsWrapper:
Expand Down Expand Up @@ -943,7 +974,10 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
# This can be used to debug FSDP parameter memory consumption.
outputs = self._dummy_forward(*args, **kwargs)

if self.reshard_after_forward:
# Allgather reduction optimization: if this forward is a recompute forward
# in checkpoint, then we do not reshard here, so that the following backward
# does not need to do the allgather
if self.reshard_after_forward and not chkpt_status.in_chkpt_bwd:
output_opt_barrier_tensors = []
if self.optimization_barrier_in_forward:
# Ensure that the full parameters of this FSDP module are freed
Expand Down Expand Up @@ -1060,7 +1094,9 @@ def _pre_backward_hook(t_grad: torch.Tensor) -> None:
# All-gather full parameters or switching to the full params.
# Note, ``self._rebuild_full_params`` is idempotent. So in case it is called
# unnecessarily, it doesn't incur much overhead.
if self.reshard_after_forward:
# Allgather reduction optimization: if this backward is in checkpoint, then we
# do not allgather here, since the previous recompute forward does not reshard
if self.reshard_after_forward and not chkpt_status.in_chkpt_bwd:
dependency_tensors = []
if self.optimization_barrier_in_backward:
# Ensure that backward pass ops of feature gradients, parameter
Expand Down Expand Up @@ -1145,6 +1181,7 @@ def _register_post_backward_hooks(self) -> None:
"""
if not torch.is_grad_enabled():
return # don't register grad hooks if grad isn't enabled
self._post_backward_hooks_to_call = 0
for p in self.full_params:
if p.requires_grad:
if hasattr(p, "_shard_bwd_hook"):
Expand All @@ -1158,6 +1195,7 @@ def _register_post_backward_hooks(self) -> None:
handle = grad_acc.register_hook(
functools.partial(self._post_backward_hook, p))
p._shard_bwd_hook = (grad_acc, handle)
self._post_backward_hooks_to_call += 1

@torch.no_grad()
def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
Expand All @@ -1184,7 +1222,10 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
# then subsequent hook callbacks will see POST state.
self.assert_state([TrainingState.BACKWARD_PRE, TrainingState.BACKWARD_POST])
self.training_state = TrainingState.BACKWARD_POST
self._post_backward_hooks_to_call -= 1
if param.grad is None:
if self._post_backward_hooks_to_call == 0:
self.reduce_scatter_op.flush()
return

assert param.grad is not None, param.shape
Expand All @@ -1205,6 +1246,8 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
apply_opt_barrier=self.optimization_barrier_in_backward)

if not self._require_backward_grad_sync:
if self._post_backward_hooks_to_call == 0:
self.reduce_scatter_op.flush()
return

if self.gradient_predivide_factor > 1:
Expand All @@ -1220,38 +1263,37 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
self.optimization_barrier_op([grad_flat])
if grad_flat.dtype != torch.float32 and self.fp32_reduce_scatter:
grad_flat = grad_flat.to(torch.float32)
reduced_grad = self.reduce_scatter_op(
xm.REDUCE_SUM,
grad_flat.detach(),
scale=1.0,
scatter_dim=0,
shard_count=self.world_size,
groups=self.sharding_groups)
if reduced_grad.dtype != torch.float32:
reduced_grad = reduced_grad.to(torch.float32)
if self.optimization_barrier_in_backward:
self.optimization_barrier_op([reduced_grad])
if self.gradient_postdivide_factor > 1:
# Average grad by world_size for consistency with PyTorch DDP.
reduced_grad.div_(self.gradient_postdivide_factor)

grad._has_full_param = True
grad_flat._has_full_param = True
self._free_full_params(
[grad, grad_flat],
dependency_tensors=[reduced_grad],
apply_opt_barrier=self.optimization_barrier_in_backward)
self._try_adding_to_backward_opt_barrier_lists(reduced_grad)

# Accumulate into the gradient shard.
assert hasattr(param, "_sharded_param")
p_shard = param._sharded_param
if p_shard.grad is None:
p_shard.grad = reduced_grad
else:
assert p_shard.grad.shape == reduced_grad.shape
assert p_shard.grad.device == reduced_grad.device
p_shard.grad += reduced_grad

def reduce_scatter_done(reduced_grad):
if reduced_grad.dtype != torch.float32:
reduced_grad = reduced_grad.to(torch.float32)
if self.optimization_barrier_in_backward:
self.optimization_barrier_op([reduced_grad])
if self.gradient_postdivide_factor > 1:
# Average grad by world_size for consistency with PyTorch DDP.
reduced_grad.div_(self.gradient_postdivide_factor)

grad._has_full_param = True
grad_flat._has_full_param = True
self._free_full_params(
[grad, grad_flat],
dependency_tensors=[reduced_grad],
apply_opt_barrier=self.optimization_barrier_in_backward)
self._try_adding_to_backward_opt_barrier_lists(reduced_grad)

# Accumulate into the gradient shard.
assert hasattr(param, "_sharded_param")
p_shard = param._sharded_param
if p_shard.grad is None:
p_shard.grad = reduced_grad
else:
assert p_shard.grad.shape == reduced_grad.shape
assert p_shard.grad.device == reduced_grad.device
p_shard.grad += reduced_grad

self.reduce_scatter_op(grad_flat.detach(), reduce_scatter_done)
if self._post_backward_hooks_to_call == 0:
self.reduce_scatter_op.flush()

def _queue_wait_for_post_backward(self) -> None:
"""
Expand Down
13 changes: 13 additions & 0 deletions torch_xla/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,16 @@ def set_device_states(devices: List[torch.device],
device_module.set_rng_state(state, device=device)


class CheckpointStatus:

def __init__(self):
self.in_chkpt_bwd = False


# chkpt_status is used for FSDP allgather reduction optimizaion
chkpt_status = CheckpointStatus()


class CheckpointFunction(torch.autograd.Function):

def _extract_tensors_from_list(inputs):
Expand Down Expand Up @@ -123,6 +133,8 @@ def forward(ctx, run_function, preserve_rng_state, *args):

@staticmethod
def backward(ctx, *args):
chkpt_status.in_chkpt_bwd = True

if not torch.autograd._is_checkpoint_valid():
raise RuntimeError(
"Checkpointing is not compatible with .grad() or when an `inputs` parameter"
Expand Down Expand Up @@ -190,6 +202,7 @@ def backward(ctx, *args):
inp.grad if isinstance(inp, torch.Tensor) else None
for inp in detached_inputs)

chkpt_status.in_chkpt_bwd = False
return (None, None) + grads


Expand Down

0 comments on commit 7fe070a

Please sign in to comment.