diff --git a/docs/source/index.rst b/docs/source/index.rst index 4652a2115b8..edd6bcc5372 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -75,8 +75,7 @@ distributed ---------------------------------- .. automodule:: torch_xla.distributed.parallel_loader -.. autoclass:: ParallelLoader - :members: per_device_loader +.. autoclass:: MpDeviceLoader .. automodule:: torch_xla.distributed.xla_multiprocessing .. autofunction:: spawn diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index e64e78a1216..d3c243335f3 100644 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -1017,13 +1017,13 @@ def mark_step(wait=False, reset_scope=True): torch_xla._XLAC._set_all_reduce_token(devctx.device, None) +# TODO(lsy323): When `tensors` is empty, the some intermediate tensors will also be +# dump as outputs. Need further investigation. def get_stablehlo(tensors=None) -> str: """Get StableHLO for the computation graph in string format. If `tensors` is not empty, the graph with `tensors` as outputs will be dump. If `tensors` is empty, the whole computation graph will be dump. - TODO(lsy323): When `tensors` is empty, the some intermediate tensors will also be - dump as outputs. Need further investigation. For inference graph, it is recommended to pass the model outputs to `tensors`. For training graph, it is not straightforward to identify the "outputs". Using empty `tensors` is recommended. @@ -1043,13 +1043,13 @@ def get_stablehlo(tensors=None) -> str: False).decode('utf-8') +# TODO(lsy323): When `tensors` is empty, the some intermediate tensors will also be +# dump as outputs. Need further investigation. def get_stablehlo_bytecode(tensors=None) -> bytes: """Get StableHLO for the computation graph in bytecode format. If `tensors` is not empty, the graph with `tensors` as outputs will be dump. If `tensors` is empty, the whole computation graph will be dump. - TODO(lsy323): When `tensors` is empty, the some intermediate tensors will also be - dump as outputs. Need further investigation. For inference graph, it is recommended to pass the model outputs to `tensors`. For training graph, it is not straightforward to identify the "outputs". Using empty `tensors` is recommended. @@ -1154,7 +1154,7 @@ def optimizer_step(optimizer, optimizer_args={}, groups=None, pin_layout=True): - """Run the provided optimizer step and issue the XLA device step computation. + """Run the provided optimizer step and sync gradidents across all devices. Args: optimizer (:class:`torch.Optimizer`): The `torch.Optimizer` instance whose @@ -1177,6 +1177,11 @@ def optimizer_step(optimizer, Returns: The same value returned by the `optimizer.step()` call. + + Example: + + >>> import torch_xla.core.xla_model as xm + >>> xm.optimizer_step(self.optimizer) """ reduce_gradients(optimizer, groups=groups, pin_layout=pin_layout) loss = optimizer.step(**optimizer_args) @@ -1210,9 +1215,13 @@ def save(data, file_or_path, master_only=True, global_master=False): controls whether every host's master (if ``global_master`` is ``False``) saves the content, or only the global master (ordinal 0). Default: False - sync (bool, optional): Whether to synchronize all replicas after saving - tensors. If True, all replicas must call `xm.save` or the main process - will hang. + + Example: + + >>> import torch_xla.core.xla_model as xm + >>> xm.wait_device_ops() # wait for all pending operations to finish. + >>> xm.save(obj_to_save, path_to_save) + >>> xm.rendezvous('torch_xla.core.xla_model.save') # multi process context only """ should_write_data = not master_only or is_master_ordinal( local=not global_master) @@ -1328,6 +1337,11 @@ def rendezvous(tag, payload=b'', replicas=[]): Returns: The payloads exchanged by all the other cores, with the payload of core ordinal `i` at position `i` in the returned tuple. + + Example: + + >>> import torch_xla.core.xla_model as xm + >>> xm.rendezvous('example') """ return xla_rendezvous(payload, replicas or None, tag=tag) @@ -1373,6 +1387,12 @@ def mesh_reduce(tag, data, reduce_fn): Returns: The reduced value. + + Example: + + >>> import torch_xla.core.xla_model as xm + >>> import numpy as np + >>> accuracy = xm.mesh_reduce('test_accuracy', accuracy, np.mean) """ cpu_data = _maybe_convert_to_cpu(data) bio = io.BytesIO() @@ -1449,6 +1469,11 @@ def get_memory_info(device: Optional[torch.device] = None) -> MemoryInfo: Returns: MemoryInfo dict with memory usage for the given device. + + Example: + + >>> xm.get_memory_info() + {'bytes_used': 290816, 'bytes_limit': 34088157184} """ if device == None: device = xla_device() diff --git a/torch_xla/distributed/parallel_loader.py b/torch_xla/distributed/parallel_loader.py index 8af7196e95c..94dfcbc3b2e 100644 --- a/torch_xla/distributed/parallel_loader.py +++ b/torch_xla/distributed/parallel_loader.py @@ -186,13 +186,20 @@ def _worker(self, dqueue, host_to_device_transfer_threads): class MpDeviceLoader(object): """Wraps an existing PyTorch DataLoader with background data upload. - This class should only be using with multi-processing data parallelism. + This class should only be using with multi-processing data parallelism. It will wrap + the dataloader passed in with ParallelLoader and return the per_device_loader for the + current device. Args: loader (:class:`torch.utils.data.DataLoader`): The PyTorch DataLoader to be wrapped. device (`torch.device`...): The device where the data has to be sent. kwargs: Named arguments for the `ParallelLoader` constructor. + + Example: + + >>> device = torch_xla.device() + >>> train_device_loader = MpDeviceLoader(train_loader, device) """ def __init__(self, loader, device, **kwargs): diff --git a/torch_xla/distributed/spmd/xla_sharding.py b/torch_xla/distributed/spmd/xla_sharding.py index 39edf73b156..a7b7ec758bd 100644 --- a/torch_xla/distributed/spmd/xla_sharding.py +++ b/torch_xla/distributed/spmd/xla_sharding.py @@ -41,7 +41,7 @@ class Mesh: [4, 5], [6, 7]]) >>> mesh.shape() - >>> OrderedDict([('x', 4), ('y', 2)]) + OrderedDict([('x', 4), ('y', 2)]) """ device_ids: np.ndarray @@ -176,9 +176,9 @@ def get_1d_mesh(axis_name: Optional[str] = None) -> Mesh: >>> import torch_xla.distributed.spmd as xs >>> mesh = xs.get_1d_mesh("data") >>> print(mesh.mesh_shape) - >>> (4,) + (4,) >>> print(mesh.axis_names) - >>> ('data',) + ('data',) """ num_devices = xr.global_runtime_device_count() mesh_shape = (num_devices,) @@ -579,21 +579,18 @@ def mark_sharding( """ Annotates the tensor provided with XLA partition spec. Internally, it annotates the corresponding XLATensor as sharded for the XLA SpmdPartitioner pass. + Args: t (Union[torch.Tensor, XLAShardedTensor]): input tensor to be annotated with partition_spec. mesh (Mesh): describes the logical XLA device topology and the underlying device IDs. partition_spec (Tuple[Tuple, int, str, None]): A tuple of device_mesh dimension index or - `None`. Each index is an int, str if the mesh axis is named, or tuple of int or str. - This specifies how each input rank is sharded (index to mesh_shape) or replicated (None). - When a tuple is specified, the corresponding input tensor axis will be sharded along all - logical axes in the tuple. Note that the order the mesh axes are specified in the tuple - will impact the resulting sharding. - For example, we can shard an 8x10 tensor 4-way row-wise, and replicate column-wise. - >> input = torch.randn(8, 10) - >> mesh_shape = (4, 2) - >> partition_spec = (0, None) + `None`. Each index is an int, str if the mesh axis is named, or tuple of int or str. + This specifies how each input rank is sharded (index to mesh_shape) or replicated (None). + When a tuple is specified, the corresponding input tensor axis will be sharded along all + logical axes in the tuple. Note that the order the mesh axes are specified in the tuple + will impact the resulting sharding. dynamo_custom_op (bool): if set to True, it calls the dynamo custom op variant of mark_sharding to make itself recognizeable and traceable by dynamo. @@ -606,12 +603,10 @@ def mark_sharding( >>> num_devices = xr.global_runtime_device_count() >>> device_ids = np.array(range(num_devices)) >>> mesh = Mesh(device_ids, mesh_shape, ('x', 'y')) - >>> # 4-way data parallel >>> input = torch.randn(8, 32).to(xm.xla_device()) - >>> xs.mark_sharding(input, mesh, (0, None)) - >>> # 2-way model parallel + >>> xs.mark_sharding(input, mesh, (0, None)) # 4-way data parallel >>> linear = nn.Linear(32, 10).to(xm.xla_device()) - >>> xs.mark_sharding(linear.weight, mesh, (None, 1)) + >>> xs.mark_sharding(linear.weight, mesh, (None, 1)) # 2-way model parallel """ num_devices = xr.global_runtime_device_count() assert num_devices > 0, "This requires XLA supported device(s)." diff --git a/torch_xla/runtime.py b/torch_xla/runtime.py index 5c9417a2bc6..0b963e378ec 100644 --- a/torch_xla/runtime.py +++ b/torch_xla/runtime.py @@ -79,6 +79,11 @@ def _maybe_select_default_device(): def device_type() -> Optional[str]: """Returns the current PjRt device type. + + Selects a default device if none has been configured + + Returns: + A string representation of the device. """ pjrt_device = xu.getenv_as(xenv.PJRT_DEVICE, str) return pjrt_device.split('_')[0] if pjrt_device else pjrt_device @@ -212,11 +217,18 @@ def addressable_runtime_device_count() -> int: return torch_xla._XLAC._xla_num_runtime_devices() -# API to enable SPMD mode. This is a recommended way to enable SPMD. -# This forces SPMD mode if some tensors are already initialized on non-SPMD -# devices. This means that those tensors would be replicated across the devices. # TODO(yeounoh) introduce SPMD configuration. def use_spmd(auto: Optional[bool] = False): + """API to enable SPMD mode. This is a recommended way to enable SPMD. + + This forces SPMD mode if some tensors are already initialized on non-SPMD + devices. This means that those tensors would be replicated across the devices. + + Args: + auto (bool): Whether to enable the auto-sharding. Read + https://github.com/pytorch/xla/blob/master/docs/spmd_advanced.md#auto-sharding + for more detail + """ if os.environ.get("XLA_USE_SPMD") is not None: warnings.warn("XLA_USE_SPMD is being deprecated. " "Use torch_xla.runtime.use_spmd() " @@ -249,7 +261,8 @@ def get_master_ip() -> str: """Retrieve the master worker IP for the runtime. This calls into backend-specific discovery APIs. - Returns master worker's IP address as a string.""" + Returns: + master worker's IP address as a string.""" if device_type() == 'TPU': return tpu.discover_master_worker_ip() raise RuntimeError(f'IP discovery not supported for device: {device_type()}') @@ -260,8 +273,8 @@ def initialize_cache(path: str, readonly: bool = False): before any computations have been performed. Args: - path: The path at which to store the persistent cache. - readonly: Whether or not this worker should have write access to the cache. + path (str): The path at which to store the persistent cache. + readonly (bool): Whether or not this worker should have write access to the cache. """ assert not torch_xla._XLAC._xla_computation_cache_is_initialized( ), "Computation cache has already been initialized"