diff --git a/test/run_tests.sh b/test/run_tests.sh index b43b20b868f..16c18a204a0 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -195,6 +195,7 @@ function run_xla_op_tests1 { function run_xla_op_tests2 { run_test "$CDIR/pjrt/test_dtypes.py" run_test "$CDIR/test_while_loop.py" + run_test "$CDIR/test_scan.py" run_test "$CDIR/test_autocast.py" run_test "$CDIR/eager/test_eager.py" run_test "$CDIR/eager/test_eager_with_xla_compile.py" diff --git a/test/test_scan.py b/test/test_scan.py new file mode 100644 index 00000000000..6926c01fb01 --- /dev/null +++ b/test/test_scan.py @@ -0,0 +1,107 @@ +import sys +import unittest +import torch_xla +import torch +from torch_xla.experimental.scan import scan +from torch.utils._pytree import tree_map, tree_flatten, tree_iter + +from test_utils import XlaTestCase + + +def _loopy_scan(fn, init, xs): + """A simple scan implemented with for loops serving as reference + implementation.""" + carry = init + ys = [] + xs_len = len(next(iter(tree_iter(xs)))) + for i in range(xs_len): + carry, y = fn(carry, tree_map(lambda x: x[i], xs)) + ys.append(y) + ys = tree_map(lambda *x: torch.stack(x), *ys) + return carry, ys + + +class ScanTest(XlaTestCase): + + def setUp(self): + self.device = torch_xla.device() + + def compare_pytree(self, expected_pytree, actual_pytree): + flat_expected_pytree, expected_spec = tree_flatten(expected_pytree) + flat_actual_pytree, actual_spec = tree_flatten(actual_pytree) + assert expected_spec == actual_spec + super().compareResults(flat_expected_pytree, flat_actual_pytree) + + def run_test(self, step_fn, init, xs): + # Actual output + final_carry, ys = scan(step_fn, init, xs) + torch_xla.sync() + + # Expected output + expected_final_carry, expected_ys = _loopy_scan(step_fn, init, xs) + torch_xla.sync() + + # Compare + self.compare_pytree(expected_final_carry, final_carry) + self.compare_pytree(expected_ys, ys) + + return final_carry, ys + + def test_scan_forward_simple(self): + """This test uses `scan` to implement `torch.cumsum`.""" + + def step_fn(carry, x): + new_carry = carry + x + y = new_carry + return new_carry, y + + init = torch.tensor([0.0, 0.0], device=self.device) + xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], device=self.device) + final_carry, ys = self.run_test(step_fn, init, xs) + + # Also ensure that our loop-based scan is correct, with manual checks + # that replicate the step_fn. + expected_final_carry = torch.sum(xs, dim=0) + init + expected_ys = torch.cumsum(xs, dim=0) + self.compare_pytree(expected_final_carry, final_carry) + self.compare_pytree(expected_ys, ys) + + def test_scan_fn_not_callable(self): + init = torch.tensor([1.0, 1.0], device=self.device) + xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], device=self.device) + with self.assertRaises(ValueError): + scan(1000, init, xs) # type: ignore + + def test_scan_incompatible_length(self): + init = torch.tensor([1.0, 1.0], device=self.device) + xs_1 = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], + device=self.device) + xs_2 = torch.tensor([[1.0, 2.0], [3.0, 4.0]], device=self.device) + with self.assertRaises(ValueError): + scan(lambda a, b: (a, b), init, (xs_1, xs_2)) + + def test_scan_forward_tuples(self): + """Test scanning over the leading axis of a tuple of tensors simultaneously, + which is a simple PyTree.""" + + def step_fn(carry, x): + carry1, carry2 = carry + x1, x2 = x + new_carry1 = carry1 + x1.sum() + new_carry2 = carry2 + x2.sum() + y1 = x1 * 2 + y2 = x2 * 2 + return (new_carry1, new_carry2), (y1, y2) + + init = (torch.tensor([0.0], device=self.device), + torch.tensor([1.0, 2.0], device=self.device)) + + xs = (torch.tensor([[1.0, 2.0], [3.0, 4.0]], device=self.device), + torch.tensor([[5.0, 6.0, 7.0], [8.0, 9.0, 10.0]], device=self.device)) + + self.run_test(step_fn, init, xs) + + +if __name__ == '__main__': + test = unittest.main() + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/tpu/run_tests.sh b/test/tpu/run_tests.sh index a49081d6bb4..f22dbb4c02b 100755 --- a/test/tpu/run_tests.sh +++ b/test/tpu/run_tests.sh @@ -24,6 +24,7 @@ XLA_PARAMETER_WRAPPING_THREADSHOLD=1 python test/spmd/test_spmd_parameter_wrappi python3 test/pjrt/test_dtypes.py python3 test/pjrt/test_dynamic_plugin_tpu.py python3 test/test_while_loop.py +python3 test/test_scan.py python3 test/test_pallas.py python3 test/test_pallas_spmd.py python3 test/test_input_output_aliases.py diff --git a/torch_xla/experimental/scan.py b/torch_xla/experimental/scan.py new file mode 100644 index 00000000000..9008e03dbd9 --- /dev/null +++ b/torch_xla/experimental/scan.py @@ -0,0 +1,95 @@ +"""Module implementing the `scan` higher order operator. + +Reference: https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html + +""" + +from typing import Callable, TypeVar + +import torch +from torch.utils._pytree import tree_map, tree_iter + +Carry = TypeVar('Carry') +X = TypeVar('X') +Y = TypeVar('Y') + + +def scan( + fn: Callable[[Carry, X], tuple[Carry, Y]], + init: Carry, + xs: X, +) -> tuple[Carry, Y]: + """Apply a function over leading dimension of tensors while carrying along state. + + This is similar to the JAX `jax.lax.scan` function found in [1]. + + You may use it to loop over the leading dimension of tensors efficiently. If `xs` + is a single tensor, this function is roughly equal to the following Python code: + + def scan(fn, init, xs): + ys = [] + carry = init + for i in len(range(xs.size(0))): + carry, y = fn(carry, xs[i]) + ys.append(y) + return carry, torch.stack(ys, dim=0) + + In the general case, `Carry`, `X`, and `Y` can be arbitrary PyTrees. This function + will iterate through the leading dimension of every leaf element of `xs` simultaneously, + and pass a slice of those elements to `fn` as another PyTree. This means you may + scan over multiple tensors and produce multiple output tensors at once. + + Args: + + fn: a Python callable that accepts two PyTrees of tensors: the carry object and the + slices of `xs` along its leading dimension. It should return two PyTrees: the carry + object and the slices of the output. The returned carry object will be passed to + the next invocation of `fn`. + + init: the initial carry object passed to the first invocation of `fn`. + + xs: the input PyTree to scan over. If `xs` is a tensor, then `fn` will get slices along + the leading dimension (`xs[i]`). If `xs` is some other PyTree (e.g. tuple of + tensor), `fn` will get PyTrees of slices. In that case the leading dimension size + of the leaves in the PyTree must be the same. + + Returns: + + (carry, ys): A tuple where `carry` is the last carry object returned by `fn`, and + `ys` is a PyTree with the same structure as `xs`, but where the leaves are formed + by stacking the leaf outputs of `fn` respectively. This means if your `fn` returns + `(carry, (y1, y2))` then this function will return + `(carry, (torch.stack(all_y1), torch.stack(all_y2)))`. + + [1]: https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html + """ + + # Ensure that `fn` is callable. + if not callable(fn): + raise ValueError(f"`fn` {fn} must be callable.") + + # Ensure that the leaves have the same length. + xs_length = None + for leaf in tree_iter(xs): + leaf_len = len(leaf) + if xs_length is None: + xs_length = leaf_len + if xs_length != leaf_len: + raise ValueError( + f"The leaves of the `xs` input PyTree must have the same leading dimension size. \ + Got {xs_length} and {leaf_len}") + + if xs_length is None: + raise ValueError(f"`xs` {xs} is an empty PyTree.") + + carry = init + ys = [] + + for i in range(xs_length): + carry, y = fn(carry, tree_map(lambda x: x[i], xs)) + ys.append(y) + + # Combine the list of PyTrees into one PyTree, where the leaves are + # stacked into a new major axis. + ys = tree_map(lambda *x: torch.stack(x), *ys) + return carry, ys