diff --git a/examples/decoder_only_model.py b/examples/decoder_only_model.py index 712423d79ad..79040e5d24d 100644 --- a/examples/decoder_only_model.py +++ b/examples/decoder_only_model.py @@ -7,16 +7,16 @@ from torch import nn -# the default config is intentionally kept low to make it runable on a sigle tpu v2-8 core. +# the default config is intentionally kept low to make it runnable on a single tpu v2-8 core. @dataclass class DecoderOnlyConfig: hidden_size: int = 1024 num_hidden_layers: int = 2 num_attention_heads: int = 8 num_key_value_heads: int = 4 - intermediate_size = 32 * 1024 - vocab_size = 3200 - use_flash_attention = False + intermediate_size: int = 32 * 1024 + vocab_size: int = 3200 + use_flash_attention: bool = False def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: diff --git a/test/run_tests.sh b/test/run_tests.sh index 86c52b6643b..b65cbc730ec 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -197,6 +197,7 @@ function run_xla_op_tests2 { run_test "$CDIR/pjrt/test_dtypes.py" run_test "$CDIR/test_while_loop.py" run_test "$CDIR/test_scan.py" + run_test "$CDIR/test_apply_layers.py" run_test "$CDIR/test_autocast.py" run_test "$CDIR/eager/test_eager.py" run_test "$CDIR/eager/test_eager_with_xla_compile.py" diff --git a/test/test_apply_layers.py b/test/test_apply_layers.py new file mode 100644 index 00000000000..062eedb0750 --- /dev/null +++ b/test/test_apply_layers.py @@ -0,0 +1,149 @@ +import sys +import os +example_folder = os.path.dirname(os.path.dirname(os.path.abspath( + sys.argv[0]))) + "/examples" +sys.path.append(example_folder) +from decoder_only_model import DecoderOnlyConfig, DecoderOnlyModel # type:ignore + +import sys +import unittest +from typing import Iterable + +import torch + +import torch_xla +from torch_xla.experimental.apply_layers import apply_layers + +from test_utils import XlaTestCase # type:ignore + + +class ApplyLayersTest(XlaTestCase): + + def setUp(self): + super().setUp() + + self.device = torch_xla.device() + + def test_empty_layers(self): + layers = [] + input_data = torch.randn(64).to(self.device) + torch_xla.sync() + output = apply_layers(layers, input_data.clone()) + super().compareResults(output, input_data, abs_err=0.0001, rel_err=0.01) + + def test_linear_layers(self): + # We want to apply these layers sequentially + import torch.nn as nn + layers = [nn.Linear(64, 64).to(self.device) for _ in range(10)] + input_data = torch.randn(64).to(self.device) + + from copy import deepcopy + scan_layers = deepcopy(layers) + loop_layers = deepcopy(layers) + torch_xla.sync() + + output = apply_layers(scan_layers, input_data.clone()) + output.sum().backward() + + # Test that the result is the same as for loop. + loop_output = input_data.clone() + from copy import deepcopy + for layer in loop_layers: + loop_output = layer(loop_output) + torch_xla.sync() + + super().compareResults(loop_output, output, abs_err=0.0001, rel_err=0.01) + + loop_output.sum().backward() + torch_xla.sync() + + # Test that the gradients are the same too. + for layer_scan, layer_loop in zip(scan_layers, loop_layers): + super().compareResults( + layer_scan.weight.grad, + layer_loop.weight.grad, + abs_err=0.0001, + rel_err=0.01) + super().compareResults( + layer_scan.bias.grad, + layer_loop.bias.grad, + abs_err=0.0001, + rel_err=0.01) + + def test_decoder_model(self): + # Define a decoder model that composes the decoder model in the example, + # but adds the ability to run the layers with the `scan` operator. + class DecoderOnlyModelWithScan(torch.nn.Module): + + def __init__(self, **kwargs): + super(DecoderOnlyModelWithScan, self).__init__() + self.decoder = DecoderOnlyModel(**kwargs) + + @property + def layers(self) -> Iterable[torch.nn.Module]: + return self.decoder.layers + + def forward( + self, + input_ids: torch.Tensor, + ) -> torch.Tensor: + return self.decoder.forward(input_ids) + + def forward_scan( + self, + input_ids: torch.Tensor, + ) -> torch.Tensor: + inputs_embeds = self.decoder.embed_tokens(input_ids) + # embed positions + assert isinstance(inputs_embeds, torch.Tensor) + # decoder layers + hidden_states = apply_layers(self.decoder.layers, inputs_embeds) + hidden_states = self.decoder.norm(hidden_states) + # [B, S, H] -> [B, S, V] + return self.decoder.output(hidden_states) + + # Make it smaller for fast model run and comparisons. + config = DecoderOnlyConfig( + hidden_size=128, intermediate_size=8 * 128, vocab_size=256) + model = DecoderOnlyModelWithScan(config=config).to(self.device) + batch_size = 2 + sequence_length = 8 + + # Generate random input_ids within the range of the vocabulary size + input_ids = torch.randint(0, config.vocab_size, + (batch_size, sequence_length)).to(self.device) + + from copy import deepcopy + loop_model = deepcopy(model) + scan_model = deepcopy(model) + torch_xla.sync() + + # Run the loop-based model. + loop_output = loop_model(input_ids.clone()) + loop_output.sum().backward() + torch_xla.sync() + + # Run again, this time using `scan` + scan_output = scan_model.forward_scan(input_ids.clone()) + scan_output.sum().backward() + torch_xla.sync() + + # Compare results + super().compareResults(scan_output, loop_output, abs_err=0.05, rel_err=0.01) + + # Check gradients + for layer_scan, layer_loop in zip(scan_model.layers, loop_model.layers): + for (name, + param_scan), (name2, + param_loop) in zip(layer_scan.named_parameters(), + layer_loop.named_parameters()): + assert name == name2 + if param_scan.grad is not None or param_loop.grad is not None: + super().compareResults( + param_scan.grad, param_loop.grad, abs_err=0.1, rel_err=0.05) + print(f"Pass: {name} {param_scan.shape}") + + +if __name__ == '__main__': + test = unittest.main() + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/test_operations.py b/test/test_operations.py index 1af928e6a47..3059b7e221c 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -21,6 +21,7 @@ import itertools import math from numbers import Number +from functools import reduce import numpy import random import re @@ -2597,6 +2598,29 @@ def test_api(self): mapping = ctx.parameter_id_tensor_mapping() self.assertEqual(len(mapping), 2) + def test_get_parameters_scalar(self): + """Scalar tensors parameters may be shared in the HLO graph if their + numerical values are equal. `parameter_id_tensor_mapping` needs to handle + that appropriately. + """ + + device = xm.xla_device() + tensors = [] + for i in range(10): + # Add three copies of the same value. + tensors.append(torch.tensor(i, device=device)) + tensors.append(torch.tensor(i, device=device)) + tensors.append(torch.tensor(i, device=device)) + result = reduce(lambda a, b: a + b, tensors) + ctx = torch_xla._XLAC.lowering.LoweringContext() + ctx.build([result]) + mapping = ctx.parameter_id_tensor_mapping() + + import json + hlo_json = json.loads(ctx.hlo_json()) + num_parameters = len(hlo_json["hostProgramShape"]["parameters"]) + self.assertEqual(len(mapping), num_parameters) + class TestGeneric(test_utils.XlaTestCase): diff --git a/test/test_scan.py b/test/test_scan.py index 6926c01fb01..c035f940960 100644 --- a/test/test_scan.py +++ b/test/test_scan.py @@ -1,11 +1,14 @@ import sys import unittest -import torch_xla +from functools import reduce + import torch +from torch.utils._pytree import tree_map, tree_flatten, tree_iter, tree_leaves, PyTree + +import torch_xla from torch_xla.experimental.scan import scan -from torch.utils._pytree import tree_map, tree_flatten, tree_iter -from test_utils import XlaTestCase +from test_utils import XlaTestCase # type:ignore def _loopy_scan(fn, init, xs): @@ -24,6 +27,8 @@ def _loopy_scan(fn, init, xs): class ScanTest(XlaTestCase): def setUp(self): + super().setUp() + self.device = torch_xla.device() def compare_pytree(self, expected_pytree, actual_pytree): @@ -32,22 +37,43 @@ def compare_pytree(self, expected_pytree, actual_pytree): assert expected_spec == actual_spec super().compareResults(flat_expected_pytree, flat_actual_pytree) - def run_test(self, step_fn, init, xs): + def run_test(self, fn, init: PyTree, xs: PyTree): + """Compares the result of scanning with `fn` with our optimized HLO implementation + against a for loop implementation. Checks both output values and gradients. + """ # Actual output - final_carry, ys = scan(step_fn, init, xs) + init_scan = tree_map(lambda v: v.detach().requires_grad_(), init) + xs_scan = tree_map(lambda v: v.detach().requires_grad_(), xs) + final_carry, ys = scan(fn, init_scan, xs_scan) + # Add up all leaves in `ys` and `backward()` once. + reduce(lambda a, b: a + b, map(lambda v: v.sum(), tree_leaves(ys)), + torch.tensor(0.0)).backward() torch_xla.sync() # Expected output - expected_final_carry, expected_ys = _loopy_scan(step_fn, init, xs) + init_loop = tree_map(lambda v: v.detach().requires_grad_(), init) + xs_loop = tree_map(lambda v: v.detach().requires_grad_(), xs) + expected_final_carry, expected_ys = _loopy_scan(fn, init_loop, xs_loop) + # Add up all leaves in `ys` and `backward()` once. + reduce(lambda a, b: a + b, map(lambda v: v.sum(), tree_leaves(expected_ys)), + torch.tensor(0.0)).backward() torch_xla.sync() - # Compare + # Compare values self.compare_pytree(expected_final_carry, final_carry) self.compare_pytree(expected_ys, ys) + # Compare gradients + self.compare_pytree( + tree_map(lambda v: v.grad, init_scan), + tree_map(lambda v: v.grad, init_loop)) + self.compare_pytree( + tree_map(lambda v: v.grad, xs_scan), tree_map(lambda v: v.grad, + xs_loop)) + return final_carry, ys - def test_scan_forward_simple(self): + def test_scan_simple(self): """This test uses `scan` to implement `torch.cumsum`.""" def step_fn(carry, x): @@ -55,8 +81,10 @@ def step_fn(carry, x): y = new_carry return new_carry, y - init = torch.tensor([0.0, 0.0], device=self.device) - xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], device=self.device) + init = torch.tensor([0.0, 0.0], requires_grad=True, device=self.device) + xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], + requires_grad=True, + device=self.device) final_carry, ys = self.run_test(step_fn, init, xs) # Also ensure that our loop-based scan is correct, with manual checks @@ -80,26 +108,30 @@ def test_scan_incompatible_length(self): with self.assertRaises(ValueError): scan(lambda a, b: (a, b), init, (xs_1, xs_2)) - def test_scan_forward_tuples(self): + def test_scan_tuples(self): """Test scanning over the leading axis of a tuple of tensors simultaneously, which is a simple PyTree.""" - def step_fn(carry, x): + def fn(carry, x): carry1, carry2 = carry x1, x2 = x new_carry1 = carry1 + x1.sum() new_carry2 = carry2 + x2.sum() - y1 = x1 * 2 - y2 = x2 * 2 + y1 = x1 * 2 + torch.sum(new_carry1) + y2 = x2 * 2 + torch.sum(new_carry2) return (new_carry1, new_carry2), (y1, y2) - init = (torch.tensor([0.0], device=self.device), - torch.tensor([1.0, 2.0], device=self.device)) + init = (torch.tensor([0.0], requires_grad=True, device=self.device), + torch.tensor([1.0, 2.0], requires_grad=True, device=self.device)) - xs = (torch.tensor([[1.0, 2.0], [3.0, 4.0]], device=self.device), - torch.tensor([[5.0, 6.0, 7.0], [8.0, 9.0, 10.0]], device=self.device)) + xs = (torch.tensor([[1.0, 2.0], [3.0, 4.0]], + requires_grad=True, + device=self.device), + torch.tensor([[5.0, 6.0, 7.0], [8.0, 9.0, 10.0]], + requires_grad=True, + device=self.device)) - self.run_test(step_fn, init, xs) + self.run_test(fn, init, xs) if __name__ == '__main__': diff --git a/test/tpu/run_tests.sh b/test/tpu/run_tests.sh index d7a8ba1c2a6..32a5e980889 100755 --- a/test/tpu/run_tests.sh +++ b/test/tpu/run_tests.sh @@ -25,6 +25,7 @@ python3 test/pjrt/test_dtypes.py python3 test/pjrt/test_dynamic_plugin_tpu.py python3 test/test_while_loop.py python3 test/test_scan.py +python3 test/test_apply_layers.py python3 test/test_pallas.py python3 test/test_pallas_spmd.py python3 test/test_input_output_aliases.py diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 4689787bc56..ab31be0ad65 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1064,7 +1064,6 @@ class PyLoweringContext { // etc.) std::unordered_map GetParameterIdTensorMapping() { // Find parameters in the lowering - const std::vector& param_ids = lowering_ctx.GetParameterSequence(); const std::vector& device_data = lowering_ctx.GetParametersData(); @@ -1081,7 +1080,9 @@ class PyLoweringContext { at::ScalarType dtype = MaybeUpcastToHostTorchType(literal.shape().element_type()); at::Tensor input = MakeTensorFromXlaLiteral(literal, dtype); - results[param_ids[i]] = input; + std::optional param_id = lowering_ctx.GetParameterId(device_data[i]); + XLA_CHECK(param_id.has_value()); + results[param_id.value()] = input; } return results; } @@ -1104,12 +1105,13 @@ class PyLoweringContext { torch::lazy::BackendData::Handle handle = data->GetHandle(); // Linearly search parameters and compare opaque handles - const std::vector& param_ids = lowering_ctx.GetParameterSequence(); const std::vector& device_data = lowering_ctx.GetParametersData(); for (int i = 0; i < device_data.size(); ++i) { if (device_data[i]->GetHandle() == handle) { - return param_ids[i]; + std::optional param_id = lowering_ctx.GetParameterId(device_data[i]); + XLA_CHECK(param_id.has_value()); + return param_id.value(); } } return -1; diff --git a/torch_xla/csrc/lowering_context.cpp b/torch_xla/csrc/lowering_context.cpp index c104be7c438..c2db9b36309 100644 --- a/torch_xla/csrc/lowering_context.cpp +++ b/torch_xla/csrc/lowering_context.cpp @@ -3,6 +3,7 @@ #include #include +#include #include #include #include @@ -136,6 +137,16 @@ xla::XlaOp LoweringContext::GetParameter( return it->second.param; } +std::optional LoweringContext::GetParameterId( + const std::shared_ptr& data) const { + torch::lazy::BackendData::Handle handle = data->GetHandle(); + auto it = parameters_map_.find(handle); + if (it == parameters_map_.end()) { + return std::nullopt; + } + return it->second.index; +} + const std::vector& LoweringContext::GetParametersData() const { return parameters_; @@ -195,13 +206,14 @@ void LoweringContext::AssignOutputOp(const torch::lazy::Output& output, xla::XlaOp LoweringContext::GetOutputOp(const torch::lazy::Output& output) { auto it = emitted_outputs_.find(output); + if (it == emitted_outputs_.end()) { auto post_order = torch::lazy::Util::ComputePostOrder(output.node, &emit_status_); for (auto node : post_order) { LowerNode(node); } - // At this point the outpout better be present, otherwise there is an issue + // At this point the output better be present, otherwise there is an issue // with the lowering code. it = emitted_outputs_.find(output); XLA_CHECK(it != emitted_outputs_.end()) @@ -216,6 +228,7 @@ XlaOpVector LoweringContext::LowerNode(const torch::lazy::Node* node) { HloMetadataSetter meta_setter(this, node); const XlaNode* casted = dynamic_cast(node); + result_ops = casted->Lower(this); if (!casted->dynamic_dims().empty()) { xla::internal::XlaBuilderFriend builder_friend; diff --git a/torch_xla/csrc/lowering_context.h b/torch_xla/csrc/lowering_context.h index e645f959af0..3a36695e1c0 100644 --- a/torch_xla/csrc/lowering_context.h +++ b/torch_xla/csrc/lowering_context.h @@ -6,6 +6,7 @@ #include #include +#include #include #include #include @@ -52,6 +53,11 @@ class LoweringContext : public torch::lazy::LoweringContext { const std::shared_ptr& data, const std::unordered_set& dynamic_dims = {}); + // If a parameter associated with data has already been declared, returns its + // ID. Otherwise, returns `std::nullopt`. + std::optional GetParameterId( + const std::shared_ptr& data) const; + // Retrieves the vector holding all the tensors associated with the parameter // instructions which have been created. const std::vector& GetParametersData() const; diff --git a/torch_xla/experimental/apply_layers.py b/torch_xla/experimental/apply_layers.py new file mode 100644 index 00000000000..8152d493d1d --- /dev/null +++ b/torch_xla/experimental/apply_layers.py @@ -0,0 +1,72 @@ +from typing import Iterable + +import torch +import torch.nn as nn +from torch.utils._pytree import tree_map + +from torch_xla.experimental.scan import scan + + +def apply_layers(layers: Iterable[torch.nn.Module], input_data): + """Applies each layer in `layers` to `input_data` sequentially. + + `input_data` is provided as input to the first layer in `layers`. The output of one + layer is provided as input to next layer. This function is equivalent to + + sequential = torch.nn.Sequential(layers) + sequential(input_data) + + This function can be faster to compile since it reuses the XLA computation of the + first layer to perform the computation of all other layers. + """ + # Handle empty layers case. + try: + next(iter(layers)) + except StopIteration: + return input_data + + # Extract and stack the parameters into a pytree. + params = [_extract_weights_dict(layer) for layer in layers] + stacked_params = tree_map(lambda *tensors: torch.stack(tensors, dim=0), + *params) + + # Use the first layer as the example/template layer. + from copy import deepcopy + example_layer = deepcopy(next(iter(layers))) + + # Hollow out the weights and biases in the example layer. + example_layer = example_layer.to_empty(device=None) + + # Define the function to apply at each step + def one_layer(carry, params): + # Apply the current layer's weights and biases to the example layer, + # then run the resulting layer. + _apply_weights_dict(example_layer, params) + # TODO(yifeit): it should be possible to return `None` as opposed to + # `example_layer(carry) * 0`, for additional clarity. There is no extra + # computation since we discard `ys` right after. + return example_layer(carry), example_layer(carry) * 0 + + final_carry, _ = scan(one_layer, input_data, stacked_params) + + return final_carry + + +def _extract_weights_dict(module: nn.Module): + """ + Extracts the parameters (weights and biases) from a PyTorch module and + stores them in a dictionary. + """ + weights_dict = { + name: param.clone() for name, param in module.named_parameters() + } + return weights_dict + + +def _apply_weights_dict(module: nn.Module, weights_dict): + """ + Re-applies the weights and biases from the dictionary back to the PyTorch module. + """ + for name, param in module.named_parameters(): + if name in weights_dict: + torch.utils.swap_tensors(param, weights_dict[name].clone()) diff --git a/torch_xla/experimental/pytreeify.py b/torch_xla/experimental/pytreeify.py new file mode 100644 index 00000000000..9fb0d282526 --- /dev/null +++ b/torch_xla/experimental/pytreeify.py @@ -0,0 +1,50 @@ +import torch.utils._pytree as pytree +from torch.autograd import Function + + +# Taken from https://github.com/pytorch/pytorch/issues/96337 +# +# The main purpose is to support autograd in the `scan` operator, which takes in +# PyTrees and outputs PyTrees. Builtin PyTorch autograd ignores tensors in +# non-trivial PyTrees such as dictionaries of tensors. This decorator adds +# arbitrary PyTree support by flattening the PyTree before handing to PyTorch and +# unflattening on the way back. +def pytreeify(cls): + assert issubclass(cls, Function) + + orig_fw = cls.forward + orig_bw = cls.backward + orig_apply = cls.apply + + def new_apply(*inp): + flat_inp, struct = pytree.tree_flatten(inp) + out_struct_holder = [] + flat_out = orig_apply(struct, out_struct_holder, *flat_inp) + assert flat_out is not None + assert len(out_struct_holder) == 1 + return pytree.tree_unflatten(flat_out, out_struct_holder[0]) + + def new_forward(ctx, struct, out_struct_holder, *flat_inp): + inp = pytree.tree_unflatten(flat_inp, struct) + out = orig_fw(ctx, *inp) + flat_out, out_struct = pytree.tree_flatten(out) + ctx._inp_struct = struct + ctx._out_struct = out_struct + out_struct_holder.append(out_struct) + return tuple(flat_out) + + def new_backward(ctx, *flat_grad_outputs): + grad_outputs = pytree.tree_unflatten(flat_grad_outputs, ctx._out_struct) + if not isinstance(grad_outputs, tuple): + grad_outputs = (grad_outputs,) + grad_inputs = orig_bw(ctx, *grad_outputs) + flat_grad_inputs, grad_inputs_struct = pytree.tree_flatten(grad_inputs) + if grad_inputs_struct != ctx._inp_struct: + raise RuntimeError("The backward generated an arg structure that doesn't " + "match the forward's input.") + return (None, None) + tuple(flat_grad_inputs) + + cls.apply = new_apply + cls.forward = new_forward + cls.backward = new_backward + return cls diff --git a/torch_xla/experimental/scan.py b/torch_xla/experimental/scan.py index 9008e03dbd9..a88bce4532a 100644 --- a/torch_xla/experimental/scan.py +++ b/torch_xla/experimental/scan.py @@ -4,10 +4,16 @@ """ -from typing import Callable, TypeVar +import itertools +from typing import Callable, Dict, Sequence, TypeVar, Tuple, List import torch -from torch.utils._pytree import tree_map, tree_iter +import torch.autograd +from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten, tree_iter, PyTree + +import torch_xla +import torch_xla.core.xla_builder as xb +from torch_xla.experimental.pytreeify import pytreeify Carry = TypeVar('Carry') X = TypeVar('X') @@ -82,14 +88,335 @@ def scan(fn, init, xs): if xs_length is None: raise ValueError(f"`xs` {xs} is an empty PyTree.") - carry = init - ys = [] + carry, carry_history, ys = Scan.apply(fn, init, xs) # type: ignore + return carry, ys + - for i in range(xs_length): - carry, y = fn(carry, tree_map(lambda x: x[i], xs)) - ys.append(y) +def dynamic_update_slice(ys: xb.Op, y: xb.Op, idx: xb.Op) -> xb.Op: + # See https://openxla.org/xla/operation_semantics#dynamicupdateslice. + y = y.broadcast([1]) + indices = [idx] + for _ in range(ys.shape().rank - 1): + indices.append(idx.zeros_like()) + return ys.dynamic_update_slice(y, indices) - # Combine the list of PyTrees into one PyTree, where the leaves are - # stacked into a new major axis. - ys = tree_map(lambda *x: torch.stack(x), *ys) - return carry, ys + +def dynamic_slice(xs: xb.Op, idx: xb.Op) -> xb.Op: + indices = [idx] + for _ in range(xs.shape().rank - 1): + indices.append(idx.zeros_like()) + slice_shape = list(xs.shape().sizes) + slice_shape[0] = 1 + sliced = xs.dynamic_slice(indices, slice_shape) + shape = list(xs.shape().sizes) + shape = shape[1:] + return sliced.reshape(shape) + + +class Builder: + + def __init__(self, name: str): + self._builder = xb.create_builder(name) + self._params = [] + self._param_tensors = [] + + def add_param(self, val: torch.Tensor): + idx = len(self._params) + param = xb.mkparam(self._builder, idx, xb.tensor_shape(val)) + self._params.append(param) + self._param_tensors.append(val) + return idx + + def params(self) -> Tuple[xb.Op, ...]: + return tuple(self._params) + + def param_tensors(self) -> Tuple[torch.Tensor, ...]: + return tuple(self._param_tensors) + + def num_params(self) -> int: + return len(self._params) + + +def _scan_impl(fn, init, xs): + """Forward logic of scan without gradient tracking. + + See the `Scan` class which implements an autograd `Function` and builds + autograd support on top of `_scan_impl`. + """ + + flat_init, carry_spec = tree_flatten(init) + flat_xs, xs_spec = tree_flatten(xs) + + # Because `flat_fn` returns a concatenated flattened carry and y list, + # we need to know how many elements out of that list is the carry. + flat_carry_len = len(flat_init) + flat_xs_len = len(flat_xs) + + # `fn` operates on PyTrees and returns PyTrees. However, XLA only understands + # (lists or tuples of) tensors. So we will craft a `flat_fn` that takes in + # flattened PyTrees, internally recreates the desired tree structure, then calls `fn`. + def flat_fn(*seq: torch.Tensor) -> Tuple[PyTree, PyTree]: + carry = seq[:flat_carry_len] + x = seq[flat_carry_len:] + carry_pytree = tree_unflatten(carry, carry_spec) + x_pytree = tree_unflatten(x, xs_spec) + return fn(carry_pytree, x_pytree) + + # Abstractly trace and lower `fn`. + # Later we will include `fn_computation` within the while loop body. + def make_fake_tensor(v: torch.Tensor) -> torch.Tensor: + # TODO(yifeit): there are some problems in PyTorch/XLA, or I'm missing something, because + # + # torch.empty(v.size(), dtype=v.dtype, requires_grad=v.requires_grad).to(device) + # + # results in a tensor without gradient tracking, and + # + # torch.empty(v.size(), dtype=v.dtype, device=v.device, requires_grad=v.requires_grad) + # + # results in incorrect calculation, unless they are `print()`-ed. + # But all three should be equivalent. + t = torch.empty(v.size(), dtype=v.dtype).to(device) + t.requires_grad_(v.requires_grad) + return t + + device = torch_xla.device() + fake_carry_pytree = tree_map(make_fake_tensor, init) + fake_x_pytree = tree_map(lambda v: make_fake_tensor(v[0]), xs) + fake_carry, _ = tree_flatten(fake_carry_pytree) + fake_x, _ = tree_flatten(fake_x_pytree) + fn_output_carry_pytree, fn_output_y_pytree = flat_fn(*(fake_carry + fake_x)) + + # Later we'll use `fn_output_carry_spec` etc to turn flattened outputs back to a PyTree. + fn_output_carry, fn_output_carry_spec = tree_flatten(fn_output_carry_pytree) + assert fn_output_carry_spec == carry_spec + fn_output_y, fn_output_y_spec = tree_flatten(fn_output_y_pytree) + flat_y_len = len(fn_output_y) + fn_outputs = fn_output_carry + fn_output_y + + fn_ctx = torch_xla._XLAC.lowering.LoweringContext() + fn_ctx.set_name_string("my_ctx") + fn_ctx.build(list(fn_outputs)) + fn_hlo = fn_ctx.hlo() + fn_computation = xb.computation_from_module_proto("my_fn_computation", fn_hlo) + + builder = Builder('scan') + + # Figure out the shape of `ys` from the abstract tracing. + fn_carry_out = fn_outputs[:flat_carry_len] + fn_y_out = fn_outputs[flat_carry_len:] + assert flat_carry_len + flat_y_len == len(fn_outputs) + fn_carry_shapes = [v.shape for v in fn_carry_out] + fn_y_shapes = [v.shape for v in fn_y_out] + for fn_carry_shape, init_leaf in zip(fn_carry_shapes, flat_init): + assert fn_carry_shape == init_leaf.shape, f"`fn` must keep the `carry` shape unchanged. \ + Got {fn_carry_shape} but expected {init_leaf.shape}" + + # Since we are threading four PyTrees through the body_fn: + # - carry: the scan state + # - xs: the flattened input pytree + # - fn_carry_history: history of that state + # - ys: the flattened output of fn + # + # We need to concatenate all three into one big list prior to + # entering `body_fn` and `cond_fn`, and split them back to three + # objects which is easier to work with after that. This pair of + # functions is for that purpose. + T = TypeVar('T') + + def pack(carry: Sequence[T], xs: Sequence[T], fn_carry_history: Sequence[T], + ys: Sequence[T]) -> Tuple[T, ...]: + return tuple(carry) + tuple(xs) + tuple(fn_carry_history) + tuple(ys) + + def unpack(seq: Sequence[T]) -> Tuple[List[T], List[T], List[T], List[T]]: + seq = list(seq) + carry = seq[:flat_carry_len] + xs = seq[flat_carry_len:flat_carry_len + flat_xs_len] + fn_carry_history = seq[flat_carry_len + flat_xs_len:flat_carry_len * 2 + + flat_xs_len] + ys = seq[flat_carry_len * 2 + flat_xs_len:] + return carry, xs, fn_carry_history, ys + + xs_len = next(iter(tree_iter(xs))).size(0) + num_iters = torch.tensor(xs_len, device=device) + ys = [ + torch.zeros((xs_len, *fn_y_shape), device=device) + for fn_y_shape in fn_y_shapes + ] + fn_carry_history = [ + torch.zeros((xs_len, *fn_carry_shape), device=device) + for fn_carry_shape in fn_carry_shapes + ] + loop_tensors: Tuple[torch.Tensor, ...] = (num_iters,) + pack( + flat_init, flat_xs, fn_carry_history, ys) + for val in loop_tensors: + builder.add_param(val) + + # If there are additional device data tensors referenced by the computation that + # are not input or carry, we need to provide those tensors when calling + # `fn_computation`. As a result, we need to determine what are those tensors and they + # need to be provided as additional inputs to `cond_fn` and `body_fn`. + + # Add additional inputs as params as well. + mapping: Dict[int, torch.Tensor] = fn_ctx.parameter_id_tensor_mapping() + param_id_to_additional_tensors_param_id: Dict[int, int] = {} + num_params = len(mapping) + for v in itertools.chain(fake_carry, fake_x): + param_id = fn_ctx.tensor_parameter_id(v) + if param_id != -1: + del mapping[param_id] + for param_id in range(num_params): + if param_id in mapping: + idx = builder.add_param(mapping[param_id].to(torch_xla.device())) + param_id_to_additional_tensors_param_id[param_id] = idx + num_additional_inputs = len(mapping) + + def skip_additional_inputs(fn): + + def wrapper(*args): + first_args = args[:builder.num_params() - num_additional_inputs] + return fn(*first_args) + + return wrapper + + def pass_through_additional_inputs(fn): + + def wrapper(*args): + first_args = args[:builder.num_params() - num_additional_inputs] + additional_inputs = args[builder.num_params() - num_additional_inputs:] + res = fn(*first_args, additional_inputs=additional_inputs) + assert isinstance(res, tuple) + return xb.Op.tuple(res + additional_inputs) + + return wrapper + + # For each tensor, we need to know its parameter ID. + # Then we should order the tensors in increasing parameter ID order when passing them + # to `xb.Op.call`. For each (ID, tensor) in `mapping`: + # - Check if the tensor is a fake tensor we just created + # - If yes, find the position in the fake tensor list. Index into the input ops. + # - If no, find the op in the additional input list. + def call_fn_computation(carry: List[xb.Op], x: List[xb.Op], + additional_inputs: Tuple[xb.Op, ...]) -> xb.Op: + param_id_to_fake_tensors_id: Dict[int, int] = {} + for i, v in enumerate(itertools.chain(fake_carry, fake_x)): + param_id = fn_ctx.tensor_parameter_id(v) + if param_id != -1: + param_id_to_fake_tensors_id[param_id] = i + + all_inputs = carry + x + all_inputs_reordered = [] + mapping: Dict[int, torch.Tensor] = fn_ctx.parameter_id_tensor_mapping() + for i in range(len(mapping)): + if i in param_id_to_fake_tensors_id: + op = all_inputs[param_id_to_fake_tensors_id[i]] + all_inputs_reordered.append(op) + else: + op = additional_inputs[param_id_to_additional_tensors_param_id[i] - + len(loop_tensors)] + all_inputs_reordered.append(op) + return xb.Op.call(fn_computation, all_inputs_reordered) + + @skip_additional_inputs + def cond_fn(num_iters: xb.Op, *args: xb.Op): + return num_iters > xb.Op.scalar(num_iters.builder(), 0, dtype=xb.Type.S64) + + @pass_through_additional_inputs + def body_fn(num_iters: xb.Op, *args: xb.Op, additional_inputs: Tuple[xb.Op, + ...]): + carry, xs, fn_carry_history, ys = unpack(args) + xs_len_op = xb.Op.scalar(num_iters.builder(), xs_len, dtype=xb.Type.S64) + one = xb.Op.scalar(num_iters.builder(), 1, dtype=xb.Type.S64) + idx = xs_len_op - num_iters + x = [dynamic_slice(v, idx) for v in xs] + for i in range(len(carry)): + fn_carry_history[i] = dynamic_update_slice(fn_carry_history[i], carry[i], + idx) + result = call_fn_computation(carry, x, additional_inputs) + for i in range(flat_carry_len): + carry[i] = result.get_tuple_element(i) + for i in range(flat_y_len): + y = result.get_tuple_element(i + flat_carry_len) + ys[i] = dynamic_update_slice(ys[i], y, idx) + return (num_iters - one,) + pack(carry, xs, fn_carry_history, ys) + + res = xb.Op.mkwhile(builder.params(), cond_fn, body_fn) + computation = res.build('scan') + + outputs = torch_xla._XLAC._xla_user_computation('xla::scan', + builder.param_tensors(), + computation) + # skip the last num_additional_inputs + outputs = outputs[:len(outputs) - num_additional_inputs] + # `1:` to skip `num_iters` + carry, xs, fn_carry_history, ys = unpack(outputs[1:]) + + # Unflatten tensors back to PyTrees + return tree_unflatten(carry, carry_spec), tree_unflatten( + fn_carry_history, carry_spec), tree_unflatten(ys, fn_output_y_spec) + + +@pytreeify +class Scan(torch.autograd.Function): + + @staticmethod + def forward(ctx, fn, init, xs): + # Forward pass, save inputs for backward + ctx._fn = fn + with torch._C._AutoDispatchBelowAutograd(): + carry, carry_history, ys = _scan_impl(fn, init, xs) + flat_carry_history, carry_spec = tree_flatten(carry_history) + flat_xs, xs_spec = tree_flatten(xs) + ctx.save_for_backward(*flat_carry_history, *flat_xs) + ctx._flat_carry_len = len(flat_carry_history) + ctx._carry_spec = carry_spec + ctx._xs_spec = xs_spec + return carry, carry_history, ys + + @staticmethod + def backward(ctx, grad_carry, grad_carry_history, grad_ys): + fn = ctx._fn + flat_carry_len = ctx._flat_carry_len + carry_spec = ctx._carry_spec + xs_spec = ctx._xs_spec + tensors_list = ctx.saved_tensors + carry_history = tree_unflatten(tensors_list[:flat_carry_len], carry_spec) + xs = tree_unflatten(tensors_list[flat_carry_len:], xs_spec) + + def detach_tensor(inp: torch.Tensor) -> torch.Tensor: + x = inp.clone().detach() + x.requires_grad = inp.requires_grad + return x + + def step_fn(grad_carry, pytree: Tuple[torch.Tensor, torch.Tensor, + torch.Tensor]): + grad_y, carry, x = pytree + # Compute the backward of a single scan iteration + detached_inputs = tree_map(detach_tensor, (carry, x)) + with torch.enable_grad(): + # `outputs` is a tuple of two PyTrees, the carry and y. + # In order to use `autograd.backward`, we need to flatten all + # PyTrees. The flattened `grad_carry` should match up with `carry` + # and similar for `y`. + outputs = fn(*detached_inputs) + output_carry, output_y = outputs + flat_output_carry, output_carry_spec = tree_flatten(output_carry) + flat_output_y, output_y_spec = tree_flatten(output_y) + flat_grad_carry, grad_carry_spec = tree_flatten(grad_carry) + flat_grad_y, grad_y_spec = tree_flatten(grad_y) + assert output_carry_spec == grad_carry_spec + assert output_y_spec == grad_y_spec + torch.autograd.backward(flat_output_carry + flat_output_y, + flat_grad_carry + flat_grad_y) + grad_carry, grad_x = tree_map(lambda v: v.grad, detached_inputs) + return grad_carry, grad_x + + # Reverse loop to accumulate gradients + grad_init = tree_map(lambda v: v.clone(), grad_carry) + carry_history = tree_map(lambda v: v.flip(0).requires_grad_(True), + carry_history) + xs = tree_map(lambda v: v.flip(0).requires_grad_(True), xs) + grad_ys = tree_map(lambda v: v.flip(0).requires_grad_(True), grad_ys) + + grad_init, _, grad_xs = _scan_impl(step_fn, grad_init, + (grad_ys, carry_history, xs)) + return None, grad_init, tree_map(lambda v: v.flip(0), grad_xs)