Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support FSDPv2 compute dtype #8056

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion torch_xla/experimental/spmd_fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch_xla.core.xla_model as xm
import torch_xla.distributed.spmd as spmd
from torch_xla.distributed.fsdp.wrap import recursive_wrap
from torch_xla.distributed.fsdp.xla_fully_sharded_data_parallel import _cast_floats_tensors, FLOAT_DTYPES


def _prepare_spmd_partition_spec(param,
Expand All @@ -36,7 +37,7 @@ def _prepare_spmd_partition_spec(param,

class SpmdFullyShardedDataParallel(nn.Module):
"""
This is an experiemntal implementation of rewriting FullyShardedDataParallel using SPMD.
This is an experimental implementation of rewriting FullyShardedDataParallel using SPMD.
The usage is similar to FSDP, but with some subtle differences args.

Args:
Expand All @@ -46,6 +47,10 @@ class SpmdFullyShardedDataParallel(nn.Module):
The callable should have the signature (output, mesh) -> None.
If None, the default implementation will shard the first tensor in the output.
If the output is a tuple, only the first tensor will be sharded.
compute_dtype (torch.dtype, Optional):
dtype for full parameters for computation. This defaults to
``torch.float32`` but can be set to ``torch.float16`` or
``torch.bfloat16``. The sharded parameters will always be in FP32.
"""

def __init__(
Expand All @@ -54,6 +59,7 @@ def __init__(
*,
mesh: Optional[spmd.Mesh] = None,
shard_output: Optional[Callable] = None,
compute_dtype: Optional[torch.dtype] = None,
auto_wrap_policy: Optional[Callable] = None,
auto_wrapper_callable: Optional[Callable] = None,
extra_data_axis: Optional[str] = None,
Expand Down Expand Up @@ -107,6 +113,11 @@ def __init__(
)
self._auto_wrap(auto_wrap_kwargs, fsdp_kwargs)

if compute_dtype is not None and compute_dtype not in FLOAT_DTYPES:
raise ValueError(
f"compute_dtype must be one of {FLOAT_DTYPES}, not {compute_dtype}")
self.compute_dtype = compute_dtype or torch.float32

# Let's move the module to xla device in case it's not moved
# by the caller already.
self._orig_module = module.to(xm.xla_device())
Expand Down Expand Up @@ -157,6 +168,9 @@ def module(self) -> nn.Module:
return self._orig_module

def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
if self.compute_dtype != torch.float32:
# Cast the input float tensors to the specified compute_dtype
args, kwargs = _cast_floats_tensors(self.compute_dtype, *args, **kwargs)
output = self.module(*args, **kwargs)
# Need to shard the output of the forward to instruct the compiler
# to enforce the FSDP algorithm.
Expand Down