Skip to content

Commit

Permalink
scan and apply_layers
Browse files Browse the repository at this point in the history
Add the lowering of scan to HLO While op.

Introduce apply_layers which can sequentially apply a bunch of layers
using scan underneath.

Beef up unit tests including linear layers and decoders.

add regression test for parameter_id_tensor_mapping

add test_apply_layers.py to test shell scripts

correctly import decoder model from examples
  • Loading branch information
tengyifei committed Sep 3, 2024
1 parent 989ac69 commit 67cff9b
Show file tree
Hide file tree
Showing 12 changed files with 716 additions and 39 deletions.
8 changes: 4 additions & 4 deletions examples/decoder_only_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,16 @@
from torch import nn


# the default config is intentionally kept low to make it runable on a sigle tpu v2-8 core.
# the default config is intentionally kept low to make it runnable on a single tpu v2-8 core.
@dataclass
class DecoderOnlyConfig:
hidden_size: int = 1024
num_hidden_layers: int = 2
num_attention_heads: int = 8
num_key_value_heads: int = 4
intermediate_size = 32 * 1024
vocab_size = 3200
use_flash_attention = False
intermediate_size: int = 32 * 1024
vocab_size: int = 3200
use_flash_attention: bool = False


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
Expand Down
1 change: 1 addition & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ function run_xla_op_tests2 {
run_test "$CDIR/pjrt/test_dtypes.py"
run_test "$CDIR/test_while_loop.py"
run_test "$CDIR/test_scan.py"
run_test "$CDIR/test_apply_layers.py"
run_test "$CDIR/test_autocast.py"
run_test "$CDIR/eager/test_eager.py"
run_test "$CDIR/eager/test_eager_with_xla_compile.py"
Expand Down
149 changes: 149 additions & 0 deletions test/test_apply_layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import sys
import os
example_folder = os.path.dirname(os.path.dirname(os.path.abspath(
sys.argv[0]))) + "/examples"
sys.path.append(example_folder)
from decoder_only_model import DecoderOnlyConfig, DecoderOnlyModel # type:ignore

import sys
import unittest
from typing import Iterable

import torch

import torch_xla
from torch_xla.experimental.apply_layers import apply_layers

from test_utils import XlaTestCase # type:ignore


class ApplyLayersTest(XlaTestCase):

def setUp(self):
super().setUp()

self.device = torch_xla.device()

def test_empty_layers(self):
layers = []
input_data = torch.randn(64).to(self.device)
torch_xla.sync()
output = apply_layers(layers, input_data.clone())
super().compareResults(output, input_data, abs_err=0.0001, rel_err=0.01)

def test_linear_layers(self):
# We want to apply these layers sequentially
import torch.nn as nn
layers = [nn.Linear(64, 64).to(self.device) for _ in range(10)]
input_data = torch.randn(64).to(self.device)

from copy import deepcopy
scan_layers = deepcopy(layers)
loop_layers = deepcopy(layers)
torch_xla.sync()

output = apply_layers(scan_layers, input_data.clone())
output.sum().backward()

# Test that the result is the same as for loop.
loop_output = input_data.clone()
from copy import deepcopy
for layer in loop_layers:
loop_output = layer(loop_output)
torch_xla.sync()

super().compareResults(loop_output, output, abs_err=0.0001, rel_err=0.01)

loop_output.sum().backward()
torch_xla.sync()

# Test that the gradients are the same too.
for layer_scan, layer_loop in zip(scan_layers, loop_layers):
super().compareResults(
layer_scan.weight.grad,
layer_loop.weight.grad,
abs_err=0.0001,
rel_err=0.01)
super().compareResults(
layer_scan.bias.grad,
layer_loop.bias.grad,
abs_err=0.0001,
rel_err=0.01)

def test_decoder_model(self):
# Define a decoder model that composes the decoder model in the example,
# but adds the ability to run the layers with the `scan` operator.
class DecoderOnlyModelWithScan(torch.nn.Module):

def __init__(self, **kwargs):
super(DecoderOnlyModelWithScan, self).__init__()
self.decoder = DecoderOnlyModel(**kwargs)

@property
def layers(self) -> Iterable[torch.nn.Module]:
return self.decoder.layers

def forward(
self,
input_ids: torch.Tensor,
) -> torch.Tensor:
return self.decoder.forward(input_ids)

def forward_scan(
self,
input_ids: torch.Tensor,
) -> torch.Tensor:
inputs_embeds = self.decoder.embed_tokens(input_ids)
# embed positions
assert isinstance(inputs_embeds, torch.Tensor)
# decoder layers
hidden_states = apply_layers(self.decoder.layers, inputs_embeds)
hidden_states = self.decoder.norm(hidden_states)
# [B, S, H] -> [B, S, V]
return self.decoder.output(hidden_states)

# Make it smaller for fast model run and comparisons.
config = DecoderOnlyConfig(
hidden_size=128, intermediate_size=8 * 128, vocab_size=256)
model = DecoderOnlyModelWithScan(config=config).to(self.device)
batch_size = 2
sequence_length = 8

# Generate random input_ids within the range of the vocabulary size
input_ids = torch.randint(0, config.vocab_size,
(batch_size, sequence_length)).to(self.device)

from copy import deepcopy
loop_model = deepcopy(model)
scan_model = deepcopy(model)
torch_xla.sync()

# Run the loop-based model.
loop_output = loop_model(input_ids.clone())
loop_output.sum().backward()
torch_xla.sync()

# Run again, this time using `scan`
scan_output = scan_model.forward_scan(input_ids.clone())
scan_output.sum().backward()
torch_xla.sync()

# Compare results
super().compareResults(scan_output, loop_output, abs_err=0.05, rel_err=0.01)

# Check gradients
for layer_scan, layer_loop in zip(scan_model.layers, loop_model.layers):
for (name,
param_scan), (name2,
param_loop) in zip(layer_scan.named_parameters(),
layer_loop.named_parameters()):
assert name == name2
if param_scan.grad is not None or param_loop.grad is not None:
super().compareResults(
param_scan.grad, param_loop.grad, abs_err=0.1, rel_err=0.05)
print(f"Pass: {name} {param_scan.shape}")


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
24 changes: 24 additions & 0 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import itertools
import math
from numbers import Number
from functools import reduce
import numpy
import random
import re
Expand Down Expand Up @@ -2597,6 +2598,29 @@ def test_api(self):
mapping = ctx.parameter_id_tensor_mapping()
self.assertEqual(len(mapping), 2)

def test_get_parameters_scalar(self):
"""Scalar tensors parameters may be shared in the HLO graph if their
numerical values are equal. `parameter_id_tensor_mapping` needs to handle
that appropriately.
"""

device = xm.xla_device()
tensors = []
for i in range(10):
# Add three copies of the same value.
tensors.append(torch.tensor(i, device=device))
tensors.append(torch.tensor(i, device=device))
tensors.append(torch.tensor(i, device=device))
result = reduce(lambda a, b: a + b, tensors)
ctx = torch_xla._XLAC.lowering.LoweringContext()
ctx.build([result])
mapping = ctx.parameter_id_tensor_mapping()

import json
hlo_json = json.loads(ctx.hlo_json())
num_parameters = len(hlo_json["hostProgramShape"]["parameters"])
self.assertEqual(len(mapping), num_parameters)


class TestGeneric(test_utils.XlaTestCase):

Expand Down
70 changes: 51 additions & 19 deletions test/test_scan.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import sys
import unittest
import torch_xla
from functools import reduce

import torch
from torch.utils._pytree import tree_map, tree_flatten, tree_iter, tree_leaves, PyTree

import torch_xla
from torch_xla.experimental.scan import scan
from torch.utils._pytree import tree_map, tree_flatten, tree_iter

from test_utils import XlaTestCase
from test_utils import XlaTestCase # type:ignore


def _loopy_scan(fn, init, xs):
Expand All @@ -24,6 +27,8 @@ def _loopy_scan(fn, init, xs):
class ScanTest(XlaTestCase):

def setUp(self):
super().setUp()

self.device = torch_xla.device()

def compare_pytree(self, expected_pytree, actual_pytree):
Expand All @@ -32,31 +37,54 @@ def compare_pytree(self, expected_pytree, actual_pytree):
assert expected_spec == actual_spec
super().compareResults(flat_expected_pytree, flat_actual_pytree)

def run_test(self, step_fn, init, xs):
def run_test(self, fn, init: PyTree, xs: PyTree):
"""Compares the result of scanning with `fn` with our optimized HLO implementation
against a for loop implementation. Checks both output values and gradients.
"""
# Actual output
final_carry, ys = scan(step_fn, init, xs)
init_scan = tree_map(lambda v: v.detach().requires_grad_(), init)
xs_scan = tree_map(lambda v: v.detach().requires_grad_(), xs)
final_carry, ys = scan(fn, init_scan, xs_scan)
# Add up all leaves in `ys` and `backward()` once.
reduce(lambda a, b: a + b, map(lambda v: v.sum(), tree_leaves(ys)),
torch.tensor(0.0)).backward()
torch_xla.sync()

# Expected output
expected_final_carry, expected_ys = _loopy_scan(step_fn, init, xs)
init_loop = tree_map(lambda v: v.detach().requires_grad_(), init)
xs_loop = tree_map(lambda v: v.detach().requires_grad_(), xs)
expected_final_carry, expected_ys = _loopy_scan(fn, init_loop, xs_loop)
# Add up all leaves in `ys` and `backward()` once.
reduce(lambda a, b: a + b, map(lambda v: v.sum(), tree_leaves(expected_ys)),
torch.tensor(0.0)).backward()
torch_xla.sync()

# Compare
# Compare values
self.compare_pytree(expected_final_carry, final_carry)
self.compare_pytree(expected_ys, ys)

# Compare gradients
self.compare_pytree(
tree_map(lambda v: v.grad, init_scan),
tree_map(lambda v: v.grad, init_loop))
self.compare_pytree(
tree_map(lambda v: v.grad, xs_scan), tree_map(lambda v: v.grad,
xs_loop))

return final_carry, ys

def test_scan_forward_simple(self):
def test_scan_simple(self):
"""This test uses `scan` to implement `torch.cumsum`."""

def step_fn(carry, x):
new_carry = carry + x
y = new_carry
return new_carry, y

init = torch.tensor([0.0, 0.0], device=self.device)
xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], device=self.device)
init = torch.tensor([0.0, 0.0], requires_grad=True, device=self.device)
xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
requires_grad=True,
device=self.device)
final_carry, ys = self.run_test(step_fn, init, xs)

