diff --git a/docs/fori_loop.md b/docs/fori_loop.md
index 0c9f85af399..c29e32e28b3 100644
--- a/docs/fori_loop.md
+++ b/docs/fori_loop.md
@@ -1,114 +1,72 @@
-# Fori_loop
-`fori_loop` is a replacement of pure python for loop, PyTorch/XLA would enable `torch_xla.experimental.fori_loop` to keep loop computation graph as rolled during compilation
-like [`jax.lax.fori_loop`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.fori_loop.html), not like currently repeat computations by enumerating all execution steps
-of each iteration. `fori_loop` might help memory utilization and might help faster compilation.
+# `While_loop` optimize memory utilization and compilation
-User could use `fori_loop` like this:
-```python
-from torch_xla.experimental.fori_loop import fori_loop
-res = fori_loop(upper, lower, /*user defined*/body_fun, init)
-```
-
-current fori_loop only support simple test like [link](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py), and user could try [simple user guide](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/docs/fori_loop.md#simple-example-with-fori_loop) with `fori_loop` on TPU too.
+
-For detailed implementation:
-- for situation that loop range is dynamic, [`fori_loop`](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/docs/fori_loop.md#fori_loop) is implemented with [`while_loop`](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/docs/fori_loop.md#while_loop),
-like [`jax.lax.while_loop`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.while_loop.html), PyTorch/XLA would support `while_loop` with the
-native PyTorch and the XLA backend: XLA::While. Due to `while_loop` didn't support autograd, so it would be used for inference only.
+### `while_loop`
+`while_loop` replace pure python `while` loop, PyTorch supported `while_loop` by
+[torch._higher_order_ops.while_loop](https://github.com/pytorch/pytorch/blob/62311257adb902d6a4ea98809c88895af1dbbf2b/torch/_higher_order_ops/while_loop.py#L66).
+PyTorch/XLA provide experimental XLA backend support for `torch._higher_order_ops.while_loop` via `XLA::While`.
-- for situation that loop range is not dynamic, [`fori_loop`](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/docs/fori_loop.md#fori_loop) is implemented with [`scan`](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/docs/fori_loop.md#wipscan),
-like [`jax.lax.scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html), PyTorch/XLA would enable `scan` using XLA::While operator.
-This implementation would be very similar like `while_loop`. `scan` support autograd, and it could be used in both training and inference.
-
-# while_loop
-`while_loop` is a replacement of pure python while loop, PyTorch has supported `while_loop` in
-[code](https://github.com/pytorch/pytorch/blob/ca6a0e1348ba7dcade1833d983b1b4ca12a5c1e1/torch/_higher_order_ops/while_loop.py#L69).
-PyTorch/XLA want to support `while_loop` with the native PyTorch and the XLA backend: XLA::While.
-
-User could use `while_loop` like this:
+#### Usage:
```python
import torch_xla.experimental.fori_loop
from torch._higher_order_ops.while_loop import while_loop
-res = while_loop(/*user-defined*/cond_fn, /*user-defined*/body_fn, /*tuple or list*/init)
+result = while_loop(cond_fn, body_fn, init)
```
-current while_loop only support simple test like [link](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py), and user could try [simple user guide](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/docs/fori_loop.md#simple-example-with-while_loop) with `while_loop` on TPU too.
-
+- `cond_fn`: User-defined condition function.
+- `body_fn`: User-defined loop body function.
+- `init`: Initial values (tuple or list).
-# [WIP]scan
-like [`jax.lax.scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html), PyTorch/XLA would enable `scan` for training and inference since it support autograd.
-`scan` is WIP.
-
-
-# Simple user guide
-User could try these three simple test case to better compare difference between `pure python for loop` and `fori_loop` and `while_loop`, these three test case have similar logic: cumulative plus 1 for ten times:
-
-### simple example with pure python for loop
-```bash
-# python
->>> import torch
->>> init = torch.tensor([0], dtype=torch.int32)
->>> one_value = torch.ones(1, dtype=torch.int32)
->>>
->>> for i in range(10):
-... init = init + one_value
-...
->>> init
-tensor([10], dtype=torch.int32)
-```
-
-### simple example with `while_loop`:
+#### simple example with `while_loop`:
```bash
# PJRT_DEVICE=TPU python
>>> import torch
>>> import torch_xla
>>> import torch_xla.experimental.fori_loop
->>> from torch_xla.experimental.fori_loop import fori_loop
>>> from torch._higher_order_ops.while_loop import while_loop
>>> import torch_xla.core.xla_model as xm
->>> import torch_xla.core.xla_builder as xb
>>>
>>> device = xm.xla_device()
>>>
->>> def cond_fn(init, limit_value):
-... return limit_value[0] >= init[0]
+>>> def cond_fn(iteri, x):
+... return iteri > 0
...
->>> def body_fn(init, limit_value):
-... one_value = torch.ones(1, dtype=torch.int32, device=device)
-... return (torch.add(init, one_value), limit_value.clone())
+>>> def body_fn(iteri, x):
+... return iteri - 1, torch.add(x, 1)
...
->>> init = torch.tensor([0], dtype=torch.int32, device=device)
->>> limit_value = torch.tensor([10], dtype=torch.int32, device=device)
->>> res_, limit_value_ = while_loop(cond_fn, body_fn, (init, limit_value))
->>> res_
+>>> init_val = torch.tensor(3, device=device)
+>>> iteri = torch.tensor(10, device=device)
+>>> _, res = while_loop(cond_fn, body_fn, (iteri, init_val))
+>>> res
FunctionalTensor(lvl=0, value=\
-tensor([11], device='xla:0', dtype=torch.int32))
+tensor(13, device='xla:0'))
```
-### simple example with `fori_loop`:
+
+
+## Control group test case
+For better compare difference between `pure python while loop` and `while_loop`, there is one test case called pure python `while` loop with similar logic: cumulative plus 1 for ten times:
+
+### Control group example with pure python `while` loop
```bash
# PJRT_DEVICE=TPU python
>>> import torch
>>> import torch_xla
->>> import torch_xla.experimental.fori_loop
->>> from torch_xla.experimental.fori_loop import fori_loop
->>> from torch._higher_order_ops.while_loop import while_loop
>>> import torch_xla.core.xla_model as xm
->>> import torch_xla.core.xla_builder as xb
>>>
>>> device = xm.xla_device()
>>>
->>> lower = torch.tensor([2], dtype=torch.int32, device=device)
->>> upper = torch.tensor([52], dtype=torch.int32, device=device)
->>> plus_value = torch.tensor([1], dtype=torch.int32, device=device)
->>> init_val = torch.tensor([1], dtype=torch.int32, device=device)
+>>> init_val = torch.tensor(1, device=device)
+>>> iteri = torch.tensor(50, device=device)
>>>
->>> def body_fun(*argus):
-... plus_value, init_val = argus
-... return plus_value, torch.add(plus_value, init_val)
+>>> while iteri > 0:
+... init_val = init_val + 1
+... iteri -= 1
...
->>> _, _, _, res_ = fori_loop(upper, lower, body_fun, plus_value, init_val)
->>> res_
-tensor([51], device='xla:0', dtype=torch.int32)
+>>> init_val
+tensor(51, device='xla:0')
```
-For more example and detailed user guide, please read [this test file](https://github.com/pytorch/xla/blob/master/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py). PyTorch/XLA would include `while_loop` support in 2.3 for simple test case, complex test case and support for `fori_loop` and `scan` would be added after 2.3
+
+
+PyTorch/XLA would include `while_loop` support in 2.4 with test case, support for `fori_loop` would be added after 2.4. For `while_loop`, currently we only should force define `body_fn` with same `input` and `output(return args)` shape
diff --git a/test/run_tests.sh b/test/run_tests.sh
index 4a298f01ee5..26d3c82303e 100755
--- a/test/run_tests.sh
+++ b/test/run_tests.sh
@@ -203,7 +203,7 @@ function run_xla_op_tests1 {
function run_xla_op_tests2 {
run_downcast_bf16 "$CDIR/test_data_type.py"
run_test "$CDIR/pjrt/test_dtypes.py"
- run_test "$CDIR/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py"
+ run_test "$CDIR/test_while_loop.py"
run_test "$CDIR/test_autocast.py" # TODO(yeounoh) this is expensive on GPU
}
diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py
deleted file mode 100644
index a76197cc736..00000000000
--- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py
+++ /dev/null
@@ -1,106 +0,0 @@
-import os
-import unittest
-from typing import Callable, Dict, List
-
-import torch
-import torch_xla
-# We need to import the underlying implementation function to register with the dispatcher
-import torch_xla.experimental.fori_loop
-from torch_xla.experimental.fori_loop import fori_loop
-from torch._higher_order_ops.while_loop import while_loop
-import torch_xla.core.xla_model as xm
-import torch_xla.core.xla_builder as xb
-
-
-def _fake_while_loop(cond_fn, body_fn, operands):
- # operands need to be more than one here
- while cond_fn(*operands):
- operands = body_fn(*operands)
- return operands
-
-
-def _fake_fori_loop(lower, upper, body_fun, *init_val):
- (plus_value, init_val) = init_val
- for i in range((upper - lower)[0]):
- plus_value, init_val = body_fun(plus_value, init_val)
- return init_val
-
-
-class WhileLoopTest(unittest.TestCase):
-
- def test_while_loop_tpu_subtraction(self):
-
- device = xm.xla_device()
-
- def cond_fn(init, limit_value):
- return limit_value[0] <= init[0]
-
- def body_fn(init, limit_value):
- one_value = torch.ones(1, dtype=torch.int32, device=device)
- two_value = limit_value.clone()
- return (torch.sub(init, one_value), two_value)
-
- init = torch.tensor([10], dtype=torch.int32, device=device)
- limit_value = torch.tensor([0], dtype=torch.int32, device=device)
- res = while_loop(cond_fn, body_fn, (init, limit_value))
- expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value))
- self.assertEqual(expected, res)
-
- def test_while_loop_tpu_addition(self):
-
- device = xm.xla_device()
-
- def cond_fn(init, limit_value):
- return limit_value[0] >= init[0]
-
- def body_fn(init, limit_value):
- one_value = torch.ones(1, dtype=torch.int32, device=device)
- return (torch.add(init, one_value), limit_value.clone())
-
- # TODO(@manfei): init and limit_value has to be torch.tensor.
- init = torch.tensor([0], dtype=torch.int32, device=device)
- limit_value = torch.tensor([10], dtype=torch.int32, device=device)
- res = while_loop(cond_fn, body_fn, (init, limit_value))
- expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value))
- self.assertEqual(expected, res)
-
- def test_while_loop_tpu_subtraction_nested(self):
-
- device = xm.xla_device()
-
- def cond_fn(init, limit_value):
- return limit_value[0] <= init[0]
-
- def body_fn(init, limit_value):
- one_value = torch.ones(1, dtype=torch.int32, device=device)
- two_value = limit_value.clone()
- return (torch.sub(torch.sub(init, one_value), one_value), two_value)
-
- init = torch.tensor([10], dtype=torch.int32, device=device)
- limit_value = torch.tensor([0], dtype=torch.int32, device=device)
- res = while_loop(cond_fn, body_fn, (init, limit_value))
- expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value))
- self.assertEqual(expected, res)
-
- def test_fori_loop_tpu_addition(self):
-
- xm.mark_step()
- device = xm.xla_device()
-
- lower = torch.tensor([2], dtype=torch.int32, device=device)
- upper = torch.tensor([52], dtype=torch.int32, device=device)
- plus_value = torch.tensor([1], dtype=torch.int32, device=device)
- init_val = torch.tensor([1], dtype=torch.int32, device=device)
-
- def body_fun(*argus):
- plus_value, init_val = argus
- return plus_value, torch.add(plus_value, init_val)
-
- _, _, _, actual = fori_loop(upper, lower, body_fun, plus_value, init_val)
- expected = _fake_fori_loop(lower, upper, body_fun, plus_value, init_val)
- self.assertEqual(expected, actual)
-
-
-if __name__ == '__main__':
- test = unittest.main()
- sys.exit(0 if test.result.wasSuccessful() else 1)
\ No newline at end of file
diff --git a/test/test_while_loop.py b/test/test_while_loop.py
new file mode 100644
index 00000000000..e8ea617b0f9
--- /dev/null
+++ b/test/test_while_loop.py
@@ -0,0 +1,116 @@
+import os
+import unittest
+from typing import Callable, Dict, List
+
+import torch
+import torch_xla
+# We need to import the underlying implementation function to register with the dispatcher
+import torch_xla.experimental.fori_loop
+from torch_xla.experimental.fori_loop import fori_loop
+from torch._higher_order_ops.while_loop import while_loop
+import torch_xla.core.xla_model as xm
+import torch_xla.core.xla_builder as xb
+import torch_xla.utils.utils as xu
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+
+
+def _fake_while_loop(cond_fn, body_fn, operands):
+ # operands need to be more than one here
+ while cond_fn(*operands):
+ operands = body_fn(*operands)
+ return operands
+
+
+class WhileLoopTest(unittest.TestCase):
+
+ def test_while_loop_addition(self):
+ device = xm.xla_device()
+
+ def cond_fn(iteri, x):
+ return iteri > 0
+
+ def body_fn(iteri, x):
+ return iteri - 1, torch.add(x, 1)
+
+ init_val = torch.tensor(3, dtype=torch.int32, device=device)
+ iteri = torch.tensor(10, device=device)
+ _, res_with_loop = while_loop(cond_fn, body_fn, (iteri, init_val))
+ _, res_without_loop = _fake_while_loop(cond_fn, body_fn, (iteri, init_val))
+ self.assertTrue(torch.all(torch.eq(res_with_loop, res_without_loop)))
+
+ def test_while_loop_addition_nested(self):
+ device = xm.xla_device()
+
+ def cond_fn(iteri, x):
+ return iteri > 0
+
+ def body_fn(iteri, x):
+ return iteri - 1, torch.add(torch.add(x, 1), 1)
+
+ init_val = torch.tensor(2, dtype=torch.int32, device=device)
+ iteri = torch.tensor(10, device=device)
+ _, res_with_loop = while_loop(cond_fn, body_fn, (iteri, init_val))
+ _, res_without_loop = _fake_while_loop(cond_fn, body_fn, (iteri, init_val))
+ self.assertTrue(torch.all(torch.eq(res_with_loop, res_without_loop)))
+
+ def test_while_loop_simple_linear_inside_loop(self):
+ device = xm.xla_device()
+ torch.set_grad_enabled(False)
+
+ class SimpleLinear(torch.nn.Module):
+
+ def __init__(self):
+ super().__init__()
+ self.linear = torch.nn.Linear(2, 2)
+
+ def forward(self, iteri, x):
+
+ def cond_fn(iteri, x):
+ return iteri > 0
+
+ def body_fn(iteri, x):
+ return iteri - 1, self.linear(x)
+
+ return while_loop(cond_fn, body_fn, (iteri, x))
+
+ def forward_without_while_loop_op(self, iteri, x):
+ while (iteri > 0):
+ x = self.linear(x)
+ iteri -= 1
+ return iteri, x
+
+ linear_model = SimpleLinear()
+ linear_model.to(device)
+ l_in_0 = torch.randn(2, 2, dtype=torch.float32, device=device)
+ iteri = torch.tensor(10, dtype=torch.int32, device=device)
+ _, res_with_loop = linear_model(iteri, l_in_0)
+ _, res_without_loop = linear_model.forward_without_while_loop_op(
+ iteri, l_in_0)
+
+ self.assertTrue(torch.all(torch.eq(res_with_loop, res_without_loop)))
+
+ # ====== fori_loop ======
+ @unittest.skip("Fori_loop is not supported now due to unstable result.")
+ def test_fori_loop_addition(self):
+ device = xm.xla_device()
+
+ lower = torch.tensor(0, device=device)
+ upper = torch.tensor(50, device=device)
+ init_val = torch.tensor(1, dtype=torch.int32, device=device)
+
+ def body_fun(x):
+ return torch.add(x, 1)
+
+ _, res_with_loop = fori_loop(lower, upper, body_fun, (init_val))
+
+ # === expected ===
+ for i in range(upper - lower):
+ init_val = torch.add(init_val, 1)
+ res_without_loop = init_val
+
+
+if __name__ == '__main__':
+ test = unittest.main()
+ sys.exit(0 if test.result.wasSuccessful() else 1)
diff --git a/test/tpu/run_tests.sh b/test/tpu/run_tests.sh
index 6d74c40d4e3..ddd439d1c60 100755
--- a/test/tpu/run_tests.sh
+++ b/test/tpu/run_tests.sh
@@ -20,7 +20,7 @@ python3 test/dynamo/test_dynamo.py
python3 test/spmd/test_spmd_debugging.py
python3 test/pjrt/test_dtypes.py
python3 test/pjrt/test_dynamic_plugin_tpu.py
-python3 test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py
+python3 test/test_while_loop.py
python3 test/test_pallas.py
python3 test/test_pallas_spmd.py
python3 test/test_input_output_aliases.py
diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp
index 3fba13773b3..db3e54d8163 100644
--- a/torch_xla/csrc/init_python_bindings.cpp
+++ b/torch_xla/csrc/init_python_bindings.cpp
@@ -933,22 +933,9 @@ class PyLoweringContext {
}
// Builds a HLO graph given a set of output tensors, and add unused parameters
- // needed in xlacomputation.
+ // needed in xlacomputation for fori_loop/while_loop.
void BuildForiLoop(std::vector tensors,
- std::vector input_arguments = {}) {
- if (GetNameString() == "condctx") {
- xla::XlaBuilder* local_builder = lowering_ctx.builder();
- // hard-code parameter_idx to 2 to skip existing upper/lower arguments
- int64_t parameter_idx = 2;
- for (at::Tensor input_argument : input_arguments) {
- xla::Shape shape =
- xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, {1});
- xla::XlaOp x = xla::Parameter(local_builder, parameter_idx, shape,
- "UnusedArgumentsPlaceholder");
- parameter_idx += 1;
- }
- }
-
+ std::vector additional_inputs_list = {}) {
// Get the backing XLA tensors from the output torch tensor handles
std::vector xtensors =
GetXlaTensors(tensors, /*want_all=*/true);
@@ -966,6 +953,24 @@ class PyLoweringContext {
torch::lazy::Output(ir_value.node.get(), ir_value.index));
lowering_ctx.AddResult(root);
}
+
+ // add dummy parameter to cond/body xlacomputation's input for xla::while
+ // requriement
+ if ((GetNameString() == "condctx") or
+ (GetNameString() == "bodyctx" && additional_inputs_list.size() != 0)) {
+ xla::XlaBuilder* local_builder = lowering_ctx.builder();
+ int64_t parameter_idx =
+ local_builder->GetProgramShape()->parameters_size();
+ int64_t additional_inputs_list_size = additional_inputs_list.size();
+ for (int64_t i = parameter_idx; i < additional_inputs_list_size; i++) {
+ XLATensorPtr xtensor = bridge::GetXlaTensor(additional_inputs_list[i]);
+ xla::Shape shape = xtensor->shape().get();
+ xla::XlaOp x = xla::Parameter(local_builder, parameter_idx, shape,
+ "UnusedArgumentsPlaceholder");
+ parameter_idx += 1;
+ }
+ }
+
computation = ConsumeValue(lowering_ctx.BuildXla());
// wrap inputs of cond/body_computation
diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py
index bf32a712f3e..e41709084e2 100644
--- a/torch_xla/experimental/fori_loop.py
+++ b/torch_xla/experimental/fori_loop.py
@@ -10,80 +10,141 @@
from torch._ops import HigherOrderOperator
import torch._higher_order_ops.while_loop
from torch._higher_order_ops.while_loop import while_loop_op
+from torch._higher_order_ops.while_loop import while_loop as torch_while_loop
+from torch._higher_order_ops.utils import _has_potential_branch_input_mutation
-def fori_loop(lower, upper, user_body_func, *init_val):
+def fori_loop(lower, upper, body_fun, *input_value):
device = xm.xla_device()
+ if (upper < lower):
+ print("ERROR: upper should be a larger number than lower")
+ iteri = upper - lower
- def cond_fn(upper, lower, *init_val):
- return lower[0] < upper[0]
+ def cond_fn(iteri, *input_value):
+ return iteri > 0
- def body_fn(upper, lower, *init_val):
- one_value_i = torch.ones(1, dtype=torch.int32, device=device)
- res_list = list(user_body_func(*init_val))
- res_list.insert(0, lower)
- res_list.insert(0, torch.sub(upper, one_value_i))
- return res_list
+ def new_body_fn(iteri, *input_value):
+ return iteri - 1, body_fun(*input_value)
+
+ inputs = (iteri,) + input_value
+ res = _xla_while_loop_wrapper(
+ cond_fn, new_body_fn, inputs, (), fake_tensor=True)
- res = while_loop(cond_fn, body_fn, (lower, upper, *init_val))
return res
@while_loop_op.py_impl(DispatchKey.XLA)
-def while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=None):
- # TODO(@manfei): PyTorch require carried_inputs to be list/tuple, PyTorch/XLA _xla_while_loop only accept *operands, *operands would tuple items again: (a, '')
- # cond_fn&body_fn: callable
- # carried_inputs: (Tuple of possibly nested dict/list/tuple of tensors)
+def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None):
if additional_inputs is None:
additional_inputs = tuple()
- return _xla_while_loop(
- cond_fn, body_fn, *carried_inputs, additional_inputs=additional_inputs)
+ return _xla_while_loop_wrapper(cond_fn, body_fn, carried_inputs,
+ additional_inputs)
+
+
+def _xla_while_loop_wrapper(cond_fn,
+ body_fn,
+ carried_inputs,
+ additional_inputs=None,
+ fake_tensor=False):
+
+ def new_body_fn(*carried_inputs):
+ res = list(body_fn(*carried_inputs))
+ if additional_inputs:
+ res = [
+ res[0],
+ ] + list(additional_inputs) + res[1:]
+ else:
+ res = res
+ return res
+ return _xla_while_loop(cond_fn, new_body_fn, carried_inputs,
+ additional_inputs, fake_tensor)
-def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs):
- # untuple carried_inputs from while_loop
- carried_inputs = carried_inputs[0]
- # fake carried_inputs to split formal code
+
+def _xla_while_loop(cond_fn,
+ body_fn,
+ carried_inputs,
+ additional_inputs=None,
+ fake_tensor=False):
+
+ # ====== fake_carried_inputs ======
fake_carried_inputs = []
for carried_input in carried_inputs:
device = carried_input.device
fake_carried_inputs.append(
torch.randint(10, carried_input.size(),
dtype=carried_input.dtype).to(device))
- fake_carried_inputs = tuple(fake_carried_inputs)
- # trans fake_carried_inputs from list(tensor) to list(xla::op)
- kwargs = {}
- if type(fake_carried_inputs) is tuple:
- shapes = xb.tensor_shape(fake_carried_inputs)
+ # ====== additional_inputs_list_cond ======
+ fake_additiona_args = []
+ for additional_input in additional_inputs:
+ device = additional_input.device
+ fake_additiona_args.append(
+ torch.randint(
+ 10, additional_input.size(),
+ dtype=additional_input.dtype).to(device))
+
+ # ====== inputs_list ======
+ # specify body_fn_inputs/cond_fn_inputs, and add caught additional_inputs into fn_inputs
+ if additional_inputs or fake_tensor:
+ # replace inputs(carried_inputs[1:]) with fake tensors to fix missed arguments problem
+ body_fn_inputs = [
+ carried_inputs[0],
+ ] + fake_carried_inputs[1:] + list(additional_inputs)
+ cond_fn_inputs = carried_inputs + additional_inputs
else:
- shapes = xb.tensor_shape((fake_carried_inputs))
- builder = xb.create_builder('test_while')
- params = []
- for shape in shapes:
- p = xb.mkparam(builder, len(params), shape)
- params.append(p)
+ body_fn_inputs = carried_inputs
+ cond_fn_inputs = carried_inputs
+
+ # due to `xla::While` requirement, body xlacomputation inputs/outputs, cond xlacomputation and init need to be the same shape and type;
+ # and carried_inputs contain (iter, values), additional_inputs contain (weights/bias)
+ # based on generated body xlacomputation outputs: (iter, weights/bias, values)
+ # we create expected order for cond/body xlacomputation generation to compare and match: (iter, weights/bias, values)
+ dummy_inputs_list = [
+ fake_carried_inputs[0],
+ ] + fake_additiona_args + fake_carried_inputs[1:]
+
+ # ====== body_fn ======
+ body_result = body_fn(*body_fn_inputs)
+ body_ctx = torch_xla._XLAC.lowering.LoweringContext()
+ body_ctx.set_name_string("bodyctx")
- # generate cond_fn xlacomputation
- cond_result = cond_fn(*fake_carried_inputs)
+ # ====== body xlacomputation ======
+ body_ctx.buildforiloop(list(body_result), dummy_inputs_list)
+ body_hlo = body_ctx.hlo()
+ body_computation = xb.computation_from_module_proto("bodycomputation",
+ body_hlo)
+
+ # ====== cond_fn ======
+ cond_result = cond_fn(*cond_fn_inputs)
cond_ctx = torch_xla._XLAC.lowering.LoweringContext()
cond_ctx.set_name_string("condctx")
- cond_ctx.buildforiloop([cond_result], list(fake_carried_inputs[2:]))
+
+ # ====== cond xlacomputation ======
+ cond_ctx.buildforiloop([cond_result], dummy_inputs_list)
cond_hlo = cond_ctx.hlo()
cond_computation = xb.computation_from_module_proto("condcomputation",
cond_hlo)
- # generate body_fn xlacomputation
- body_result = body_fn(*fake_carried_inputs)
- body_ctx = torch_xla._XLAC.lowering.LoweringContext()
- body_ctx.set_name_string("bodyctx")
- body_ctx.buildforiloop(list(body_result), [])
- body_hlo = body_ctx.hlo()
- body_computation = xb.computation_from_module_proto("bodycomputation",
- body_hlo)
+ # ====== xla::while ======
+ iter_value = carried_inputs[0]
+ input_and_outputs_value = carried_inputs[1:]
+ total_inputs = tuple([
+ iter_value,
+ ]) + tuple(additional_inputs) + tuple(input_and_outputs_value)
+
+ kwargs = {}
+ if type(total_inputs) is tuple:
+ shapes = xb.tensor_shape(total_inputs)
+ else:
+ shapes = xb.tensor_shape((total_inputs))
+ builder = xb.create_builder('while_loop')
+ params = []
+ for shape in shapes:
+ p = xb.mkparam(builder, len(params), shape)
+ params.append(p)
- # generate while xlacomputation
input_tuple = xb.Op.tuple(tuple(params))
w = xb.mkop(
'While', (input_tuple.op,),
@@ -94,6 +155,12 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs):
# gain final result with generated while xlacomputation
result = torch_xla._XLAC._xla_user_computation('xla::_op_test_while',
- (carried_inputs), computation)
+ (total_inputs), computation)
+
+ # unwrapper result without additional_inputs for original order
+ additional_inputs_len = len(additional_inputs) + 1
+ final_res = [
+ result[0],
+ ] + result[additional_inputs_len:]
- return result
\ No newline at end of file
+ return final_res