Skip to content

Commit

Permalink
relocate world_size()
Browse files Browse the repository at this point in the history
  • Loading branch information
zpcore committed Jul 17, 2024
1 parent 949777f commit 8be771c
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 40 deletions.
2 changes: 1 addition & 1 deletion torch_xla/_internal/pjrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def _thread_fn(device: torch.device):
torch_xla._XLAC._xla_set_default_device(device)

# See Note Note [Dynamo WORLD_SIEZ and ORDINAL].
xm._init_world_size_ordinal()
runtime._init_world_size_ordinal()

return fn()

Expand Down
67 changes: 30 additions & 37 deletions torch_xla/core/xla_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import torch_xla.utils.closures as xc
import os
from torch_xla.experimental.deprecation import deprecated
from . import xla_model as this_module

_DEVICES = xu.LazyProperty(lambda: torch_xla._XLAC._xla_get_devices())

Expand All @@ -33,27 +34,22 @@
_DEVICE_CONTEXTS = dict()
_DEVICE_CONTEXTS_LOCK = threading.Lock()

# Note [Dynamo WORLD_SIEZ and ORDINAL]
# Belows are workaround to cache the ordinal and world_size such that
# Dynamo won't do graph breaks when xm.pjrt_world_size() and xm.get_ordinal() are called.
_WORLD_SIZE = None
_ORDINAL = None

XLA_LIB = Library("xla", "DEF")

xrt_world_size = deprecated(torch_xla.core, torch_xla.runtime.world_size)
xrt_world_size = deprecated(this_module, torch_xla.runtime.world_size)
get_ordinal = deprecated(this_module, torch_xla.runtime.get_ordinal)

# def _init_world_size_ordinal():
# global _WORLD_SIZE, _ORDINAL

def _init_world_size_ordinal():
global _WORLD_SIZE, _ORDINAL
# # Dynamo doesn't support XRT or multithreaded runtime. See Note [V3-8 Threading]
# if not runtime.using_pjrt() or runtime.addressable_device_count() > 1:
# return

# Dynamo doesn't support XRT or multithreaded runtime. See Note [V3-8 Threading]
if not runtime.using_pjrt() or runtime.addressable_device_count() > 1:
return

if _WORLD_SIZE is None:
_WORLD_SIZE = runtime.world_size()
_ORDINAL = get_ordinal()
# if _WORLD_SIZE is None:
# _WORLD_SIZE = runtime.world_size()
# _ORDINAL = get_ordinal()


class DeviceContext(object):
Expand Down Expand Up @@ -118,24 +114,24 @@ def get_xla_supported_devices(devkind=None, max_devices=None):
return kind_devices[:max_devices] if max_devices else kind_devices


def get_ordinal(defval=0):
"""Retrieves the replication ordinal of the current thread.
# def get_ordinal(defval=0):
# """Retrieves the replication ordinal of the current thread.

The ordinals range from 0 to `runtime.world_size()` minus 1.
# The ordinals range from 0 to `runtime.world_size()` minus 1.

Args:
defval (int, optional): The default value to be returned in case there is no
replication information available. Ignored for runtime.
Default: 0
# Args:
# defval (int, optional): The default value to be returned in case there is no
# replication information available. Ignored for runtime.
# Default: 0

Returns:
The replication ordinal of the current thread.
"""
global _ORDINAL
if _ORDINAL is not None:
return _ORDINAL
# Returns:
# The replication ordinal of the current thread.
# """
# global _ORDINAL
# if _ORDINAL is not None:
# return _ORDINAL

return runtime.global_ordinal()
# return runtime.global_ordinal()


