Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

scan and apply_layers #7901

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions examples/decoder_only_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,16 @@
from torch import nn


# the default config is intentionally kept low to make it runable on a sigle tpu v2-8 core.
# the default config is intentionally kept low to make it runnable on a single tpu v2-8 core.
@dataclass
class DecoderOnlyConfig:
hidden_size: int = 1024
num_hidden_layers: int = 2
num_attention_heads: int = 8
num_key_value_heads: int = 4
intermediate_size = 32 * 1024
vocab_size = 3200
use_flash_attention = False
intermediate_size: int = 32 * 1024
vocab_size: int = 3200
use_flash_attention: bool = False


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
Expand Down
1 change: 1 addition & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ function run_xla_op_tests2 {
run_test "$CDIR/pjrt/test_dtypes.py"
run_test "$CDIR/test_while_loop.py"
run_test "$CDIR/test_scan.py"
run_test "$CDIR/test_apply_layers.py"
run_test "$CDIR/test_autocast.py"
run_test "$CDIR/eager/test_eager.py"
run_test "$CDIR/eager/test_eager_with_xla_compile.py"
Expand Down
189 changes: 189 additions & 0 deletions test/test_apply_layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
import sys
import os
example_folder = os.path.dirname(os.path.dirname(os.path.abspath(
sys.argv[0]))) + "/examples"
sys.path.append(example_folder)
from decoder_only_model import DecoderOnlyConfig, DecoderOnlyModel # type:ignore

import unittest
from copy import deepcopy
from typing import Iterable

import torch
import torch.nn as nn

import torch_xla
from torch_xla.experimental.apply_layers import apply_layers

from test_utils import XlaTestCase # type:ignore


class ApplyLayersTest(XlaTestCase):

def setUp(self):
super().setUp()

self.device = torch_xla.device()

def assert_different_tensor(self, a: torch.Tensor, b: torch.Tensor):
assert a is not b, f"Expected {a} and {b} to be different tensors"
assert a.data is not b.data, f"Expected {a} and {b} to have different storage"

def assert_while_found_in_hlo(self, tensor: torch.Tensor):
hlo_text = torch_xla._XLAC._get_xla_tensors_hlo([tensor])
assert "while(" in hlo_text
assert "condition=" in hlo_text
assert "body=" in hlo_text

def test_empty_layers(self):
layers = []
input_data = torch.randn(64).to(self.device)
output = apply_layers(layers, input_data.clone())
super().compareResults(output, input_data, abs_err=0.0001, rel_err=0.01)
tengyifei marked this conversation as resolved.
Show resolved Hide resolved

def test_linear_layers(self):
# We want to apply these layers sequentially
layers = [nn.Linear(64, 64).to(self.device) for _ in range(10)]
input_data = torch.randn(64).to(self.device)

scan_layers = deepcopy(layers)
loop_layers = deepcopy(layers)
torch_xla.sync()

output = apply_layers(scan_layers, input_data.clone())
output.sum().backward()

# Test that the result is the same as for loop.
loop_output = input_data.clone()
for layer in loop_layers:
loop_output = layer(loop_output)
torch_xla.sync()

super().compareResults(loop_output, output, abs_err=0.0001, rel_err=0.01)
self.assert_different_tensor(loop_output, output)

loop_output.sum().backward()
torch_xla.sync()

# Test that the gradients are the same too.
for layer_scan, layer_loop in zip(scan_layers, loop_layers):
assert layer_scan.weight.grad is not None
assert layer_loop.weight.grad is not None
assert layer_scan.bias.grad is not None
assert layer_loop.bias.grad is not None
super().compareResults(
tengyifei marked this conversation as resolved.
Show resolved Hide resolved
layer_scan.weight.grad,
layer_loop.weight.grad,
tengyifei marked this conversation as resolved.
Show resolved Hide resolved
abs_err=0.0001,
rel_err=0.01)
super().compareResults(
layer_scan.bias.grad,
layer_loop.bias.grad,
abs_err=0.0001,
rel_err=0.01)
self.assert_different_tensor(layer_scan.weight.grad,
layer_loop.weight.grad)
self.assert_different_tensor(layer_scan.bias.grad, layer_loop.bias.grad)

def test_decoder_model(self):
# Define a decoder model that composes the decoder model in the example,
# but adds the ability to run the layers with the `scan` operator.
class DecoderOnlyModelWithScan(torch.nn.Module):

def __init__(self, **kwargs):
super(DecoderOnlyModelWithScan, self).__init__()
self.decoder = DecoderOnlyModel(**kwargs)

@property
def layers(self) -> Iterable[torch.nn.Module]:
return self.decoder.layers

def forward(
self,
input_ids: torch.Tensor,
) -> torch.Tensor:
return self.decoder.forward(input_ids)

def forward_scan(
self,
input_ids: torch.Tensor,
) -> torch.Tensor:
inputs_embeds = self.decoder.embed_tokens(input_ids)
# embed positions
assert isinstance(inputs_embeds, torch.Tensor)
# decoder layers
hidden_states = apply_layers(self.decoder.layers, inputs_embeds)
hidden_states = self.decoder.norm(hidden_states)
tengyifei marked this conversation as resolved.
Show resolved Hide resolved
# [B, S, H] -> [B, S, V]
return self.decoder.output(hidden_states)

# Make it smaller for fast model run and comparisons.
config = DecoderOnlyConfig(
hidden_size=128, intermediate_size=8 * 128, vocab_size=256)
model = DecoderOnlyModelWithScan(config=config).to(self.device)
batch_size = 2
sequence_length = 8

# Generate random input_ids within the range of the vocabulary size
input_ids = torch.randint(0, config.vocab_size,
(batch_size, sequence_length)).to(self.device)

loop_model = deepcopy(model)
scan_model = deepcopy(model)
torch_xla.sync()

# Run the loop-based model.
loop_output = loop_model(input_ids.clone())
loop_output.sum().backward()
torch_xla.sync()

# Run again, this time using `scan`
scan_output = scan_model.forward_scan(input_ids.clone())
scan_output.sum().backward()

# Before materializing the tensors, check that tensor HLO has `While` in it.
self.assert_while_found_in_hlo(scan_output)
for layer_scan in scan_model.layers:
for (name, param_scan) in layer_scan.named_parameters():
if param_scan.grad is not None:
self.assert_while_found_in_hlo(param_scan.grad)

torch_xla.sync()

# Compare results
super().compareResults(scan_output, loop_output, abs_err=0.05, rel_err=0.01)

# Check gradients
for layer_scan, layer_loop in zip(scan_model.layers, loop_model.layers):
for (name,
param_scan), (name2,
param_loop) in zip(layer_scan.named_parameters(),
layer_loop.named_parameters()):
assert name == name2
# Either the parameter should have gradient in both, or it should not
# have gradient in both.
assert (param_scan.grad is not None) == (param_loop.grad is not None)
# Check gradients
if param_scan.grad is not None or param_loop.grad is not None:
super().compareResults(
param_scan.grad, param_loop.grad, abs_err=0.1, rel_err=0.05)
print(f"Pass: {name} {param_scan.shape}")

def test_heterogenous_layers(self):
layer1 = nn.Linear(128, 128).to(torch_xla.device())
layer2 = nn.Sequential(nn.Linear(128, 128).to(torch_xla.device()))
with self.assertRaisesRegex(ValueError, "mismatched set of parameters"):
apply_layers([layer1, layer2],
torch.zeros((128,), device=torch_xla.device()))

def test_mismatched_shapes(self):
layer1 = nn.Linear(128, 128).to(torch_xla.device())
layer2 = nn.Linear(128, 129).to(torch_xla.device())
with self.assertRaisesRegex(ValueError, "Shape mismatch"):
apply_layers([layer1, layer2],
torch.zeros((128,), device=torch_xla.device()))


tengyifei marked this conversation as resolved.
Show resolved Hide resolved
if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
24 changes: 24 additions & 0 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import itertools
import math
from numbers import Number
from functools import reduce
import numpy
import random
import re
Expand Down Expand Up @@ -2597,6 +2598,29 @@ def test_api(self):
mapping = ctx.parameter_id_tensor_mapping()
self.assertEqual(len(mapping), 2)

def test_get_parameters_scalar(self):
"""Scalar tensors parameters may be shared in the HLO graph if their
numerical values are equal. `parameter_id_tensor_mapping` needs to handle
that appropriately.
"""

device = torch_xla.device()
tensors = []
for i in range(10):
# Add three copies of the same value.
tensors.append(torch.tensor(i, device=device))
tensors.append(torch.tensor(i, device=device))
tensors.append(torch.tensor(i, device=device))
result = reduce(lambda a, b: a + b, tensors)
ctx = torch_xla._XLAC.lowering.LoweringContext()
ctx.build([result])
mapping = ctx.parameter_id_tensor_mapping()

import json
hlo_json = json.loads(ctx.hlo_json())
num_parameters = len(hlo_json["hostProgramShape"]["parameters"])
self.assertEqual(len(mapping), num_parameters)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so you expect both value to be 10?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately not. It looks like some integer values (e.g. values <= 2) are shared when you put multiple copies into the HLO, but values above 2 are not shared. So we don't necessarily get 10. In any case, the precise number of parameters seems to be an implementation detail that we can't reliably test.



class TestGeneric(test_utils.XlaTestCase):

Expand Down
73 changes: 54 additions & 19 deletions test/test_scan.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import sys
import unittest
import torch_xla
from functools import reduce

import torch
from torch.utils._pytree import tree_map, tree_flatten, tree_iter, tree_leaves, PyTree

import torch_xla
from torch_xla.experimental.scan import scan
from torch.utils._pytree import tree_map, tree_flatten, tree_iter

from test_utils import XlaTestCase
from test_utils import XlaTestCase # type:ignore
tengyifei marked this conversation as resolved.
Show resolved Hide resolved


def _loopy_scan(fn, init, xs):
Expand All @@ -24,6 +27,8 @@ def _loopy_scan(fn, init, xs):
class ScanTest(XlaTestCase):

def setUp(self):
super().setUp()

self.device = torch_xla.device()

def compare_pytree(self, expected_pytree, actual_pytree):
Expand All @@ -32,31 +37,54 @@ def compare_pytree(self, expected_pytree, actual_pytree):
assert expected_spec == actual_spec
super().compareResults(flat_expected_pytree, flat_actual_pytree)

def run_test(self, step_fn, init, xs):
def run_test(self, fn, init: PyTree, xs: PyTree):
"""Compares the result of scanning with `fn` with our optimized HLO implementation
against a for loop implementation. Checks both output values and gradients.
"""
# Actual output
final_carry, ys = scan(step_fn, init, xs)
init_scan = tree_map(lambda v: v.detach().requires_grad_(), init)
xs_scan = tree_map(lambda v: v.detach().requires_grad_(), xs)
final_carry, ys = scan(fn, init_scan, xs_scan)
# Add up all leaves in `ys` and `backward()` once.
reduce(lambda a, b: a + b, map(lambda v: v.sum(), tree_leaves(ys)),
torch.tensor(0.0)).backward()
torch_xla.sync()

# Expected output
expected_final_carry, expected_ys = _loopy_scan(step_fn, init, xs)
init_loop = tree_map(lambda v: v.detach().requires_grad_(), init)
xs_loop = tree_map(lambda v: v.detach().requires_grad_(), xs)
expected_final_carry, expected_ys = _loopy_scan(fn, init_loop, xs_loop)
# Add up all leaves in `ys` and `backward()` once.
reduce(lambda a, b: a + b, map(lambda v: v.sum(), tree_leaves(expected_ys)),
torch.tensor(0.0)).backward()
torch_xla.sync()

# Compare
# Compare values
self.compare_pytree(expected_final_carry, final_carry)
self.compare_pytree(expected_ys, ys)

# Compare gradients
self.compare_pytree(
tree_map(lambda v: v.grad, init_scan),
tree_map(lambda v: v.grad, init_loop))
self.compare_pytree(
tree_map(lambda v: v.grad, xs_scan), tree_map(lambda v: v.grad,
xs_loop))

return final_carry, ys

def test_scan_forward_simple(self):
def test_scan_simple(self):
"""This test uses `scan` to implement `torch.cumsum`."""

def step_fn(carry, x):
new_carry = carry + x
y = new_carry
return new_carry, y

init = torch.tensor([0.0, 0.0], device=self.device)
xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], device=self.device)
init = torch.tensor([0.0, 0.0], requires_grad=True, device=self.device)
xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
requires_grad=True,
device=self.device)
final_carry, ys = self.run_test(step_fn, init, xs)

