Skip to content

Commit

Permalink
scan and apply_layers
Browse files Browse the repository at this point in the history
Add the lowering of scan to HLO While op.

Introduce apply_layers which can sequentially apply a bunch of layers
using scan underneath.

Beef up unit tests including linear layers and decoders.

add regression test for parameter_id_tensor_mapping

add test_apply_layers.py to test shell scripts

correctly import decoder model from examples
  • Loading branch information
tengyifei committed Sep 25, 2024
1 parent 72fda76 commit ea640ab
Show file tree
Hide file tree
Showing 12 changed files with 742 additions and 39 deletions.
8 changes: 4 additions & 4 deletions examples/decoder_only_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,16 @@
from torch import nn


# the default config is intentionally kept low to make it runable on a sigle tpu v2-8 core.
# the default config is intentionally kept low to make it runnable on a single tpu v2-8 core.
@dataclass
class DecoderOnlyConfig:
hidden_size: int = 1024
num_hidden_layers: int = 2
num_attention_heads: int = 8
num_key_value_heads: int = 4
intermediate_size = 32 * 1024
vocab_size = 3200
use_flash_attention = False
intermediate_size: int = 32 * 1024
vocab_size: int = 3200
use_flash_attention: bool = False


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
Expand Down
1 change: 1 addition & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ function run_xla_op_tests2 {
run_test "$CDIR/pjrt/test_dtypes.py"
run_test "$CDIR/test_while_loop.py"
run_test "$CDIR/test_scan.py"
run_test "$CDIR/test_apply_layers.py"
run_test "$CDIR/test_autocast.py"
run_test "$CDIR/eager/test_eager.py"
run_test "$CDIR/eager/test_eager_with_xla_compile.py"
Expand Down
175 changes: 175 additions & 0 deletions test/test_apply_layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
import sys
import os
example_folder = os.path.dirname(os.path.dirname(os.path.abspath(
sys.argv[0]))) + "/examples"
sys.path.append(example_folder)
from decoder_only_model import DecoderOnlyConfig, DecoderOnlyModel # type:ignore

import unittest
from copy import deepcopy
from typing import Iterable

import torch
import torch.nn as nn

import torch_xla
from torch_xla.experimental.apply_layers import apply_layers

from test_utils import XlaTestCase # type:ignore


class ApplyLayersTest(XlaTestCase):

def setUp(self):
super().setUp()

self.device = torch_xla.device()

def assert_different_tensor(self, a: torch.Tensor, b: torch.Tensor):
assert a is not b, f"Expected {a} and {b} to be different tensors"
assert a.data is not b.data, f"Expected {a} and {b} to have different storage"

def assert_while_found_in_hlo(self, tensor: torch.Tensor):
hlo_text = torch_xla._XLAC._get_xla_tensors_hlo([tensor])
assert "while(" in hlo_text
assert "condition=" in hlo_text
assert "body=" in hlo_text

def test_empty_layers(self):
layers = []
input_data = torch.randn(64).to(self.device)
output = apply_layers(layers, input_data.clone())
super().compareResults(output, input_data, abs_err=0.0001, rel_err=0.01)

def test_linear_layers(self):
# We want to apply these layers sequentially
layers = [nn.Linear(64, 64).to(self.device) for _ in range(10)]
input_data = torch.randn(64).to(self.device)

scan_layers = deepcopy(layers)
loop_layers = deepcopy(layers)
torch_xla.sync()

output = apply_layers(scan_layers, input_data.clone())
output.sum().backward()

# Test that the result is the same as for loop.
loop_output = input_data.clone()
for layer in loop_layers:
loop_output = layer(loop_output)
torch_xla.sync()

super().compareResults(loop_output, output, abs_err=0.0001, rel_err=0.01)
self.assert_different_tensor(loop_output, output)

loop_output.sum().backward()
torch_xla.sync()

# Test that the gradients are the same too.
for layer_scan, layer_loop in zip(scan_layers, loop_layers):
assert layer_scan.weight.grad is not None
assert layer_loop.weight.grad is not None
assert layer_scan.bias.grad is not None
assert layer_loop.bias.grad is not None
super().compareResults(
layer_scan.weight.grad,
layer_loop.weight.grad,
abs_err=0.0001,
rel_err=0.01)
super().compareResults(
layer_scan.bias.grad,
layer_loop.bias.grad,
abs_err=0.0001,
rel_err=0.01)
self.assert_different_tensor(layer_scan.weight.grad,
layer_loop.weight.grad)
self.assert_different_tensor(layer_scan.bias.grad, layer_loop.bias.grad)

def test_decoder_model(self):
# Define a decoder model that composes the decoder model in the example,
# but adds the ability to run the layers with the `scan` operator.
class DecoderOnlyModelWithScan(torch.nn.Module):

def __init__(self, **kwargs):
super(DecoderOnlyModelWithScan, self).__init__()
self.decoder = DecoderOnlyModel(**kwargs)

@property
def layers(self) -> Iterable[torch.nn.Module]:
return self.decoder.layers

def forward(
self,
input_ids: torch.Tensor,
) -> torch.Tensor:
return self.decoder.forward(input_ids)

def forward_scan(
self,
input_ids: torch.Tensor,
) -> torch.Tensor:
inputs_embeds = self.decoder.embed_tokens(input_ids)
# embed positions
assert isinstance(inputs_embeds, torch.Tensor)
# decoder layers
hidden_states = apply_layers(self.decoder.layers, inputs_embeds)
hidden_states = self.decoder.norm(hidden_states)
# [B, S, H] -> [B, S, V]
return self.decoder.output(hidden_states)

# Make it smaller for fast model run and comparisons.
config = DecoderOnlyConfig(
hidden_size=128, intermediate_size=8 * 128, vocab_size=256)
model = DecoderOnlyModelWithScan(config=config).to(self.device)
batch_size = 2
sequence_length = 8

# Generate random input_ids within the range of the vocabulary size
input_ids = torch.randint(0, config.vocab_size,
(batch_size, sequence_length)).to(self.device)

loop_model = deepcopy(model)
scan_model = deepcopy(model)
torch_xla.sync()

# Run the loop-based model.
loop_output = loop_model(input_ids.clone())
loop_output.sum().backward()
torch_xla.sync()

# Run again, this time using `scan`
scan_output = scan_model.forward_scan(input_ids.clone())
scan_output.sum().backward()

# Before materializing the tensors, check that tensor HLO has `While` in it.
self.assert_while_found_in_hlo(scan_output)
for layer_scan in scan_model.layers:
for (name, param_scan) in layer_scan.named_parameters():
if param_scan.grad is not None:
self.assert_while_found_in_hlo(param_scan.grad)

torch_xla.sync()

# Compare results
super().compareResults(scan_output, loop_output, abs_err=0.05, rel_err=0.01)

# Check gradients
for layer_scan, layer_loop in zip(scan_model.layers, loop_model.layers):
for (name,
param_scan), (name2,
param_loop) in zip(layer_scan.named_parameters(),
layer_loop.named_parameters()):
assert name == name2
# Either the parameter should have gradient in both, or it should not
# have gradient in both.
assert (param_scan.grad is not None) == (param_loop.grad is not None)
# Check gradients
if param_scan.grad is not None or param_loop.grad is not None:
super().compareResults(
param_scan.grad, param_loop.grad, abs_err=0.1, rel_err=0.05)
print(f"Pass: {name} {param_scan.shape}")


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
24 changes: 24 additions & 0 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import itertools
import math
from numbers import Number
from functools import reduce
import numpy
import random
import re
Expand Down Expand Up @@ -2597,6 +2598,29 @@ def test_api(self):
mapping = ctx.parameter_id_tensor_mapping()
self.assertEqual(len(mapping), 2)

def test_get_parameters_scalar(self):
"""Scalar tensors parameters may be shared in the HLO graph if their
numerical values are equal. `parameter_id_tensor_mapping` needs to handle
that appropriately.
"""

device = torch_xla.device()
tensors = []
for i in range(10):
# Add three copies of the same value.
tensors.append(torch.tensor(i, device=device))
tensors.append(torch.tensor(i, device=device))
tensors.append(torch.tensor(i, device=device))
result = reduce(lambda a, b: a + b, tensors)
ctx = torch_xla._XLAC.lowering.LoweringContext()
ctx.build([result])
mapping = ctx.parameter_id_tensor_mapping()

import json
hlo_json = json.loads(ctx.hlo_json())
num_parameters = len(hlo_json["hostProgramShape"]["parameters"])
self.assertEqual(len(mapping), num_parameters)


class TestGeneric(test_utils.XlaTestCase):

Expand Down
70 changes: 51 additions & 19 deletions test/test_scan.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import sys
import unittest
import torch_xla
from functools import reduce

import torch
from torch.utils._pytree import tree_map, tree_flatten, tree_iter, tree_leaves, PyTree

import torch_xla
from torch_xla.experimental.scan import scan
from torch.utils._pytree import tree_map, tree_flatten, tree_iter

from test_utils import XlaTestCase
from test_utils import XlaTestCase # type:ignore


def _loopy_scan(fn, init, xs):
Expand All @@ -24,6 +27,8 @@ def _loopy_scan(fn, init, xs):
class ScanTest(XlaTestCase):

def setUp(self):
super().setUp()

self.device = torch_xla.device()

def compare_pytree(self, expected_pytree, actual_pytree):
Expand All @@ -32,31 +37,54 @@ def compare_pytree(self, expected_pytree, actual_pytree):
assert expected_spec == actual_spec
super().compareResults(flat_expected_pytree, flat_actual_pytree)

def run_test(self, step_fn, init, xs):
def run_test(self, fn, init: PyTree, xs: PyTree):
"""Compares the result of scanning with `fn` with our optimized HLO implementation
against a for loop implementation. Checks both output values and gradients.
"""
# Actual output
final_carry, ys = scan(step_fn, init, xs)
init_scan = tree_map(lambda v: v.detach().requires_grad_(), init)
xs_scan = tree_map(lambda v: v.detach().requires_grad_(), xs)
final_carry, ys = scan(fn, init_scan, xs_scan)
# Add up all leaves in `ys` and `backward()` once.
reduce(lambda a, b: a + b, map(lambda v: v.sum(), tree_leaves(ys)),
torch.tensor(0.0)).backward()
torch_xla.sync()

# Expected output
expected_final_carry, expected_ys = _loopy_scan(step_fn, init, xs)
init_loop = tree_map(lambda v: v.detach().requires_grad_(), init)
xs_loop = tree_map(lambda v: v.detach().requires_grad_(), xs)
expected_final_carry, expected_ys = _loopy_scan(fn, init_loop, xs_loop)
# Add up all leaves in `ys` and `backward()` once.
reduce(lambda a, b: a + b, map(lambda v: v.sum(), tree_leaves(expected_ys)),
torch.tensor(0.0)).backward()
torch_xla.sync()

# Compare
# Compare values
self.compare_pytree(expected_final_carry, final_carry)
self.compare_pytree(expected_ys, ys)

# Compare gradients
self.compare_pytree(
tree_map(lambda v: v.grad, init_scan),
tree_map(lambda v: v.grad, init_loop))
self.compare_pytree(
tree_map(lambda v: v.grad, xs_scan), tree_map(lambda v: v.grad,
xs_loop))

return final_carry, ys

def test_scan_forward_simple(self):
def test_scan_simple(self):
"""This test uses `scan` to implement `torch.cumsum`."""

def step_fn(carry, x):
new_carry = carry + x
y = new_carry
return new_carry, y

init = torch.tensor([0.0, 0.0], device=self.device)
xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], device=self.device)
init = torch.tensor([0.0, 0.0], requires_grad=True, device=self.device)
xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
requires_grad=True,
device=self.device)
final_carry, ys = self.run_test(step_fn, init, xs)