def get_local_ordinal(defval=0):
Expand Down Expand Up @@ -166,7 +162,7 @@ def is_master_ordinal(local=True):
Returns:
A boolean indicating whether the current process is the master ordinal.
"""
ordinal = get_local_ordinal() if local else get_ordinal()
ordinal = get_local_ordinal() if local else runtime.get_ordinal()
return ordinal == 0


Expand Down Expand Up @@ -465,10 +461,7 @@ def all_reduce(reduce_type, inputs, scale=1.0, groups=None, pin_layout=True):
groups = groups or []

# No-op if there is only one device
global _WORLD_SIZE
if _WORLD_SIZE is None:
_WORLD_SIZE = runtime.world_size()
if _WORLD_SIZE == 1 and not xu.getenv_as('XLA_ALWAYS_ALLREDUCE',
if runtime.world_size() == 1 and not xu.getenv_as('XLA_ALWAYS_ALLREDUCE',
bool, False):
if isinstance(inputs, torch.Tensor):
return inputs.clone()
Expand Down Expand Up @@ -520,7 +513,7 @@ def _all_gather_using_all_reduce(value, dim=0, groups=None, pin_layout=True):
dim = value.dim() + dim
size = value.size(dim)
padding = [0] * (2 * value.dim())
ordinal = get_ordinal()
ordinal = runtime.get_ordinal()
if groups is None:
left, right = ordinal, runtime.world_size() - 1 - ordinal
else:
Expand Down Expand Up @@ -813,7 +806,7 @@ def collective_broadcast(tensors: List[torch.Tensor],
# so each replica must have the same multiply op with the same parameters.
for tensor in tensors:
scale = torch.tensor(
1 if get_ordinal() == root_ordinal else 0, dtype=tensor.dtype)
1 if runtime.get_ordinal() == root_ordinal else 0, dtype=tensor.dtype)
# Transfer scale tensor as device data instead of constant 1 or 0.
xscale = send_cpu_data_to_device(scale, tensor.device)
tensor.mul_(xscale[0])
Expand Down Expand Up @@ -1386,7 +1379,7 @@ def do_on_ordinals(target, data=(), ordinals=(0,)):
In the ordinals that ran the `target` function, the function return value,
otherwise `None`.
"""
running = get_ordinal() in ordinals
running = runtime.get_ordinal() in ordinals
cpu_data = _maybe_convert_to_cpu(data, convert=running)
if running:
result = target(*cpu_data)
Expand Down
45 changes: 43 additions & 2 deletions torch_xla/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,24 @@
R = TypeVar('R')
FN = TypeVar('FN')

# Note [Dynamo WORLD_SIEZ and ORDINAL]
# Belows are workaround to cache the ordinal and world_size such that
# Dynamo won't do graph breaks when xm.xrt_world_size() and xm.get_ordinal() are called.
_WORLD_SIZE = None
_ORDINAL = None


def _init_world_size_ordinal():
global _WORLD_SIZE, _ORDINAL

# Dynamo doesn't support XRT or multithreaded runtime. See Note [V3-8 Threading]
if not runtime.using_pjrt() or runtime.addressable_device_count() > 1:
return

if _WORLD_SIZE is None:
_WORLD_SIZE = xrt_world_size()
_ORDINAL = get_ordinal()


def set_device_type(pjrt_device: str) -> None:
"""Sets the current PjRt device type.
Expand Down Expand Up @@ -147,9 +165,32 @@ def global_device_count() -> int:
@requires_pjrt
def world_size() -> int:
"""Returns the total number of processes participating in the job."""
global _WORLD_SIZE
if _WORLD_SIZE is not None:
return _WORLD_SIZE
if torch_xla._XLAC._xla_get_replication_devices_count() == 0:
return 1
return global_device_count()
_WORLD_SIZE = 1
else:
_WORLD_SIZE = global_device_count()
return _WORLD_SIZE

@requires_pjrt
def get_ordinal(defval=0):
"""Retrieves the replication ordinal of the current thread.
The ordinals range from 0 to `runtime.world_size()` minus 1.
Args:
defval (int, optional): The default value to be returned in case there is no
replication information available. Ignored for runtime.
Default: 0
Returns:
The replication ordinal of the current thread.
"""
global _ORDINAL
if _ORDINAL is not None:
return _ORDINAL


@requires_pjrt
Expand Down

0 comments on commit 8be771c

Please sign in to comment.