# Also ensure that our loop-based scan is correct, with manual checks
Expand All @@ -80,26 +108,33 @@ def test_scan_incompatible_length(self):
with self.assertRaises(ValueError):
scan(lambda a, b: (a, b), init, (xs_1, xs_2))

def test_scan_forward_tuples(self):
def test_scan_tuples(self):
"""Test scanning over the leading axis of a tuple of tensors simultaneously,
which is a simple PyTree."""

def step_fn(carry, x):
def fn(carry, x):
carry1, carry2 = carry
x1, x2 = x
new_carry1 = carry1 + x1.sum()
new_carry2 = carry2 + x2.sum()
y1 = x1 * 2
y2 = x2 * 2
y1 = x1 * 2 + torch.sum(new_carry1)
y2 = x2 * 2 + torch.sum(new_carry2)
return (new_carry1, new_carry2), (y1, y2)

init = (torch.tensor([0.0], device=self.device),
torch.tensor([1.0, 2.0], device=self.device))
init = (torch.tensor([0.0], requires_grad=True, device=self.device),
torch.tensor([1.0, 2.0], requires_grad=True, device=self.device))

xs = (torch.tensor([[1.0, 2.0], [3.0, 4.0]],
requires_grad=True,
device=self.device),
torch.tensor([[5.0, 6.0, 7.0], [8.0, 9.0, 10.0]],
requires_grad=True,
device=self.device))

xs = (torch.tensor([[1.0, 2.0], [3.0, 4.0]], device=self.device),
torch.tensor([[5.0, 6.0, 7.0], [8.0, 9.0, 10.0]], device=self.device))
self.run_test(fn, init, xs)

self.run_test(step_fn, init, xs)
# TODO(yifeit): Add a test involving in-place updates
# TODO(yifeit): Add a test involving RNG


if __name__ == '__main__':
Expand Down
1 change: 1 addition & 0 deletions test/tpu/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ python3 test/pjrt/test_dtypes.py
python3 test/pjrt/test_dynamic_plugin_tpu.py
python3 test/test_while_loop.py
python3 test/test_scan.py
python3 test/test_apply_layers.py
python3 test/test_pallas.py
python3 test/test_pallas_spmd.py
python3 test/test_input_output_aliases.py
Expand Down
Loading
Loading