Skip to content

Commit

Permalink
Introduce multi-node SPMD initialization for Neuron
Browse files Browse the repository at this point in the history
  • Loading branch information
rpsilva-aws committed Sep 30, 2024
1 parent d0ea5cc commit eaef0b8
Show file tree
Hide file tree
Showing 3 changed files with 170 additions and 45 deletions.
5 changes: 1 addition & 4 deletions torch_xla/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,7 @@ def _aws_ec2_inf_trn_init():
# Basic initializations if torch-neuronx is not available
from ._internal import neuron
if os.path.basename(sys.argv[0]) != 'neuron_parallel_compile':
import libneuronxla
libneuronxla.configure_environment()
neuron.set_envvar_defaults()
neuron.configure_pjrt_environment()
neuron.initialize()
else:
xla.init()
# Found libneuronxla
Expand Down
202 changes: 161 additions & 41 deletions torch_xla/_internal/neuron.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging

from torch_xla.experimental import plugins
from torch_xla import runtime as xr

import sys
import torch.distributed as dist
Expand All @@ -10,6 +11,15 @@

logging.basicConfig()
logger = logging.getLogger(__name__)
# Singleton initializer to ensure that the initialization is only set once.
_initializer = None


def initialize():
global _initializer
if not _initializer:
_initializer = Initializer()
_initializer.reset()


# Set root communication address/port
Expand Down Expand Up @@ -37,51 +47,161 @@ def set_rt_root_comm_id():
os.environ['NEURON_RT_ROOT_COMM_ID'] = '{}:{}'.format(root_addr, root_port)


def set_envvar_defaults():
def _set_envvar_defaults():
os.environ.setdefault('ALLREDUCE_GRADIENTS_BUCKET_SIZE_MB', '50')


def configure_pjrt_environment():
"""
Setting all necessary PJRT default environment variables.
"""
from torch.distributed import is_torchelastic_launched

# Set root communication address/port
set_rt_root_comm_id()

# Set env variables if we don't use GSPMD, using PJRT, and using torchrun
if os.environ.get('XLA_USE_SPMD', '0') != '1' \
and is_torchelastic_launched():
# Env variables that only need to be set once
# NEURON_PJRT_PROCESSES_NUM_DEVICES is a list of core counts and is too long for very large cluster,
# so use NEURON_PJRT_WORLD_SIZE to pass world size and use core count of 1 per process in PJRT client.
if 'NEURON_PJRT_PROCESSES_NUM_DEVICES' not in os.environ and 'NEURON_PJRT_WORLD_SIZE' not in os.environ:
if 'WORLD_SIZE' not in os.environ:
class Initializer():
"""
Initializer class that manages the initialization for torch. It cohesively
guarantees that the environment is correctly configured for both SPMD and
non-SPMD use cases. Note that in case SPMD is enabled, the initialization
requires reconfiguring the environment, as this follows the default
initialization.
"""

# Whether the PJRT environment has already been configured.
configured_pjrt_env = False
# The previous state of the PJRT environment before the latest
# configuration.
previous_pjrt_env_vars = {}

def __init__(self):
import libneuronxla
libneuronxla.configure_environment()
_set_envvar_defaults()
# Environment agnostic PJRT configurations that only need to be set once.
self._initialize_pjrt_ranks()

def reset(self):
if self.configured_pjrt_env:
self.__clear_previous_pjrt_env_vars()
assert not (self.previous_pjrt_env_vars or self.configured_pjrt_env)
self._configure_pjrt_environment()
self.configured_pjrt_env = True

def _initialize_pjrt_ranks(self):
"""
Initialize the PJRT specific ranks for torch.
"""
if 'RANK' not in os.environ:
logger.warning('RANK environment variable is not set, defaulting to 0.')
self.__set_envvar_defaulted_and_save('NEURON_PJRT_PROCESS_INDEX', 'RANK', '0')
os.environ['NEURON_PJRT_PROCESS_INDEX'] = os.environ.get('RANK', '0')
if 'LOCAL_RANK' not in os.environ:
logger.warning(
'LOCAL RANK environment variable is not set to 0, defaulting to 0.')
self.__set_envvar_defaulted_and_save('PJRT_LOCAL_PROCESS_RANK', 'LOCAL_RANK', '0')

