Skip to content

Commit

Permalink
Error if device mesh specified in fsdp config (#3580)
Browse files Browse the repository at this point in the history
  • Loading branch information
snarayan21 committed Aug 27, 2024
1 parent 6f18ff8 commit 4e6606e
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 2 deletions.
23 changes: 21 additions & 2 deletions composer/utils/parallelism.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

"""Parallelism configs."""

from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Any, Optional

from torch.distributed._tensor.device_mesh import DeviceMesh
Expand All @@ -23,7 +23,6 @@ class FSDPConfig:
cpu_offload: bool = False
data_parallel_shard_degree: int = -1
data_parallel_replicate_degree: Optional[int] = None
device_mesh: Optional[DeviceMesh] = None
forward_prefetch: bool = False
forward_prefetch_limit: int = 1
ignored_modules: Optional[Any] = None
Expand All @@ -41,6 +40,26 @@ class FSDPConfig:
use_orig_params: bool = True
verbose: bool = False

_device_mesh: Optional[DeviceMesh] = field(default=None, init=False, repr=False)

def __init__(self, **kwargs):
if 'device_mesh' in kwargs or '_device_mesh' in kwargs:
raise ValueError(
f'Directly specifying device mesh for FSDP was deprecated in Composer version 0.24.0. ' +
f"Please specify 'data_parallel_shard_degree' and/or 'data_parallel_replicate_degree' instead.",
)

for k, v in kwargs.items():
setattr(self, k, v)

@property
def device_mesh(self) -> Optional[DeviceMesh]:
return self._device_mesh

@device_mesh.setter
def device_mesh(self, value: Optional[DeviceMesh]):
self._device_mesh = value


@dataclass
class TPConfig:
Expand Down
11 changes: 11 additions & 0 deletions tests/checkpoint/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,17 @@ def init_model(
if wrap_with_raw_fsdp:
model = FSDP(model, **fsdp_kwargs)
else:
if 'device_mesh' in fsdp_kwargs:
mesh = fsdp_kwargs.pop('device_mesh')
ndim = mesh.ndim
if ndim == 1:
fsdp_kwargs['data_parallel_shard_degree'] = mesh.size(0)
elif ndim == 2:
fsdp_kwargs['data_parallel_replicate_degree'] = mesh.size(0)
fsdp_kwargs['data_parallel_shard_degree'] = mesh.size(1)
else:
raise ValueError(f'Unsupported device mesh dimension: {ndim}')

prepare_fsdp_module(
model,
optimizers=None,
Expand Down
29 changes: 29 additions & 0 deletions tests/trainer/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from composer.models import ComposerClassifier, ComposerModel
from composer.trainer.trainer import Trainer, _fsdp_reshard_and_cleanup
from composer.utils import dist
from composer.utils.parallelism import FSDPConfig
from tests.common import (
EmbeddedWeightTiedModel,
RandomClassificationDataset,
Expand Down Expand Up @@ -551,6 +552,34 @@ def oom_hook(module, grad_input, grad_ouput):
assert torch.equal(output_1, output_2)


@pytest.mark.gpu
@world_size(2)
def test_fsdp_device_mesh(world_size: int):
model = SimpleModel()
model.fc1._fsdp_wrap = True # pyright: ignore[reportGeneralTypeIssues]
model.fc2._fsdp_wrap = True # pyright: ignore[reportGeneralTypeIssues]

# Expect error via pytest
with pytest.raises(ValueError, match='Directly specifying device mesh for FSDP was deprecated*'):
Trainer(
model=model,
parallelism_config={'fsdp': {
'device_mesh': [2],
}},
max_duration='3ba',
)


@pytest.mark.parametrize('error_key', ['device_mesh', '_device_mesh'])
def test_fsdp_config_device_mesh_error(error_key: str):
# Passing device mesh directly to FSDPConfig should raise an error
with pytest.raises(ValueError, match='Directly specifying device mesh for FSDP was deprecated*'):
cfg_dict = {
error_key: [2],
}
FSDPConfig(**cfg_dict)


@pytest.mark.gpu
@world_size(2)
def test_fsdp_shard(world_size: int):
Expand Down
2 changes: 2 additions & 0 deletions tests/trainer/test_fsdp_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -1333,6 +1333,8 @@ def test_fsdp_monolith_resumption(
resume_file = os.path.join(save_folder, 'first', resume_file)
model_init_device = [model_1_init_device, model_2_init_device][dist.get_global_rank()]
fsdp_config_dict = dataclasses.asdict(fsdp_config)
# Since device_mesh being passed to FSDPConfig is deprecated, remove it.
fsdp_config_dict.pop('_device_mesh', None)
fsdp_config_dict['load_monolith_rank0_only'] = True
fsdp_config = FSDPConfig(**fsdp_config_dict)

Expand Down

0 comments on commit 4e6606e

Please sign in to comment.