# Also ensure that our loop-based scan is correct, with manual checks
Expand All @@ -80,26 +108,30 @@ def test_scan_incompatible_length(self):
with self.assertRaises(ValueError):
scan(lambda a, b: (a, b), init, (xs_1, xs_2))

def test_scan_forward_tuples(self):
def test_scan_tuples(self):
"""Test scanning over the leading axis of a tuple of tensors simultaneously,
which is a simple PyTree."""

def step_fn(carry, x):
def fn(carry, x):
carry1, carry2 = carry
x1, x2 = x
new_carry1 = carry1 + x1.sum()
new_carry2 = carry2 + x2.sum()
y1 = x1 * 2
y2 = x2 * 2
y1 = x1 * 2 + torch.sum(new_carry1)
y2 = x2 * 2 + torch.sum(new_carry2)
return (new_carry1, new_carry2), (y1, y2)

init = (torch.tensor([0.0], device=self.device),
torch.tensor([1.0, 2.0], device=self.device))
init = (torch.tensor([0.0], requires_grad=True, device=self.device),
torch.tensor([1.0, 2.0], requires_grad=True, device=self.device))

xs = (torch.tensor([[1.0, 2.0], [3.0, 4.0]], device=self.device),
torch.tensor([[5.0, 6.0, 7.0], [8.0, 9.0, 10.0]], device=self.device))
xs = (torch.tensor([[1.0, 2.0], [3.0, 4.0]],
requires_grad=True,
device=self.device),
torch.tensor([[5.0, 6.0, 7.0], [8.0, 9.0, 10.0]],
requires_grad=True,
device=self.device))