def _configure_pjrt_environment(self):
"""
Setting all necessary PJRT default environment variables. There are currently two schemes:
- __configure_non_spmd_environment, for the non-SPMD setup.
- __configure_spmd_environment, for the SPMD setup.
"""
def __configure_non_spmd_environment():
"""
Setting all necessary PJRT environment variables for non-SPMD::
1) NEURON_PJRT_PROCESSES_NUM_DEVICES: `X,Y,Z` will denote X, Y and Z worker processes, each with
one addressable device.
2) NEURON_PJRT_WORLD_SIZE: This will denote the total number of worker processes, each with one
addressable device. For instance, '8' will expand to '1,1,1,1,1,1,1,1'.
3) NEURON_RT_VISIBLE_CORES: The specified visible cores are unwrapped and assigned to the
corresponding local rank in order associated with its index.
4) Default behavior:
* NEURON_PJRT_WORLD_SIZE is overwritten to WORLD_SIZE, denoting the global number of participating
devices.
* PJRT_LOCAL_PROCESS_COUNT is overwritten to LOCAL_WORLD_SIZE, denoting the number of local
participating processes.
"""
# NEURON_PJRT_PROCESSES_NUM_DEVICES is a list of core counts and is too long for very large cluster,
# so use NEURON_PJRT_WORLD_SIZE to pass world size and use core count of 1 per process in PJRT client.
if 'NEURON_PJRT_PROCESSES_NUM_DEVICES' not in os.environ and 'NEURON_PJRT_WORLD_SIZE' not in os.environ:
if 'WORLD_SIZE' not in os.environ:
logger.warning(
'WORLD_SIZE environment variable not set, defaulting to 1.')
self.__set_envvar_defaulted_and_save("NEURON_PJRT_WORLD_SIZE", "WORLD_SIZE", "1")
if 'LOCAL_WORLD_SIZE' not in os.environ:
logger.warning(
'LOCAL_WORLD_SIZE environment variable not set, defaulting to 1.')
self.__set_envvar_defaulted_and_save("PJRT_LOCAL_PROCESS_COUNT", "LOCAL_WORLD_SIZE", "1")
visible_cores = get_visible_cores_list()
self.__set_envvar_defaulted_and_save('NEURON_RT_VISIBLE_CORES',
'LOCAL_RANK', '0' if not visible_cores else visible_cores)

def __configure_spmd_environment():
"""
Setting all necessary PJRT environment variables for SPMD:
1) NEURON_PJRT_PROCESSES_NUM_DEVICES: `X,Y,Z` will denote X, Y and Z addressable devices
for the single worker process in the respective three node.
2) Default behaviors
* Single-node:
Use a single worker process that has all visible neuron cores:
* NEURON_RT_VISIBLE_CORES / NEURON_RT_VISIBLE_CORES if specified
* Otherwise, all available neuron cores in the instance.
* Multi-node:
No default support, requires 1)
"""
# In SPMD XRT, 'WORLD_SIZE' represents the global number of participant nodes.
if 'WORLD_SIZE' not in os.environ:
logger.warning(
'WORLD_SIZE environment variable not set, defaulting to 1.')
if 'LOCAL_WORLD_SIZE' not in os.environ:
logger.warning(
'LOCAL_WORLD_SIZE environment variable not set, defaulting to 1.')
self.__set_envvar_defaulted_and_save('PJRT_LOCAL_PROCESS_COUNT', 'LOCAL_WORLD_SIZE', '1')

