Skip to content

Commit

Permalink
Introduce torch_xla.launch() (#7648)
Browse files Browse the repository at this point in the history
  • Loading branch information
zpcore committed Jul 12, 2024
1 parent 75b10b7 commit 5b8e8e0
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 10 deletions.
10 changes: 7 additions & 3 deletions examples/data_parallel/train_resnet_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.optim as optim
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla
import torch_xla.core.xla_model as xm


class TrainResNetDDP(TrainResNetBase):

def __init__(self):
super().__init__()
dist.init_process_group('xla', init_method='xla://')
super().__init__()
self.model = DDP(
self.model, gradient_as_bucket_view=True, broadcast_buffers=False)
self.optimizer = optim.SGD(self.model.parameters(), weight_decay=1e-4)
Expand All @@ -26,5 +28,7 @@ def _mp_fn(index):


if __name__ == '__main__':
print('consider using train_resnet_spmd_data_parallel.py instead to get better performance')
xmp.spawn(_mp_fn, args=())
print(
'consider using train_resnet_spmd_data_parallel.py instead to get better performance'
)
torch_xla.launch(_mp_fn)
9 changes: 3 additions & 6 deletions test/test_train_mp_imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,7 @@
import torch_xla.distributed.parallel_loader as pl
import torch_xla.debug.profiler as xp
import torch_xla.utils.utils as xu
import torch_xla.core.xla_env_vars as xenv
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.test.test_utils as test_utils

import torch.distributed as dist
Expand Down Expand Up @@ -375,7 +373,6 @@ def _mp_fn(index, flags):


if __name__ == '__main__':
if dist.is_torchelastic_launched():
_mp_fn(xu.getenv_as(xenv.LOCAL_RANK, int), FLAGS)
else:
xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=FLAGS.num_cores)
# if running with torchrun, nprocs argument will be omitted.
debug_single_process = True if FLAGS.num_cores == 1 else False
torch_xla.launch(_mp_fn, args=(FLAGS,), debug_single_process=True)
31 changes: 30 additions & 1 deletion torch_xla/torch_xla.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
import contextlib
from typing import List
from typing import Callable, List, Tuple

import torch
import torch.distributed as dist
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.core.xla_env_vars as xenv
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.runtime as xr
import torch_xla.utils.utils as xu


def device(index: int = None) -> torch.device:
Expand Down Expand Up @@ -80,3 +85,27 @@ def manual_seed(seed, device=None):
If missing the default device seed will be set.
"""
xm.set_rng_state(seed, device)


def launch(
fn: Callable,
args: Tuple = (),
start_method: str = 'spawn',
debug_single_process: bool = False,
):
""" Entry to launch multiprocess.
Raises:
NotImplementedError: SPMD is not supported yet.
"""
if xr.is_spmd():
# TODO(piz): SPMD is specified differently from mp. Skip for now.
raise NotImplementedError(
'launch function does not support SPMD at this time')

nprocs = 1 if debug_single_process else None

if dist.is_torchelastic_launched():
fn(xu.getenv_as(xenv.LOCAL_RANK, int), *args)
else:
xmp.spawn(fn, args=args, nprocs=nprocs, start_method=start_method)

0 comments on commit 5b8e8e0

Please sign in to comment.