self.run_test(step_fn, init, xs)
self.run_test(fn, init, xs)


if __name__ == '__main__':
Expand Down
1 change: 1 addition & 0 deletions test/tpu/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ python3 test/pjrt/test_dtypes.py
python3 test/pjrt/test_dynamic_plugin_tpu.py
python3 test/test_while_loop.py
python3 test/test_scan.py
python3 test/test_apply_layers.py
python3 test/test_pallas.py
python3 test/test_pallas_spmd.py
python3 test/test_input_output_aliases.py
Expand Down
10 changes: 6 additions & 4 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1064,7 +1064,6 @@ class PyLoweringContext {
// etc.)
std::unordered_map<int64_t, at::Tensor> GetParameterIdTensorMapping() {
// Find parameters in the lowering
const std::vector<size_t>& param_ids = lowering_ctx.GetParameterSequence();
const std::vector<torch::lazy::BackendDataPtr>& device_data =
lowering_ctx.GetParametersData();

Expand All @@ -1081,7 +1080,9 @@ class PyLoweringContext {
at::ScalarType dtype =
MaybeUpcastToHostTorchType(literal.shape().element_type());
at::Tensor input = MakeTensorFromXlaLiteral(literal, dtype);
results[param_ids[i]] = input;
std::optional param_id = lowering_ctx.GetParameterId(device_data[i]);
XLA_CHECK(param_id.has_value());
results[param_id.value()] = input;
}
return results;
}
Expand All @@ -1104,12 +1105,13 @@ class PyLoweringContext {
torch::lazy::BackendData::Handle handle = data->GetHandle();

// Linearly search parameters and compare opaque handles
const std::vector<size_t>& param_ids = lowering_ctx.GetParameterSequence();
const std::vector<torch::lazy::BackendDataPtr>& device_data =
lowering_ctx.GetParametersData();
for (int i = 0; i < device_data.size(); ++i) {
if (device_data[i]->GetHandle() == handle) {
return param_ids[i];
std::optional param_id = lowering_ctx.GetParameterId(device_data[i]);
XLA_CHECK(param_id.has_value());
return param_id.value();
}
}
return -1;
Expand Down
Loading

0 comments on commit 67cff9b

Please sign in to comment.