Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

R2.4 #7316

Closed
wants to merge 9 commits into from
Closed

R2.4 #7316

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ new_local_repository(
# curl -L https://github.com/openxla/xla/archive/<git hash>.tar.gz | sha256sum
# and update the sha256 with the result.

xla_hash = 'b604c8d87df842002a7a8de79a434026329fbcb2'
xla_hash = 'bf2dc9fe056bd7140e5f29a2ae6db15a26dd5443'

http_archive(
name = "xla",
Expand Down
116 changes: 37 additions & 79 deletions docs/fori_loop.md
Original file line number Diff line number Diff line change
@@ -1,114 +1,72 @@
# Fori_loop
`fori_loop` is a replacement of pure python for loop, PyTorch/XLA would enable `torch_xla.experimental.fori_loop` to keep loop computation graph as rolled during compilation
like [`jax.lax.fori_loop`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.fori_loop.html), not like currently repeat computations by enumerating all execution steps
of each iteration. `fori_loop` might help memory utilization and might help faster compilation.
# `While_loop` optimize memory utilization and compilation

User could use `fori_loop` like this:
```python
from torch_xla.experimental.fori_loop import fori_loop
res = fori_loop(upper, lower, /*user defined*/body_fun, init)
```

current fori_loop only support simple test like [link](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py), and user could try [simple user guide](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/docs/fori_loop.md#simple-example-with-fori_loop) with `fori_loop` on TPU too.
<br>

For detailed implementation:
- for situation that loop range is dynamic, [`fori_loop`](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/docs/fori_loop.md#fori_loop) is implemented with [`while_loop`](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/docs/fori_loop.md#while_loop),
like [`jax.lax.while_loop`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.while_loop.html), PyTorch/XLA would support `while_loop` with the
native PyTorch and the XLA backend: XLA::While. Due to `while_loop` didn't support autograd, so it would be used for inference only.
### `while_loop`
`while_loop` replace pure python `while` loop, PyTorch supported `while_loop` by
[torch._higher_order_ops.while_loop](https://github.com/pytorch/pytorch/blob/62311257adb902d6a4ea98809c88895af1dbbf2b/torch/_higher_order_ops/while_loop.py#L66).
PyTorch/XLA provide experimental XLA backend support for `torch._higher_order_ops.while_loop` via `XLA::While`.

- for situation that loop range is not dynamic, [`fori_loop`](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/docs/fori_loop.md#fori_loop) is implemented with [`scan`](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/docs/fori_loop.md#wipscan),
like [`jax.lax.scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html), PyTorch/XLA would enable `scan` using XLA::While operator.
This implementation would be very similar like `while_loop`. `scan` support autograd, and it could be used in both training and inference.

# while_loop
`while_loop` is a replacement of pure python while loop, PyTorch has supported `while_loop` in
[code](https://github.com/pytorch/pytorch/blob/ca6a0e1348ba7dcade1833d983b1b4ca12a5c1e1/torch/_higher_order_ops/while_loop.py#L69).
PyTorch/XLA want to support `while_loop` with the native PyTorch and the XLA backend: XLA::While.

User could use `while_loop` like this:
#### Usage:
```python
import torch_xla.experimental.fori_loop
from torch._higher_order_ops.while_loop import while_loop
res = while_loop(/*user-defined*/cond_fn, /*user-defined*/body_fn, /*tuple or list*/init)
result = while_loop(cond_fn, body_fn, init)
```
current while_loop only support simple test like [link](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py), and user could try [simple user guide](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/docs/fori_loop.md#simple-example-with-while_loop) with `while_loop` on TPU too.

- `cond_fn`: User-defined condition function.
- `body_fn`: User-defined loop body function.
- `init`: Initial values (tuple or list).

# [WIP]scan
like [`jax.lax.scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html), PyTorch/XLA would enable `scan` for training and inference since it support autograd.
`scan` is WIP.


# Simple user guide
User could try these three simple test case to better compare difference between `pure python for loop` and `fori_loop` and `while_loop`, these three test case have similar logic: cumulative plus 1 for ten times:

### simple example with pure python for loop
```bash
# python
>>> import torch
>>> init = torch.tensor([0], dtype=torch.int32)
>>> one_value = torch.ones(1, dtype=torch.int32)
>>>
>>> for i in range(10):
... init = init + one_value
...
>>> init
tensor([10], dtype=torch.int32)
```

### simple example with `while_loop`:
#### simple example with `while_loop`:
```bash
# PJRT_DEVICE=TPU python
>>> import torch
>>> import torch_xla
>>> import torch_xla.experimental.fori_loop
>>> from torch_xla.experimental.fori_loop import fori_loop
>>> from torch._higher_order_ops.while_loop import while_loop
>>> import torch_xla.core.xla_model as xm
>>> import torch_xla.core.xla_builder as xb
>>>
>>> device = xm.xla_device()
>>>
>>> def cond_fn(init, limit_value):
... return limit_value[0] >= init[0]
>>> def cond_fn(iteri, x):
... return iteri > 0
...
>>> def body_fn(init, limit_value):
... one_value = torch.ones(1, dtype=torch.int32, device=device)
... return (torch.add(init, one_value), limit_value.clone())
>>> def body_fn(iteri, x):
... return iteri - 1, torch.add(x, 1)
...
>>> init = torch.tensor([0], dtype=torch.int32, device=device)
>>> limit_value = torch.tensor([10], dtype=torch.int32, device=device)
>>> res_, limit_value_ = while_loop(cond_fn, body_fn, (init, limit_value))
>>> res_
>>> init_val = torch.tensor(3, device=device)
>>> iteri = torch.tensor(10, device=device)
>>> _, res = while_loop(cond_fn, body_fn, (iteri, init_val))
>>> res
FunctionalTensor(lvl=0, value=\
tensor([11], device='xla:0', dtype=torch.int32))
tensor(13, device='xla:0'))
```

### simple example with `fori_loop`:
<br>

## Control group test case
For better compare difference between `pure python while loop` and `while_loop`, there is one test case called pure python `while` loop with similar logic: cumulative plus 1 for ten times:

### Control group example with pure python `while` loop
```bash
# PJRT_DEVICE=TPU python
>>> import torch
>>> import torch_xla
>>> import torch_xla.experimental.fori_loop
>>> from torch_xla.experimental.fori_loop import fori_loop
>>> from torch._higher_order_ops.while_loop import while_loop
>>> import torch_xla.core.xla_model as xm
>>> import torch_xla.core.xla_builder as xb
>>>
>>> device = xm.xla_device()
>>>
>>> lower = torch.tensor([2], dtype=torch.int32, device=device)
>>> upper = torch.tensor([52], dtype=torch.int32, device=device)
>>> plus_value = torch.tensor([1], dtype=torch.int32, device=device)
>>> init_val = torch.tensor([1], dtype=torch.int32, device=device)
>>> init_val = torch.tensor(1, device=device)
>>> iteri = torch.tensor(50, device=device)
>>>
>>> def body_fun(*argus):
... plus_value, init_val = argus
... return plus_value, torch.add(plus_value, init_val)
>>> while iteri > 0:
... init_val = init_val + 1
... iteri -= 1
...
>>> _, _, _, res_ = fori_loop(upper, lower, body_fun, plus_value, init_val)
>>> res_
tensor([51], device='xla:0', dtype=torch.int32)
>>> init_val
tensor(51, device='xla:0')
```

For more example and detailed user guide, please read [this test file](https://github.com/pytorch/xla/blob/master/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py). PyTorch/XLA would include `while_loop` support in 2.3 for simple test case, complex test case and support for `fori_loop` and `scan` would be added after 2.3


PyTorch/XLA would include `while_loop` support in 2.4 with test case, support for `fori_loop` would be added after 2.4. For `while_loop`, currently we only should force define `body_fn` with same `input` and `output(return args)` shape
58 changes: 58 additions & 0 deletions infra/tpu-pytorch-releases/artifacts.auto.tfvars
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,64 @@ nightly_builds = [

# Built on push to specific tag.
versioned_builds = [
# Remove libtpu from PyPI builds
{
git_tag = "v2.4.0-rc1"
package_version = "2.4.0-rc1"
pytorch_git_rev = "v2.4.0-rc1"
accelerator = "tpu"
python_version = "3.8"
bundle_libtpu = "0"
},
{
git_tag = "v2.4.0-rc1"
package_version = "2.4.0-rc1"
pytorch_git_rev = "v2.4.0-rc1"
accelerator = "tpu"
python_version = "3.9"
bundle_libtpu = "0"
},
{
git_tag = "v2.4.0-rc1"
package_version = "2.4.0-rc1"
pytorch_git_rev = "v2.4.0-rc1"
accelerator = "tpu"
python_version = "3.10"
bundle_libtpu = "0"
},
{
git_tag = "v2.4.0-rc1"
package_version = "2.4.0-rc1"
pytorch_git_rev = "v2.4.0-rc1"
accelerator = "tpu"
python_version = "3.11"
bundle_libtpu = "0"
},
# Bundle libtpu for Kaggle
{
git_tag = "v2.4.0-rc1"
package_version = "2.4.0-rc1+libtpu"
pytorch_git_rev = "v2.4.0-rc1"
accelerator = "tpu"
python_version = "3.10"
bundle_libtpu = "1"
},
{
git_tag = "v2.4.0-rc1"
pytorch_git_rev = "v2.4.0-rc1"
package_version = "2.4.0-rc1"
accelerator = "cuda"
cuda_version = "12.1"
python_version = "3.8"
},
{
git_tag = "v2.4.0-rc1"
pytorch_git_rev = "v2.4.0-rc1"
package_version = "2.4.0-rc1"
accelerator = "cuda"
cuda_version = "12.1"
python_version = "3.10"
},
# Remove libtpu from PyPI builds
{
git_tag = "v2.3.0"
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@

base_dir = os.path.dirname(os.path.abspath(__file__))

_date = '20240605'
_date = '20240612'
_libtpu_version = f'0.1.dev{_date}'
_libtpu_storage_path = f'https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-{_libtpu_version}-py3-none-any.whl'
_jax_version = f'0.4.29.dev{_date}'
_jax_version = f'0.4.30.dev{_date}'


def _get_build_mode():
Expand Down
2 changes: 1 addition & 1 deletion test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ function run_xla_op_tests1 {
function run_xla_op_tests2 {
run_downcast_bf16 "$CDIR/test_data_type.py"
run_test "$CDIR/pjrt/test_dtypes.py"
run_test "$CDIR/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py"
run_test "$CDIR/test_while_loop.py"
run_test "$CDIR/test_autocast.py" # TODO(yeounoh) this is expensive on GPU
}

Expand Down
4 changes: 2 additions & 2 deletions test/spmd/test_sharding_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@

num_devices = xr.global_runtime_device_count()

assert np.product(dcn_parallelism) * np.product(
assert np.prod(dcn_parallelism) * np.prod(
ici_parallelism) == num_devices, f"Number of devices {num_devices} \
does not match the product of the parallelism {np.product(dcn_parallelism) * np.product(ici_parallelism)}"
does not match the product of the parallelism {np.prod(dcn_parallelism) * np.prod(ici_parallelism)}"

# Use HybridMesh to optimize multislice topology
mesh = xs.HybridMesh(
Expand Down
28 changes: 28 additions & 0 deletions test/spmd/test_xla_distributed_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,34 @@ def test_resharding_different_device_mesh(self):
save_planner=SPMDSavePlanner(),
load_planner=SPMDLoadPlanner())

@unittest.skipIf(xr.global_runtime_device_count() == 1,
"Multiple devices needed to change mesh")
def test_resharding_transpose_device_mesh(self):
dim = self.n_devices // 2
model1 = self._get_sharded_model(mesh_shape=(dim, self.n_devices // dim))
model2 = self._get_sharded_model(mesh_shape=(self.n_devices // dim, dim))
self._save_and_restore(
model1,
model2,
save_planner=SPMDSavePlanner(),
load_planner=SPMDLoadPlanner())

@unittest.skipIf(xr.global_runtime_device_count() == 1,
"Multiple devices needed to change mesh")
def test_padded_tensor(self):
# Use a linear layer with shape not divisible by the number of devices.
model1 = torch.nn.Linear(127, 63).to('xla')
model2 = torch.nn.Linear(127, 63).to('xla')
mesh = xs.Mesh(range(self.n_devices), (self.n_devices,))
# Transpose the sharding to induce resharding in the restore path
xs.mark_sharding(model1.weight, mesh, (0, None))
xs.mark_sharding(model2.weight, mesh, (None, 0))
self._save_and_restore(
model1,
model2,
save_planner=SPMDSavePlanner(),
load_planner=SPMDLoadPlanner())

@unittest.skipUnless('CHKPT_PATH' in os.environ,
'CHKPT_PATH must be set for multihost checkpoint')
def test_multihost_checkpoint(self):
Expand Down
Loading
Loading