Skip to content

Commit

Permalink
[Doc] Update spmd.md for doc (#7019)
Browse files Browse the repository at this point in the history
  • Loading branch information
ManfeiBai authored May 30, 2024
1 parent 8fd051f commit cb482bc
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions docs/spmd.md
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,7 @@ Note that I used a batch size 4 times as large since I am running it on a TPU v4

We provide a `shard placement visualization debug tool` for PyTorch/XLA SPMD user on TPU/GPU/CPU with single-host/multi-host: you could use `visualize_tensor_sharding` to visualize sharded tensor, or you could use `visualize_sharding` to visualize sharing string. Here are two code examples on TPU single-host(v4-8) with `visualize_tensor_sharding` or `visualize_sharding`:
- Code snippet used `visualize_tensor_sharding` and visualization result:

```python
import rich

Expand All @@ -501,7 +502,9 @@ from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding
generated_table = visualize_tensor_sharding(t, use_color=False)
```
![alt_text](assets/spmd_debug_1.png "visualize_tensor_sharding example on TPU v4-8(single-host)")

- Code snippet used `visualize_sharding` and visualization result:

```python
from torch_xla.distributed.spmd.debugging import visualize_sharding
sharding = '{devices=[2,2]0,1,2,3}'
Expand All @@ -517,11 +520,13 @@ We are introducing a new PyTorch/XLA SPMD feature, called ``auto-sharding``, [RF
PyTorch/XLA auto-sharding can be enabled by one of the following:
- Setting envvar `XLA_SPMD_AUTO=1`
- Calling the SPMD API in the beginning of your code:

```python
import torch_xla.runtime as xr
xr.use_spmd(auto=True)
```
- Calling `pytorch.distributed._tensor.distribute_module` with `auto-policy` and `xla`:

```python
import torch_xla.runtime as xr
from torch.distributed._tensor import DeviceMesh, distribute_module
Expand Down

0 comments on commit cb482bc

Please sign in to comment.