# Also ensure that our loop-based scan is correct, with manual checks
Expand All @@ -80,26 +108,30 @@ def test_scan_incompatible_length(self):
with self.assertRaises(ValueError):
scan(lambda a, b: (a, b), init, (xs_1, xs_2))

def test_scan_forward_tuples(self):
def test_scan_tuples(self):
"""Test scanning over the leading axis of a tuple of tensors simultaneously,
which is a simple PyTree."""

def step_fn(carry, x):
def fn(carry, x):
carry1, carry2 = carry
x1, x2 = x
new_carry1 = carry1 + x1.sum()
new_carry2 = carry2 + x2.sum()
y1 = x1 * 2
y2 = x2 * 2
y1 = x1 * 2 + torch.sum(new_carry1)
y2 = x2 * 2 + torch.sum(new_carry2)
return (new_carry1, new_carry2), (y1, y2)

init = (torch.tensor([0.0], device=self.device),
torch.tensor([1.0, 2.0], device=self.device))
init = (torch.tensor([0.0], requires_grad=True, device=self.device),
torch.tensor([1.0, 2.0], requires_grad=True, device=self.device))

xs = (torch.tensor([[1.0, 2.0], [3.0, 4.0]], device=self.device),
torch.tensor([[5.0, 6.0, 7.0], [8.0, 9.0, 10.0]], device=self.device))
xs = (torch.tensor([[1.0, 2.0], [3.0, 4.0]],
requires_grad=True,
device=self.device),
torch.tensor([[5.0, 6.0, 7.0], [8.0, 9.0, 10.0]],
requires_grad=True,
device=self.device))

self.run_test(step_fn, init, xs)
self.run_test(fn, init, xs)


if __name__ == '__main__':
Expand Down
1 change: 1 addition & 0 deletions test/tpu/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ python3 test/pjrt/test_dtypes.py
python3 test/pjrt/test_dynamic_plugin_tpu.py
python3 test/test_while_loop.py
python3 test/test_scan.py
python3 test/test_apply_layers.py
python3 test/test_pallas.py
python3 test/test_pallas_spmd.py
python3 test/test_input_output_aliases.py
Expand Down
Loading

0 comments on commit ea640ab

Please sign in to comment.