Skip to content

Commit

Permalink
Update function doc strings (#7757)
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG committed Aug 2, 2024
1 parent 225ec7b commit d29b761
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 33 deletions.
3 changes: 1 addition & 2 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,7 @@ distributed
----------------------------------

.. automodule:: torch_xla.distributed.parallel_loader
.. autoclass:: ParallelLoader
:members: per_device_loader
.. autoclass:: MpDeviceLoader

.. automodule:: torch_xla.distributed.xla_multiprocessing
.. autofunction:: spawn
Expand Down
41 changes: 33 additions & 8 deletions torch_xla/core/xla_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1017,13 +1017,13 @@ def mark_step(wait=False, reset_scope=True):
torch_xla._XLAC._set_all_reduce_token(devctx.device, None)


# TODO(lsy323): When `tensors` is empty, the some intermediate tensors will also be
# dump as outputs. Need further investigation.
def get_stablehlo(tensors=None) -> str:
"""Get StableHLO for the computation graph in string format.
If `tensors` is not empty, the graph with `tensors` as outputs will be dump.
If `tensors` is empty, the whole computation graph will be dump.
TODO(lsy323): When `tensors` is empty, the some intermediate tensors will also be
dump as outputs. Need further investigation.
For inference graph, it is recommended to pass the model outputs to `tensors`.
For training graph, it is not straightforward to identify the "outputs". Using empty `tensors` is recommended.
Expand All @@ -1043,13 +1043,13 @@ def get_stablehlo(tensors=None) -> str:
False).decode('utf-8')


# TODO(lsy323): When `tensors` is empty, the some intermediate tensors will also be
# dump as outputs. Need further investigation.
def get_stablehlo_bytecode(tensors=None) -> bytes:
"""Get StableHLO for the computation graph in bytecode format.
If `tensors` is not empty, the graph with `tensors` as outputs will be dump.
If `tensors` is empty, the whole computation graph will be dump.
TODO(lsy323): When `tensors` is empty, the some intermediate tensors will also be
dump as outputs. Need further investigation.
For inference graph, it is recommended to pass the model outputs to `tensors`.
For training graph, it is not straightforward to identify the "outputs". Using empty `tensors` is recommended.
Expand Down Expand Up @@ -1154,7 +1154,7 @@ def optimizer_step(optimizer,
optimizer_args={},
groups=None,
pin_layout=True):
"""Run the provided optimizer step and issue the XLA device step computation.
"""Run the provided optimizer step and sync gradidents across all devices.
Args:
optimizer (:class:`torch.Optimizer`): The `torch.Optimizer` instance whose
Expand All @@ -1177,6 +1177,11 @@ def optimizer_step(optimizer,
Returns:
The same value returned by the `optimizer.step()` call.
Example:
>>> import torch_xla.core.xla_model as xm
>>> xm.optimizer_step(self.optimizer)
"""
reduce_gradients(optimizer, groups=groups, pin_layout=pin_layout)
loss = optimizer.step(**optimizer_args)
Expand Down Expand Up @@ -1210,9 +1215,13 @@ def save(data, file_or_path, master_only=True, global_master=False):
controls whether every host's master (if ``global_master`` is ``False``)
saves the content, or only the global master (ordinal 0).
Default: False
sync (bool, optional): Whether to synchronize all replicas after saving
tensors. If True, all replicas must call `xm.save` or the main process
will hang.
Example:
>>> import torch_xla.core.xla_model as xm
>>> xm.wait_device_ops() # wait for all pending operations to finish.
>>> xm.save(obj_to_save, path_to_save)
>>> xm.rendezvous('torch_xla.core.xla_model.save') # multi process context only
"""
should_write_data = not master_only or is_master_ordinal(
local=not global_master)
Expand Down Expand Up @@ -1328,6 +1337,11 @@ def rendezvous(tag, payload=b'', replicas=[]):
Returns:
The payloads exchanged by all the other cores, with the payload of core
ordinal `i` at position `i` in the returned tuple.
Example:
>>> import torch_xla.core.xla_model as xm
>>> xm.rendezvous('example')
"""
return xla_rendezvous(payload, replicas or None, tag=tag)

Expand Down Expand Up @@ -1373,6 +1387,12 @@ def mesh_reduce(tag, data, reduce_fn):
Returns:
The reduced value.
Example:
>>> import torch_xla.core.xla_model as xm
>>> import numpy as np
>>> accuracy = xm.mesh_reduce('test_accuracy', accuracy, np.mean)
"""
cpu_data = _maybe_convert_to_cpu(data)
bio = io.BytesIO()
Expand Down Expand Up @@ -1449,6 +1469,11 @@ def get_memory_info(device: Optional[torch.device] = None) -> MemoryInfo:
Returns:
MemoryInfo dict with memory usage for the given device.
Example:
>>> xm.get_memory_info()
{'bytes_used': 290816, 'bytes_limit': 34088157184}
"""
if device == None:
device = xla_device()
Expand Down
9 changes: 8 additions & 1 deletion torch_xla/distributed/parallel_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,13 +186,20 @@ def _worker(self, dqueue, host_to_device_transfer_threads):
class MpDeviceLoader(object):
"""Wraps an existing PyTorch DataLoader with background data upload.
This class should only be using with multi-processing data parallelism.
This class should only be using with multi-processing data parallelism. It will wrap
the dataloader passed in with ParallelLoader and return the per_device_loader for the
current device.
Args:
loader (:class:`torch.utils.data.DataLoader`): The PyTorch DataLoader to be
wrapped.
device (`torch.device`...): The device where the data has to be sent.
kwargs: Named arguments for the `ParallelLoader` constructor.
Example:
>>> device = torch_xla.device()
>>> train_device_loader = MpDeviceLoader(train_loader, device)
"""

def __init__(self, loader, device, **kwargs):
Expand Down
27 changes: 11 additions & 16 deletions torch_xla/distributed/spmd/xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class Mesh:
[4, 5],
[6, 7]])
>>> mesh.shape()
>>> OrderedDict([('x', 4), ('y', 2)])
OrderedDict([('x', 4), ('y', 2)])
"""

device_ids: np.ndarray
Expand Down Expand Up @@ -176,9 +176,9 @@ def get_1d_mesh(axis_name: Optional[str] = None) -> Mesh:
>>> import torch_xla.distributed.spmd as xs
>>> mesh = xs.get_1d_mesh("data")
>>> print(mesh.mesh_shape)
>>> (4,)
(4,)
>>> print(mesh.axis_names)
>>> ('data',)
('data',)
"""
num_devices = xr.global_runtime_device_count()
mesh_shape = (num_devices,)
Expand Down Expand Up @@ -579,21 +579,18 @@ def mark_sharding(
"""
Annotates the tensor provided with XLA partition spec. Internally,
it annotates the corresponding XLATensor as sharded for the XLA SpmdPartitioner pass.
Args:
t (Union[torch.Tensor, XLAShardedTensor]): input tensor to be annotated with partition_spec.
mesh (Mesh): describes the logical XLA device topology and the underlying device IDs.
partition_spec (Tuple[Tuple, int, str, None]): A tuple of device_mesh dimension index or
`None`. Each index is an int, str if the mesh axis is named, or tuple of int or str.
This specifies how each input rank is sharded (index to mesh_shape) or replicated (None).
When a tuple is specified, the corresponding input tensor axis will be sharded along all
logical axes in the tuple. Note that the order the mesh axes are specified in the tuple
will impact the resulting sharding.
For example, we can shard an 8x10 tensor 4-way row-wise, and replicate column-wise.
>> input = torch.randn(8, 10)
>> mesh_shape = (4, 2)
>> partition_spec = (0, None)
`None`. Each index is an int, str if the mesh axis is named, or tuple of int or str.
This specifies how each input rank is sharded (index to mesh_shape) or replicated (None).
When a tuple is specified, the corresponding input tensor axis will be sharded along all
logical axes in the tuple. Note that the order the mesh axes are specified in the tuple
will impact the resulting sharding.
dynamo_custom_op (bool): if set to True, it calls the dynamo custom op variant of mark_sharding
to make itself recognizeable and traceable by dynamo.
Expand All @@ -606,12 +603,10 @@ def mark_sharding(
>>> num_devices = xr.global_runtime_device_count()
>>> device_ids = np.array(range(num_devices))
>>> mesh = Mesh(device_ids, mesh_shape, ('x', 'y'))
>>> # 4-way data parallel
>>> input = torch.randn(8, 32).to(xm.xla_device())
>>> xs.mark_sharding(input, mesh, (0, None))
>>> # 2-way model parallel
>>> xs.mark_sharding(input, mesh, (0, None)) # 4-way data parallel
>>> linear = nn.Linear(32, 10).to(xm.xla_device())
>>> xs.mark_sharding(linear.weight, mesh, (None, 1))
>>> xs.mark_sharding(linear.weight, mesh, (None, 1)) # 2-way model parallel
"""
num_devices = xr.global_runtime_device_count()
assert num_devices > 0, "This requires XLA supported device(s)."
Expand Down
25 changes: 19 additions & 6 deletions torch_xla/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ def _maybe_select_default_device():

def device_type() -> Optional[str]:
"""Returns the current PjRt device type.
Selects a default device if none has been configured
Returns:
A string representation of the device.
"""
pjrt_device = xu.getenv_as(xenv.PJRT_DEVICE, str)
return pjrt_device.split('_')[0] if pjrt_device else pjrt_device
Expand Down Expand Up @@ -212,11 +217,18 @@ def addressable_runtime_device_count() -> int:
return torch_xla._XLAC._xla_num_runtime_devices()


# API to enable SPMD mode. This is a recommended way to enable SPMD.
# This forces SPMD mode if some tensors are already initialized on non-SPMD
# devices. This means that those tensors would be replicated across the devices.
# TODO(yeounoh) introduce SPMD configuration.
def use_spmd(auto: Optional[bool] = False):
"""API to enable SPMD mode. This is a recommended way to enable SPMD.
This forces SPMD mode if some tensors are already initialized on non-SPMD
devices. This means that those tensors would be replicated across the devices.
Args:
auto (bool): Whether to enable the auto-sharding. Read
https://github.com/pytorch/xla/blob/master/docs/spmd_advanced.md#auto-sharding
for more detail
"""
if os.environ.get("XLA_USE_SPMD") is not None:
warnings.warn("XLA_USE_SPMD is being deprecated. "
"Use torch_xla.runtime.use_spmd() "
Expand Down Expand Up @@ -249,7 +261,8 @@ def get_master_ip() -> str:
"""Retrieve the master worker IP for the runtime. This calls into
backend-specific discovery APIs.
Returns master worker's IP address as a string."""
Returns:
master worker's IP address as a string."""
if device_type() == 'TPU':
return tpu.discover_master_worker_ip()
raise RuntimeError(f'IP discovery not supported for device: {device_type()}')
Expand All @@ -260,8 +273,8 @@ def initialize_cache(path: str, readonly: bool = False):
before any computations have been performed.
Args:
path: The path at which to store the persistent cache.
readonly: Whether or not this worker should have write access to the cache.
path (str): The path at which to store the persistent cache.
readonly (bool): Whether or not this worker should have write access to the cache.
"""
assert not torch_xla._XLAC._xla_computation_cache_is_initialized(
), "Computation cache has already been initialized"
Expand Down

0 comments on commit d29b761

Please sign in to comment.