diff --git a/docs/fori_loop.md b/docs/fori_loop.md index 0c9f85af399..c29e32e28b3 100644 --- a/docs/fori_loop.md +++ b/docs/fori_loop.md @@ -1,114 +1,72 @@ -# Fori_loop -`fori_loop` is a replacement of pure python for loop, PyTorch/XLA would enable `torch_xla.experimental.fori_loop` to keep loop computation graph as rolled during compilation -like [`jax.lax.fori_loop`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.fori_loop.html), not like currently repeat computations by enumerating all execution steps -of each iteration. `fori_loop` might help memory utilization and might help faster compilation. +# `While_loop` optimize memory utilization and compilation -User could use `fori_loop` like this: -```python -from torch_xla.experimental.fori_loop import fori_loop -res = fori_loop(upper, lower, /*user defined*/body_fun, init) -``` - -current fori_loop only support simple test like [link](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py), and user could try [simple user guide](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/docs/fori_loop.md#simple-example-with-fori_loop) with `fori_loop` on TPU too. +
-For detailed implementation: -- for situation that loop range is dynamic, [`fori_loop`](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/docs/fori_loop.md#fori_loop) is implemented with [`while_loop`](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/docs/fori_loop.md#while_loop), -like [`jax.lax.while_loop`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.while_loop.html), PyTorch/XLA would support `while_loop` with the -native PyTorch and the XLA backend: XLA::While. Due to `while_loop` didn't support autograd, so it would be used for inference only. +### `while_loop` +`while_loop` replace pure python `while` loop, PyTorch supported `while_loop` by +[torch._higher_order_ops.while_loop](https://github.com/pytorch/pytorch/blob/62311257adb902d6a4ea98809c88895af1dbbf2b/torch/_higher_order_ops/while_loop.py#L66). +PyTorch/XLA provide experimental XLA backend support for `torch._higher_order_ops.while_loop` via `XLA::While`. -- for situation that loop range is not dynamic, [`fori_loop`](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/docs/fori_loop.md#fori_loop) is implemented with [`scan`](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/docs/fori_loop.md#wipscan), -like [`jax.lax.scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html), PyTorch/XLA would enable `scan` using XLA::While operator. -This implementation would be very similar like `while_loop`. `scan` support autograd, and it could be used in both training and inference. - -# while_loop -`while_loop` is a replacement of pure python while loop, PyTorch has supported `while_loop` in -[code](https://github.com/pytorch/pytorch/blob/ca6a0e1348ba7dcade1833d983b1b4ca12a5c1e1/torch/_higher_order_ops/while_loop.py#L69). -PyTorch/XLA want to support `while_loop` with the native PyTorch and the XLA backend: XLA::While. - -User could use `while_loop` like this: +#### Usage: ```python import torch_xla.experimental.fori_loop from torch._higher_order_ops.while_loop import while_loop -res = while_loop(/*user-defined*/cond_fn, /*user-defined*/body_fn, /*tuple or list*/init) +result = while_loop(cond_fn, body_fn, init) ``` -current while_loop only support simple test like [link](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py), and user could try [simple user guide](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/docs/fori_loop.md#simple-example-with-while_loop) with `while_loop` on TPU too. - +- `cond_fn`: User-defined condition function. +- `body_fn`: User-defined loop body function. +- `init`: Initial values (tuple or list). -# [WIP]scan -like [`jax.lax.scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html), PyTorch/XLA would enable `scan` for training and inference since it support autograd. -`scan` is WIP. - - -# Simple user guide -User could try these three simple test case to better compare difference between `pure python for loop` and `fori_loop` and `while_loop`, these three test case have similar logic: cumulative plus 1 for ten times: - -### simple example with pure python for loop -```bash -# python ->>> import torch ->>> init = torch.tensor([0], dtype=torch.int32) ->>> one_value = torch.ones(1, dtype=torch.int32) ->>> ->>> for i in range(10): -... init = init + one_value -... ->>> init -tensor([10], dtype=torch.int32) -``` - -### simple example with `while_loop`: +#### simple example with `while_loop`: ```bash # PJRT_DEVICE=TPU python >>> import torch >>> import torch_xla >>> import torch_xla.experimental.fori_loop ->>> from torch_xla.experimental.fori_loop import fori_loop >>> from torch._higher_order_ops.while_loop import while_loop >>> import torch_xla.core.xla_model as xm ->>> import torch_xla.core.xla_builder as xb >>> >>> device = xm.xla_device() >>> ->>> def cond_fn(init, limit_value): -... return limit_value[0] >= init[0] +>>> def cond_fn(iteri, x): +... return iteri > 0 ... ->>> def body_fn(init, limit_value): -... one_value = torch.ones(1, dtype=torch.int32, device=device) -... return (torch.add(init, one_value), limit_value.clone()) +>>> def body_fn(iteri, x): +... return iteri - 1, torch.add(x, 1) ... ->>> init = torch.tensor([0], dtype=torch.int32, device=device) ->>> limit_value = torch.tensor([10], dtype=torch.int32, device=device) ->>> res_, limit_value_ = while_loop(cond_fn, body_fn, (init, limit_value)) ->>> res_ +>>> init_val = torch.tensor(3, device=device) +>>> iteri = torch.tensor(10, device=device) +>>> _, res = while_loop(cond_fn, body_fn, (iteri, init_val)) +>>> res FunctionalTensor(lvl=0, value=\ -tensor([11], device='xla:0', dtype=torch.int32)) +tensor(13, device='xla:0')) ``` -### simple example with `fori_loop`: +
+ +## Control group test case +For better compare difference between `pure python while loop` and `while_loop`, there is one test case called pure python `while` loop with similar logic: cumulative plus 1 for ten times: + +### Control group example with pure python `while` loop ```bash # PJRT_DEVICE=TPU python >>> import torch >>> import torch_xla ->>> import torch_xla.experimental.fori_loop ->>> from torch_xla.experimental.fori_loop import fori_loop ->>> from torch._higher_order_ops.while_loop import while_loop >>> import torch_xla.core.xla_model as xm ->>> import torch_xla.core.xla_builder as xb >>> >>> device = xm.xla_device() >>> ->>> lower = torch.tensor([2], dtype=torch.int32, device=device) ->>> upper = torch.tensor([52], dtype=torch.int32, device=device) ->>> plus_value = torch.tensor([1], dtype=torch.int32, device=device) ->>> init_val = torch.tensor([1], dtype=torch.int32, device=device) +>>> init_val = torch.tensor(1, device=device) +>>> iteri = torch.tensor(50, device=device) >>> ->>> def body_fun(*argus): -... plus_value, init_val = argus -... return plus_value, torch.add(plus_value, init_val) +>>> while iteri > 0: +... init_val = init_val + 1 +... iteri -= 1 ... ->>> _, _, _, res_ = fori_loop(upper, lower, body_fun, plus_value, init_val) ->>> res_ -tensor([51], device='xla:0', dtype=torch.int32) +>>> init_val +tensor(51, device='xla:0') ``` -For more example and detailed user guide, please read [this test file](https://github.com/pytorch/xla/blob/master/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py). PyTorch/XLA would include `while_loop` support in 2.3 for simple test case, complex test case and support for `fori_loop` and `scan` would be added after 2.3 + + +PyTorch/XLA would include `while_loop` support in 2.4 with test case, support for `fori_loop` would be added after 2.4. For `while_loop`, currently we only should force define `body_fn` with same `input` and `output(return args)` shape diff --git a/test/run_tests.sh b/test/run_tests.sh index 4a298f01ee5..26d3c82303e 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -203,7 +203,7 @@ function run_xla_op_tests1 { function run_xla_op_tests2 { run_downcast_bf16 "$CDIR/test_data_type.py" run_test "$CDIR/pjrt/test_dtypes.py" - run_test "$CDIR/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py" + run_test "$CDIR/test_while_loop.py" run_test "$CDIR/test_autocast.py" # TODO(yeounoh) this is expensive on GPU } diff --git a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py b/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py deleted file mode 100644 index a76197cc736..00000000000 --- a/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +++ /dev/null @@ -1,106 +0,0 @@ -import os -import unittest -from typing import Callable, Dict, List - -import torch -import torch_xla -# We need to import the underlying implementation function to register with the dispatcher -import torch_xla.experimental.fori_loop -from torch_xla.experimental.fori_loop import fori_loop -from torch._higher_order_ops.while_loop import while_loop -import torch_xla.core.xla_model as xm -import torch_xla.core.xla_builder as xb - - -def _fake_while_loop(cond_fn, body_fn, operands): - # operands need to be more than one here - while cond_fn(*operands): - operands = body_fn(*operands) - return operands - - -def _fake_fori_loop(lower, upper, body_fun, *init_val): - (plus_value, init_val) = init_val - for i in range((upper - lower)[0]): - plus_value, init_val = body_fun(plus_value, init_val) - return init_val - - -class WhileLoopTest(unittest.TestCase): - - def test_while_loop_tpu_subtraction(self): - - device = xm.xla_device() - - def cond_fn(init, limit_value): - return limit_value[0] <= init[0] - - def body_fn(init, limit_value): - one_value = torch.ones(1, dtype=torch.int32, device=device) - two_value = limit_value.clone() - return (torch.sub(init, one_value), two_value) - - init = torch.tensor([10], dtype=torch.int32, device=device) - limit_value = torch.tensor([0], dtype=torch.int32, device=device) - res = while_loop(cond_fn, body_fn, (init, limit_value)) - expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value)) - self.assertEqual(expected, res) - - def test_while_loop_tpu_addition(self): - - device = xm.xla_device() - - def cond_fn(init, limit_value): - return limit_value[0] >= init[0] - - def body_fn(init, limit_value): - one_value = torch.ones(1, dtype=torch.int32, device=device) - return (torch.add(init, one_value), limit_value.clone()) - - # TODO(@manfei): init and limit_value has to be torch.tensor. - init = torch.tensor([0], dtype=torch.int32, device=device) - limit_value = torch.tensor([10], dtype=torch.int32, device=device) - res = while_loop(cond_fn, body_fn, (init, limit_value)) - expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value)) - self.assertEqual(expected, res) - - def test_while_loop_tpu_subtraction_nested(self): - - device = xm.xla_device() - - def cond_fn(init, limit_value): - return limit_value[0] <= init[0] - - def body_fn(init, limit_value): - one_value = torch.ones(1, dtype=torch.int32, device=device) - two_value = limit_value.clone() - return (torch.sub(torch.sub(init, one_value), one_value), two_value) - - init = torch.tensor([10], dtype=torch.int32, device=device) - limit_value = torch.tensor([0], dtype=torch.int32, device=device) - res = while_loop(cond_fn, body_fn, (init, limit_value)) - expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value)) - self.assertEqual(expected, res) - - def test_fori_loop_tpu_addition(self): - - xm.mark_step() - device = xm.xla_device() - - lower = torch.tensor([2], dtype=torch.int32, device=device) - upper = torch.tensor([52], dtype=torch.int32, device=device) - plus_value = torch.tensor([1], dtype=torch.int32, device=device) - init_val = torch.tensor([1], dtype=torch.int32, device=device) - - def body_fun(*argus): - plus_value, init_val = argus - return plus_value, torch.add(plus_value, init_val) - - _, _, _, actual = fori_loop(upper, lower, body_fun, plus_value, init_val) - expected = _fake_fori_loop(lower, upper, body_fun, plus_value, init_val) - self.assertEqual(expected, actual) - - -if __name__ == '__main__': - test = unittest.main() - sys.exit(0 if test.result.wasSuccessful() else 1) \ No newline at end of file diff --git a/test/test_while_loop.py b/test/test_while_loop.py new file mode 100644 index 00000000000..e8ea617b0f9 --- /dev/null +++ b/test/test_while_loop.py @@ -0,0 +1,116 @@ +import os +import unittest +from typing import Callable, Dict, List + +import torch +import torch_xla +# We need to import the underlying implementation function to register with the dispatcher +import torch_xla.experimental.fori_loop +from torch_xla.experimental.fori_loop import fori_loop +from torch._higher_order_ops.while_loop import while_loop +import torch_xla.core.xla_model as xm +import torch_xla.core.xla_builder as xb +import torch_xla.utils.utils as xu +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim + + +def _fake_while_loop(cond_fn, body_fn, operands): + # operands need to be more than one here + while cond_fn(*operands): + operands = body_fn(*operands) + return operands + + +class WhileLoopTest(unittest.TestCase): + + def test_while_loop_addition(self): + device = xm.xla_device() + + def cond_fn(iteri, x): + return iteri > 0 + + def body_fn(iteri, x): + return iteri - 1, torch.add(x, 1) + + init_val = torch.tensor(3, dtype=torch.int32, device=device) + iteri = torch.tensor(10, device=device) + _, res_with_loop = while_loop(cond_fn, body_fn, (iteri, init_val)) + _, res_without_loop = _fake_while_loop(cond_fn, body_fn, (iteri, init_val)) + self.assertTrue(torch.all(torch.eq(res_with_loop, res_without_loop))) + + def test_while_loop_addition_nested(self): + device = xm.xla_device() + + def cond_fn(iteri, x): + return iteri > 0 + + def body_fn(iteri, x): + return iteri - 1, torch.add(torch.add(x, 1), 1) + + init_val = torch.tensor(2, dtype=torch.int32, device=device) + iteri = torch.tensor(10, device=device) + _, res_with_loop = while_loop(cond_fn, body_fn, (iteri, init_val)) + _, res_without_loop = _fake_while_loop(cond_fn, body_fn, (iteri, init_val)) + self.assertTrue(torch.all(torch.eq(res_with_loop, res_without_loop))) + + def test_while_loop_simple_linear_inside_loop(self): + device = xm.xla_device() + torch.set_grad_enabled(False) + + class SimpleLinear(torch.nn.Module): + + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(2, 2) + + def forward(self, iteri, x): + + def cond_fn(iteri, x): + return iteri > 0 + + def body_fn(iteri, x): + return iteri - 1, self.linear(x) + + return while_loop(cond_fn, body_fn, (iteri, x)) + + def forward_without_while_loop_op(self, iteri, x): + while (iteri > 0): + x = self.linear(x) + iteri -= 1 + return iteri, x + + linear_model = SimpleLinear() + linear_model.to(device) + l_in_0 = torch.randn(2, 2, dtype=torch.float32, device=device) + iteri = torch.tensor(10, dtype=torch.int32, device=device) + _, res_with_loop = linear_model(iteri, l_in_0) + _, res_without_loop = linear_model.forward_without_while_loop_op( + iteri, l_in_0) + + self.assertTrue(torch.all(torch.eq(res_with_loop, res_without_loop))) + + # ====== fori_loop ====== + @unittest.skip("Fori_loop is not supported now due to unstable result.") + def test_fori_loop_addition(self): + device = xm.xla_device() + + lower = torch.tensor(0, device=device) + upper = torch.tensor(50, device=device) + init_val = torch.tensor(1, dtype=torch.int32, device=device) + + def body_fun(x): + return torch.add(x, 1) + + _, res_with_loop = fori_loop(lower, upper, body_fun, (init_val)) + + # === expected === + for i in range(upper - lower): + init_val = torch.add(init_val, 1) + res_without_loop = init_val + + +if __name__ == '__main__': + test = unittest.main() + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/tpu/run_tests.sh b/test/tpu/run_tests.sh index 6d74c40d4e3..ddd439d1c60 100755 --- a/test/tpu/run_tests.sh +++ b/test/tpu/run_tests.sh @@ -20,7 +20,7 @@ python3 test/dynamo/test_dynamo.py python3 test/spmd/test_spmd_debugging.py python3 test/pjrt/test_dtypes.py python3 test/pjrt/test_dynamic_plugin_tpu.py -python3 test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py +python3 test/test_while_loop.py python3 test/test_pallas.py python3 test/test_pallas_spmd.py python3 test/test_input_output_aliases.py diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 3fba13773b3..db3e54d8163 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -933,22 +933,9 @@ class PyLoweringContext { } // Builds a HLO graph given a set of output tensors, and add unused parameters - // needed in xlacomputation. + // needed in xlacomputation for fori_loop/while_loop. void BuildForiLoop(std::vector tensors, - std::vector input_arguments = {}) { - if (GetNameString() == "condctx") { - xla::XlaBuilder* local_builder = lowering_ctx.builder(); - // hard-code parameter_idx to 2 to skip existing upper/lower arguments - int64_t parameter_idx = 2; - for (at::Tensor input_argument : input_arguments) { - xla::Shape shape = - xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, {1}); - xla::XlaOp x = xla::Parameter(local_builder, parameter_idx, shape, - "UnusedArgumentsPlaceholder"); - parameter_idx += 1; - } - } - + std::vector additional_inputs_list = {}) { // Get the backing XLA tensors from the output torch tensor handles std::vector xtensors = GetXlaTensors(tensors, /*want_all=*/true); @@ -966,6 +953,24 @@ class PyLoweringContext { torch::lazy::Output(ir_value.node.get(), ir_value.index)); lowering_ctx.AddResult(root); } + + // add dummy parameter to cond/body xlacomputation's input for xla::while + // requriement + if ((GetNameString() == "condctx") or + (GetNameString() == "bodyctx" && additional_inputs_list.size() != 0)) { + xla::XlaBuilder* local_builder = lowering_ctx.builder(); + int64_t parameter_idx = + local_builder->GetProgramShape()->parameters_size(); + int64_t additional_inputs_list_size = additional_inputs_list.size(); + for (int64_t i = parameter_idx; i < additional_inputs_list_size; i++) { + XLATensorPtr xtensor = bridge::GetXlaTensor(additional_inputs_list[i]); + xla::Shape shape = xtensor->shape().get(); + xla::XlaOp x = xla::Parameter(local_builder, parameter_idx, shape, + "UnusedArgumentsPlaceholder"); + parameter_idx += 1; + } + } + computation = ConsumeValue(lowering_ctx.BuildXla()); // wrap inputs of cond/body_computation diff --git a/torch_xla/experimental/fori_loop.py b/torch_xla/experimental/fori_loop.py index bf32a712f3e..e41709084e2 100644 --- a/torch_xla/experimental/fori_loop.py +++ b/torch_xla/experimental/fori_loop.py @@ -10,80 +10,141 @@ from torch._ops import HigherOrderOperator import torch._higher_order_ops.while_loop from torch._higher_order_ops.while_loop import while_loop_op +from torch._higher_order_ops.while_loop import while_loop as torch_while_loop +from torch._higher_order_ops.utils import _has_potential_branch_input_mutation -def fori_loop(lower, upper, user_body_func, *init_val): +def fori_loop(lower, upper, body_fun, *input_value): device = xm.xla_device() + if (upper < lower): + print("ERROR: upper should be a larger number than lower") + iteri = upper - lower - def cond_fn(upper, lower, *init_val): - return lower[0] < upper[0] + def cond_fn(iteri, *input_value): + return iteri > 0 - def body_fn(upper, lower, *init_val): - one_value_i = torch.ones(1, dtype=torch.int32, device=device) - res_list = list(user_body_func(*init_val)) - res_list.insert(0, lower) - res_list.insert(0, torch.sub(upper, one_value_i)) - return res_list + def new_body_fn(iteri, *input_value): + return iteri - 1, body_fun(*input_value) + + inputs = (iteri,) + input_value + res = _xla_while_loop_wrapper( + cond_fn, new_body_fn, inputs, (), fake_tensor=True) - res = while_loop(cond_fn, body_fn, (lower, upper, *init_val)) return res @while_loop_op.py_impl(DispatchKey.XLA) -def while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs=None): - # TODO(@manfei): PyTorch require carried_inputs to be list/tuple, PyTorch/XLA _xla_while_loop only accept *operands, *operands would tuple items again: (a, '') - # cond_fn&body_fn: callable - # carried_inputs: (Tuple of possibly nested dict/list/tuple of tensors) +def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs=None): if additional_inputs is None: additional_inputs = tuple() - return _xla_while_loop( - cond_fn, body_fn, *carried_inputs, additional_inputs=additional_inputs) + return _xla_while_loop_wrapper(cond_fn, body_fn, carried_inputs, + additional_inputs) + + +def _xla_while_loop_wrapper(cond_fn, + body_fn, + carried_inputs, + additional_inputs=None, + fake_tensor=False): + + def new_body_fn(*carried_inputs): + res = list(body_fn(*carried_inputs)) + if additional_inputs: + res = [ + res[0], + ] + list(additional_inputs) + res[1:] + else: + res = res + return res + return _xla_while_loop(cond_fn, new_body_fn, carried_inputs, + additional_inputs, fake_tensor) -def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs): - # untuple carried_inputs from while_loop - carried_inputs = carried_inputs[0] - # fake carried_inputs to split formal code + +def _xla_while_loop(cond_fn, + body_fn, + carried_inputs, + additional_inputs=None, + fake_tensor=False): + + # ====== fake_carried_inputs ====== fake_carried_inputs = [] for carried_input in carried_inputs: device = carried_input.device fake_carried_inputs.append( torch.randint(10, carried_input.size(), dtype=carried_input.dtype).to(device)) - fake_carried_inputs = tuple(fake_carried_inputs) - # trans fake_carried_inputs from list(tensor) to list(xla::op) - kwargs = {} - if type(fake_carried_inputs) is tuple: - shapes = xb.tensor_shape(fake_carried_inputs) + # ====== additional_inputs_list_cond ====== + fake_additiona_args = [] + for additional_input in additional_inputs: + device = additional_input.device + fake_additiona_args.append( + torch.randint( + 10, additional_input.size(), + dtype=additional_input.dtype).to(device)) + + # ====== inputs_list ====== + # specify body_fn_inputs/cond_fn_inputs, and add caught additional_inputs into fn_inputs + if additional_inputs or fake_tensor: + # replace inputs(carried_inputs[1:]) with fake tensors to fix missed arguments problem + body_fn_inputs = [ + carried_inputs[0], + ] + fake_carried_inputs[1:] + list(additional_inputs) + cond_fn_inputs = carried_inputs + additional_inputs else: - shapes = xb.tensor_shape((fake_carried_inputs)) - builder = xb.create_builder('test_while') - params = [] - for shape in shapes: - p = xb.mkparam(builder, len(params), shape) - params.append(p) + body_fn_inputs = carried_inputs + cond_fn_inputs = carried_inputs + + # due to `xla::While` requirement, body xlacomputation inputs/outputs, cond xlacomputation and init need to be the same shape and type; + # and carried_inputs contain (iter, values), additional_inputs contain (weights/bias) + # based on generated body xlacomputation outputs: (iter, weights/bias, values) + # we create expected order for cond/body xlacomputation generation to compare and match: (iter, weights/bias, values) + dummy_inputs_list = [ + fake_carried_inputs[0], + ] + fake_additiona_args + fake_carried_inputs[1:] + + # ====== body_fn ====== + body_result = body_fn(*body_fn_inputs) + body_ctx = torch_xla._XLAC.lowering.LoweringContext() + body_ctx.set_name_string("bodyctx") - # generate cond_fn xlacomputation - cond_result = cond_fn(*fake_carried_inputs) + # ====== body xlacomputation ====== + body_ctx.buildforiloop(list(body_result), dummy_inputs_list) + body_hlo = body_ctx.hlo() + body_computation = xb.computation_from_module_proto("bodycomputation", + body_hlo) + + # ====== cond_fn ====== + cond_result = cond_fn(*cond_fn_inputs) cond_ctx = torch_xla._XLAC.lowering.LoweringContext() cond_ctx.set_name_string("condctx") - cond_ctx.buildforiloop([cond_result], list(fake_carried_inputs[2:])) + + # ====== cond xlacomputation ====== + cond_ctx.buildforiloop([cond_result], dummy_inputs_list) cond_hlo = cond_ctx.hlo() cond_computation = xb.computation_from_module_proto("condcomputation", cond_hlo) - # generate body_fn xlacomputation - body_result = body_fn(*fake_carried_inputs) - body_ctx = torch_xla._XLAC.lowering.LoweringContext() - body_ctx.set_name_string("bodyctx") - body_ctx.buildforiloop(list(body_result), []) - body_hlo = body_ctx.hlo() - body_computation = xb.computation_from_module_proto("bodycomputation", - body_hlo) + # ====== xla::while ====== + iter_value = carried_inputs[0] + input_and_outputs_value = carried_inputs[1:] + total_inputs = tuple([ + iter_value, + ]) + tuple(additional_inputs) + tuple(input_and_outputs_value) + + kwargs = {} + if type(total_inputs) is tuple: + shapes = xb.tensor_shape(total_inputs) + else: + shapes = xb.tensor_shape((total_inputs)) + builder = xb.create_builder('while_loop') + params = [] + for shape in shapes: + p = xb.mkparam(builder, len(params), shape) + params.append(p) - # generate while xlacomputation input_tuple = xb.Op.tuple(tuple(params)) w = xb.mkop( 'While', (input_tuple.op,), @@ -94,6 +155,12 @@ def _xla_while_loop(cond_fn, body_fn, *carried_inputs, additional_inputs): # gain final result with generated while xlacomputation result = torch_xla._XLAC._xla_user_computation('xla::_op_test_while', - (carried_inputs), computation) + (total_inputs), computation) + + # unwrapper result without additional_inputs for original order + additional_inputs_len = len(additional_inputs) + 1 + final_res = [ + result[0], + ] + result[additional_inputs_len:] - return result \ No newline at end of file + return final_res