From 3cd46778cc170cf744d8ed1a7eac095e07064f5b Mon Sep 17 00:00:00 2001 From: jeffhataws Date: Wed, 25 Sep 2024 16:16:22 -0700 Subject: [PATCH] Fix nprocs description for xmp.spawn; error if nprocs != 1 or None (#7971) --- torch_xla/_internal/pjrt.py | 8 ++++++-- torch_xla/distributed/xla_multiprocessing.py | 4 +++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/torch_xla/_internal/pjrt.py b/torch_xla/_internal/pjrt.py index 4f889516407..71d9ddc4865 100644 --- a/torch_xla/_internal/pjrt.py +++ b/torch_xla/_internal/pjrt.py @@ -194,7 +194,9 @@ def spawn(fn: Callable, Args: fn: Callable that takes the process index as the first argument. nprocs (int): The number of processes/devices for the replication. At the - moment, if specified, can be either 1 or the maximum number of devices. + moment, if specified, can be either 1 or None (which would automatically + converted to the maximum number of devices). Other numbers would result + in ValueError. args: args to pass to `fn` start_method: The Python `multiprocessing` process creation method. Default: `spawn` @@ -204,7 +206,9 @@ def spawn(fn: Callable, if nprocs == 1: return _run_singleprocess(spawn_fn) elif nprocs is not None: - logging.warning('Unsupported nprocs (%d), ignoring...' % nprocs) + raise ValueError( + 'Unsupported nprocs (%d). Please use the environment variable for the hardware you are using (X_NUM_DEVICES where X is CPU, GPU, TPU, NEURONCORE, etc).' + % nprocs) run_multiprocess(spawn_fn, start_method=start_method) diff --git a/torch_xla/distributed/xla_multiprocessing.py b/torch_xla/distributed/xla_multiprocessing.py index 4a59f30d169..9e7dbc879a2 100644 --- a/torch_xla/distributed/xla_multiprocessing.py +++ b/torch_xla/distributed/xla_multiprocessing.py @@ -19,7 +19,9 @@ def spawn(fn, args (tuple): The arguments for `fn`. Default: Empty tuple nprocs (int): The number of processes/devices for the replication. At the - moment, if specified, can be either 1 or the maximum number of devices. + moment, if specified, can be either 1 or None (which would automatically + converted to the maximum number of devices). Other numbers would result + in ValueError. join (bool): Whether the call should block waiting for the completion of the processes which have being spawned. Default: True