Skip to content

Commit

Permalink
Introduce the scan interface (#7848)
Browse files Browse the repository at this point in the history
  • Loading branch information
tengyifei committed Aug 22, 2024
1 parent df1f994 commit 95f833a
Show file tree
Hide file tree
Showing 4 changed files with 204 additions and 0 deletions.
1 change: 1 addition & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ function run_xla_op_tests1 {
function run_xla_op_tests2 {
run_test "$CDIR/pjrt/test_dtypes.py"
run_test "$CDIR/test_while_loop.py"
run_test "$CDIR/test_scan.py"
run_test "$CDIR/test_autocast.py"
run_test "$CDIR/eager/test_eager.py"
run_test "$CDIR/eager/test_eager_with_xla_compile.py"
Expand Down
107 changes: 107 additions & 0 deletions test/test_scan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import sys
import unittest
import torch_xla
import torch
from torch_xla.experimental.scan import scan
from torch.utils._pytree import tree_map, tree_flatten, tree_iter

from test_utils import XlaTestCase


def _loopy_scan(fn, init, xs):
"""A simple scan implemented with for loops serving as reference
implementation."""
carry = init
ys = []
xs_len = len(next(iter(tree_iter(xs))))
for i in range(xs_len):
carry, y = fn(carry, tree_map(lambda x: x[i], xs))
ys.append(y)
ys = tree_map(lambda *x: torch.stack(x), *ys)
return carry, ys


class ScanTest(XlaTestCase):

def setUp(self):
self.device = torch_xla.device()

def compare_pytree(self, expected_pytree, actual_pytree):
flat_expected_pytree, expected_spec = tree_flatten(expected_pytree)
flat_actual_pytree, actual_spec = tree_flatten(actual_pytree)
assert expected_spec == actual_spec
super().compareResults(flat_expected_pytree, flat_actual_pytree)

def run_test(self, step_fn, init, xs):
# Actual output
final_carry, ys = scan(step_fn, init, xs)
torch_xla.sync()

# Expected output
expected_final_carry, expected_ys = _loopy_scan(step_fn, init, xs)
torch_xla.sync()

# Compare
self.compare_pytree(expected_final_carry, final_carry)
self.compare_pytree(expected_ys, ys)

return final_carry, ys

def test_scan_forward_simple(self):
"""This test uses `scan` to implement `torch.cumsum`."""

def step_fn(carry, x):
new_carry = carry + x
y = new_carry
return new_carry, y

init = torch.tensor([0.0, 0.0], device=self.device)
xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], device=self.device)
final_carry, ys = self.run_test(step_fn, init, xs)

# Also ensure that our loop-based scan is correct, with manual checks
# that replicate the step_fn.
expected_final_carry = torch.sum(xs, dim=0) + init
expected_ys = torch.cumsum(xs, dim=0)
self.compare_pytree(expected_final_carry, final_carry)
self.compare_pytree(expected_ys, ys)

def test_scan_fn_not_callable(self):
init = torch.tensor([1.0, 1.0], device=self.device)
xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], device=self.device)
with self.assertRaises(ValueError):
scan(1000, init, xs) # type: ignore

def test_scan_incompatible_length(self):
init = torch.tensor([1.0, 1.0], device=self.device)
xs_1 = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
device=self.device)
xs_2 = torch.tensor([[1.0, 2.0], [3.0, 4.0]], device=self.device)
with self.assertRaises(ValueError):
scan(lambda a, b: (a, b), init, (xs_1, xs_2))

def test_scan_forward_tuples(self):
"""Test scanning over the leading axis of a tuple of tensors simultaneously,
which is a simple PyTree."""

def step_fn(carry, x):
carry1, carry2 = carry
x1, x2 = x
new_carry1 = carry1 + x1.sum()
new_carry2 = carry2 + x2.sum()
y1 = x1 * 2
y2 = x2 * 2
return (new_carry1, new_carry2), (y1, y2)

init = (torch.tensor([0.0], device=self.device),
torch.tensor([1.0, 2.0], device=self.device))