# 'NEURON_PJRT_PROCESSES_NUM_DEVICES' is required for multi-node support.
assert (os.environ.get('WORLD_SIZE', '1') == '1' or 'NEURON_PJRT_PROCESSES_NUM_DEVICES' in os.environ), (
'NEURON_PJRT_PROCESSES_NUM_DEVICES environment variable not set. This is required to enable '
'multi-node SPMD.')
if 'NEURON_RT_VISIBLE_CORES' in os.environ:
# In SPMD, we do not remap the visible cores based on the local work rank, but instead
# just unwrap the visible cores if specified.
self.__set_envvar_defaulted_and_save('NEURON_RT_VISIBLE_CORES',
None, get_visible_cores_list())

from torch.distributed import is_torchelastic_launched
# If not using XRT, then do not set the environment variables. In this
# case, the environment variables are initialized in the default
# initializer with `initialize_env`.
if not is_torchelastic_launched():
return

if xr.is_spmd():
__configure_spmd_environment()
else:
__configure_non_spmd_environment()

def __clear_previous_pjrt_env_vars(self):
"""
Reset the environment variables for the PJRT environment to its former
state.
"""
assert self.configured_pjrt_env
logger.warning(
'WORLD_SIZE environment variable not set, defaulting to 1.')
os.environ["NEURON_PJRT_WORLD_SIZE"] = os.environ.get("WORLD_SIZE", "1")
if 'LOCAL_WORLD_SIZE' not in os.environ:
logger.warning(
'LOCAL_WORLD_SIZE environment variable not set, defaulting to 1.')
os.environ['PJRT_LOCAL_PROCESS_COUNT'] = os.environ.get(
'LOCAL_WORLD_SIZE', '1')

# Env variables that need to be set once per process
if not os.environ.get('NEURON_RT_VISIBLE_CORES', None):
os.environ['NEURON_RT_VISIBLE_CORES'] = os.environ.get('LOCAL_RANK', '0')
else:
local_rank = int(os.environ.get('LOCAL_RANK', '0'))
local_world_size = int(os.environ.get('LOCAL_WORLD_SIZE', '1'))
remap_visible_cores(local_rank, local_world_size)

if 'RANK' not in os.environ:
logger.warning('RANK environment variable is not set, defaulting to 0.')
os.environ['NEURON_PJRT_PROCESS_INDEX'] = os.environ.get('RANK', '0')
if 'LOCAL_RANK' not in os.environ:
logger.warning(
'LOCAL RANK environment variable is not set, defaulting to 0.')
os.environ['PJRT_LOCAL_PROCESS_RANK'] = os.environ.get('LOCAL_RANK', '0')
'Reinitializing the PJRT environment.')
if self.previous_pjrt_env_vars:
# Reset the environment to a clean state
for key, previous_val in self.previous_pjrt_env_vars.items():
os.environ[key] = previous_val
self.previous_pjrt_env_vars = {}
self.configured_pjrt_env = False

def __set_envvar_defaulted_and_save(self, key_to, key_from, default_value):
"""
This is used to set a default value for an environment variable if it
is not already set, and then save the original value of the environment
variable to track its state in case we require re-initializing the
environment.
"""
if callable(default_value):
default_value = default_value()
value = os.environ.get(key_from, default_value) if key_from else default_value
if key_to in os.environ and os.environ[key_to] != value:
logger.debug(f"{key_to} environment variable is set, overriding to {value}.")
os.environ[key_to] = value
self.previous_pjrt_env_vars[key_to] = value


def num_local_processes() -> int:
Expand Down
8 changes: 8 additions & 0 deletions torch_xla/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,14 @@ def use_spmd(auto: Optional[bool] = False):
torch_xla._XLAC._xla_set_auto_sharding()
os.environ["XLA_AUTO_SPMD"] = "1"

if runtime.device_type() == 'NEURON':
# In case of Neuron, reset the initialization environment to accommodate SPMD.
try:
from torch_neuronx.initialization import initialize
except ImportError:
from ._internal.neuron import initialize
initialize()


def is_spmd():
"""Returns if SPMD is set for execution."""
Expand Down

0 comments on commit eaef0b8

Please sign in to comment.