Skip to content

Commit

Permalink
[backport][Fori_loop|While_loop] Enable while_loop/fori_loop, add tes…
Browse files Browse the repository at this point in the history
…t case (#7157) (#7306)

Co-authored-by: JackCaoG <59073027+JackCaoG@users.noreply.github.com>
  • Loading branch information
ManfeiBai and JackCaoG committed Jun 18, 2024
1 parent a901eb8 commit 2793462
Show file tree
Hide file tree
Showing 7 changed files with 287 additions and 247 deletions.
116 changes: 37 additions & 79 deletions docs/fori_loop.md
Original file line number Diff line number Diff line change
@@ -1,114 +1,72 @@
# Fori_loop
`fori_loop` is a replacement of pure python for loop, PyTorch/XLA would enable `torch_xla.experimental.fori_loop` to keep loop computation graph as rolled during compilation
like [`jax.lax.fori_loop`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.fori_loop.html), not like currently repeat computations by enumerating all execution steps
of each iteration. `fori_loop` might help memory utilization and might help faster compilation.
# `While_loop` optimize memory utilization and compilation

User could use `fori_loop` like this:
```python
from torch_xla.experimental.fori_loop import fori_loop
res = fori_loop(upper, lower, /*user defined*/body_fun, init)
```

current fori_loop only support simple test like [link](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py), and user could try [simple user guide](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/docs/fori_loop.md#simple-example-with-fori_loop) with `fori_loop` on TPU too.
<br>

For detailed implementation:
- for situation that loop range is dynamic, [`fori_loop`](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/docs/fori_loop.md#fori_loop) is implemented with [`while_loop`](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/docs/fori_loop.md#while_loop),
like [`jax.lax.while_loop`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.while_loop.html), PyTorch/XLA would support `while_loop` with the
native PyTorch and the XLA backend: XLA::While. Due to `while_loop` didn't support autograd, so it would be used for inference only.
### `while_loop`
`while_loop` replace pure python `while` loop, PyTorch supported `while_loop` by
[torch._higher_order_ops.while_loop](https://github.com/pytorch/pytorch/blob/62311257adb902d6a4ea98809c88895af1dbbf2b/torch/_higher_order_ops/while_loop.py#L66).
PyTorch/XLA provide experimental XLA backend support for `torch._higher_order_ops.while_loop` via `XLA::While`.

- for situation that loop range is not dynamic, [`fori_loop`](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/docs/fori_loop.md#fori_loop) is implemented with [`scan`](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/docs/fori_loop.md#wipscan),
like [`jax.lax.scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html), PyTorch/XLA would enable `scan` using XLA::While operator.
This implementation would be very similar like `while_loop`. `scan` support autograd, and it could be used in both training and inference.

# while_loop
`while_loop` is a replacement of pure python while loop, PyTorch has supported `while_loop` in
[code](https://github.com/pytorch/pytorch/blob/ca6a0e1348ba7dcade1833d983b1b4ca12a5c1e1/torch/_higher_order_ops/while_loop.py#L69).
PyTorch/XLA want to support `while_loop` with the native PyTorch and the XLA backend: XLA::While.

User could use `while_loop` like this:
#### Usage:
```python
import torch_xla.experimental.fori_loop
from torch._higher_order_ops.while_loop import while_loop
res = while_loop(/*user-defined*/cond_fn, /*user-defined*/body_fn, /*tuple or list*/init)
result = while_loop(cond_fn, body_fn, init)
```
current while_loop only support simple test like [link](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py), and user could try [simple user guide](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/docs/fori_loop.md#simple-example-with-while_loop) with `while_loop` on TPU too.

- `cond_fn`: User-defined condition function.
- `body_fn`: User-defined loop body function.
- `init`: Initial values (tuple or list).

# [WIP]scan
like [`jax.lax.scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html), PyTorch/XLA would enable `scan` for training and inference since it support autograd.
`scan` is WIP.


# Simple user guide
User could try these three simple test case to better compare difference between `pure python for loop` and `fori_loop` and `while_loop`, these three test case have similar logic: cumulative plus 1 for ten times:

### simple example with pure python for loop
```bash
# python
>>> import torch
>>> init = torch.tensor([0], dtype=torch.int32)
>>> one_value = torch.ones(1, dtype=torch.int32)
>>>
>>> for i in range(10):
... init = init + one_value
...
>>> init
tensor([10], dtype=torch.int32)
```
### simple example with `while_loop`:
#### simple example with `while_loop`:
```bash
# PJRT_DEVICE=TPU python
>>> import torch
>>> import torch_xla
>>> import torch_xla.experimental.fori_loop
>>> from torch_xla.experimental.fori_loop import fori_loop
>>> from torch._higher_order_ops.while_loop import while_loop
>>> import torch_xla.core.xla_model as xm
>>> import torch_xla.core.xla_builder as xb
>>>
>>> device = xm.xla_device()
>>>
>>> def cond_fn(init, limit_value):
... return limit_value[0] >= init[0]
>>> def cond_fn(iteri, x):
... return iteri > 0
...
>>> def body_fn(init, limit_value):
... one_value = torch.ones(1, dtype=torch.int32, device=device)
... return (torch.add(init, one_value), limit_value.clone())
>>> def body_fn(iteri, x):
... return iteri - 1, torch.add(x, 1)
...
>>> init = torch.tensor([0], dtype=torch.int32, device=device)
>>> limit_value = torch.tensor([10], dtype=torch.int32, device=device)
>>> res_, limit_value_ = while_loop(cond_fn, body_fn, (init, limit_value))
>>> res_
>>> init_val = torch.tensor(3, device=device)
>>> iteri = torch.tensor(10, device=device)
>>> _, res = while_loop(cond_fn, body_fn, (iteri, init_val))
>>> res
FunctionalTensor(lvl=0, value=\
tensor([11], device='xla:0', dtype=torch.int32))
tensor(13, device='xla:0'))
```

### simple example with `fori_loop`:
<br>

## Control group test case
For better compare difference between `pure python while loop` and `while_loop`, there is one test case called pure python `while` loop with similar logic: cumulative plus 1 for ten times:

### Control group example with pure python `while` loop
```bash
# PJRT_DEVICE=TPU python
>>> import torch
>>> import torch_xla
>>> import torch_xla.experimental.fori_loop
>>> from torch_xla.experimental.fori_loop import fori_loop
>>> from torch._higher_order_ops.while_loop import while_loop
>>> import torch_xla.core.xla_model as xm
>>> import torch_xla.core.xla_builder as xb
>>>
>>> device = xm.xla_device()
>>>
>>> lower = torch.tensor([2], dtype=torch.int32, device=device)
>>> upper = torch.tensor([52], dtype=torch.int32, device=device)
>>> plus_value = torch.tensor([1], dtype=torch.int32, device=device)
>>> init_val = torch.tensor([1], dtype=torch.int32, device=device)
>>> init_val = torch.tensor(1, device=device)
>>> iteri = torch.tensor(50, device=device)
>>>
>>> def body_fun(*argus):
... plus_value, init_val = argus
... return plus_value, torch.add(plus_value, init_val)
>>> while iteri > 0:
... init_val = init_val + 1
... iteri -= 1
...
>>> _, _, _, res_ = fori_loop(upper, lower, body_fun, plus_value, init_val)
>>> res_
tensor([51], device='xla:0', dtype=torch.int32)
>>> init_val
tensor(51, device='xla:0')
```
For more example and detailed user guide, please read [this test file](https://github.com/pytorch/xla/blob/master/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py). PyTorch/XLA would include `while_loop` support in 2.3 for simple test case, complex test case and support for `fori_loop` and `scan` would be added after 2.3
PyTorch/XLA would include `while_loop` support in 2.4 with test case, support for `fori_loop` would be added after 2.4. For `while_loop`, currently we only should force define `body_fn` with same `input` and `output(return args)` shape
2 changes: 1 addition & 1 deletion test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ function run_xla_op_tests1 {
function run_xla_op_tests2 {
run_downcast_bf16 "$CDIR/test_data_type.py"
run_test "$CDIR/pjrt/test_dtypes.py"
run_test "$CDIR/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py"
run_test "$CDIR/test_while_loop.py"
run_test "$CDIR/test_autocast.py" # TODO(yeounoh) this is expensive on GPU
}

Expand Down
106 changes: 0 additions & 106 deletions test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py

This file was deleted.

116 changes: 116 additions & 0 deletions test/test_while_loop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import os
import unittest
from typing import Callable, Dict, List

import torch
import torch_xla
# We need to import the underlying implementation function to register with the dispatcher
import torch_xla.experimental.fori_loop
from torch_xla.experimental.fori_loop import fori_loop
from torch._higher_order_ops.while_loop import while_loop
import torch_xla.core.xla_model as xm
import torch_xla.core.xla_builder as xb
import torch_xla.utils.utils as xu
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim


def _fake_while_loop(cond_fn, body_fn, operands):
# operands need to be more than one here
while cond_fn(*operands):
operands = body_fn(*operands)
return operands


class WhileLoopTest(unittest.TestCase):

def test_while_loop_addition(self):
device = xm.xla_device()

def cond_fn(iteri, x):
return iteri > 0

def body_fn(iteri, x):
return iteri - 1, torch.add(x, 1)

init_val = torch.tensor(3, dtype=torch.int32, device=device)
iteri = torch.tensor(10, device=device)
_, res_with_loop = while_loop(cond_fn, body_fn, (iteri, init_val))
_, res_without_loop = _fake_while_loop(cond_fn, body_fn, (iteri, init_val))
self.assertTrue(torch.all(torch.eq(res_with_loop, res_without_loop)))

def test_while_loop_addition_nested(self):
device = xm.xla_device()

def cond_fn(iteri, x):
return iteri > 0

def body_fn(iteri, x):
return iteri - 1, torch.add(torch.add(x, 1), 1)

init_val = torch.tensor(2, dtype=torch.int32, device=device)
iteri = torch.tensor(10, device=device)
_, res_with_loop = while_loop(cond_fn, body_fn, (iteri, init_val))
_, res_without_loop = _fake_while_loop(cond_fn, body_fn, (iteri, init_val))
self.assertTrue(torch.all(torch.eq(res_with_loop, res_without_loop)))

def test_while_loop_simple_linear_inside_loop(self):
device = xm.xla_device()
torch.set_grad_enabled(False)

class SimpleLinear(torch.nn.Module):

def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(2, 2)

def forward(self, iteri, x):

def cond_fn(iteri, x):
return iteri > 0

def body_fn(iteri, x):
return iteri - 1, self.linear(x)

return while_loop(cond_fn, body_fn, (iteri, x))

def forward_without_while_loop_op(self, iteri, x):
while (iteri > 0):
x = self.linear(x)
iteri -= 1
return iteri, x

linear_model = SimpleLinear()
linear_model.to(device)
l_in_0 = torch.randn(2, 2, dtype=torch.float32, device=device)
iteri = torch.tensor(10, dtype=torch.int32, device=device)
_, res_with_loop = linear_model(iteri, l_in_0)
_, res_without_loop = linear_model.forward_without_while_loop_op(
iteri, l_in_0)

self.assertTrue(torch.all(torch.eq(res_with_loop, res_without_loop)))

# ====== fori_loop ======
@unittest.skip("Fori_loop is not supported now due to unstable result.")
def test_fori_loop_addition(self):
device = xm.xla_device()

lower = torch.tensor(0, device=device)
upper = torch.tensor(50, device=device)
init_val = torch.tensor(1, dtype=torch.int32, device=device)

def body_fun(x):
return torch.add(x, 1)

_, res_with_loop = fori_loop(lower, upper, body_fun, (init_val))

# === expected ===
for i in range(upper - lower):
init_val = torch.add(init_val, 1)
res_without_loop = init_val


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
2 changes: 1 addition & 1 deletion test/tpu/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ python3 test/dynamo/test_dynamo.py
python3 test/spmd/test_spmd_debugging.py
python3 test/pjrt/test_dtypes.py
python3 test/pjrt/test_dynamic_plugin_tpu.py
python3 test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py
python3 test/test_while_loop.py
python3 test/test_pallas.py
python3 test/test_pallas_spmd.py
python3 test/test_input_output_aliases.py
Expand Down
Loading

0 comments on commit 2793462

Please sign in to comment.