xs = (torch.tensor([[1.0, 2.0], [3.0, 4.0]], device=self.device),
torch.tensor([[5.0, 6.0, 7.0], [8.0, 9.0, 10.0]], device=self.device))

self.run_test(step_fn, init, xs)


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
1 change: 1 addition & 0 deletions test/tpu/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ XLA_PARAMETER_WRAPPING_THREADSHOLD=1 python test/spmd/test_spmd_parameter_wrappi
python3 test/pjrt/test_dtypes.py
python3 test/pjrt/test_dynamic_plugin_tpu.py
python3 test/test_while_loop.py
python3 test/test_scan.py
python3 test/test_pallas.py
python3 test/test_pallas_spmd.py
python3 test/test_input_output_aliases.py
Expand Down
95 changes: 95 additions & 0 deletions torch_xla/experimental/scan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""Module implementing the `scan` higher order operator.
Reference: https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html
"""

from typing import Callable, TypeVar

import torch
from torch.utils._pytree import tree_map, tree_iter

Carry = TypeVar('Carry')
X = TypeVar('X')
Y = TypeVar('Y')


def scan(
fn: Callable[[Carry, X], tuple[Carry, Y]],
init: Carry,
xs: X,
) -> tuple[Carry, Y]:
"""Apply a function over leading dimension of tensors while carrying along state.
This is similar to the JAX `jax.lax.scan` function found in [1].
You may use it to loop over the leading dimension of tensors efficiently. If `xs`
is a single tensor, this function is roughly equal to the following Python code:
def scan(fn, init, xs):
ys = []
carry = init
for i in len(range(xs.size(0))):
carry, y = fn(carry, xs[i])
ys.append(y)
return carry, torch.stack(ys, dim=0)
In the general case, `Carry`, `X`, and `Y` can be arbitrary PyTrees. This function
will iterate through the leading dimension of every leaf element of `xs` simultaneously,
and pass a slice of those elements to `fn` as another PyTree. This means you may
scan over multiple tensors and produce multiple output tensors at once.
Args:
fn: a Python callable that accepts two PyTrees of tensors: the carry object and the
slices of `xs` along its leading dimension. It should return two PyTrees: the carry
object and the slices of the output. The returned carry object will be passed to
the next invocation of `fn`.
init: the initial carry object passed to the first invocation of `fn`.
xs: the input PyTree to scan over. If `xs` is a tensor, then `fn` will get slices along
the leading dimension (`xs[i]`). If `xs` is some other PyTree (e.g. tuple of
tensor), `fn` will get PyTrees of slices. In that case the leading dimension size
of the leaves in the PyTree must be the same.
Returns:
(carry, ys): A tuple where `carry` is the last carry object returned by `fn`, and
`ys` is a PyTree with the same structure as `xs`, but where the leaves are formed
by stacking the leaf outputs of `fn` respectively. This means if your `fn` returns
`(carry, (y1, y2))` then this function will return
`(carry, (torch.stack(all_y1), torch.stack(all_y2)))`.
[1]: https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html
"""

# Ensure that `fn` is callable.
if not callable(fn):
raise ValueError(f"`fn` {fn} must be callable.")

# Ensure that the leaves have the same length.
xs_length = None
for leaf in tree_iter(xs):
leaf_len = len(leaf)
if xs_length is None:
xs_length = leaf_len
if xs_length != leaf_len:
raise ValueError(
f"The leaves of the `xs` input PyTree must have the same leading dimension size. \
Got {xs_length} and {leaf_len}")

if xs_length is None:
raise ValueError(f"`xs` {xs} is an empty PyTree.")

carry = init
ys = []

for i in range(xs_length):
carry, y = fn(carry, tree_map(lambda x: x[i], xs))
ys.append(y)

# Combine the list of PyTrees into one PyTree, where the leaves are
# stacked into a new major axis.
ys = tree_map(lambda *x: torch.stack(x), *ys)
return carry, ys

0 comments on commit 95f833a

Please sign in to comment.