Skip to content

Commit

Permalink
Fix a bug in convolution for torch_xla2. Enable tests (pytorch#6496)
Browse files Browse the repository at this point in the history
  • Loading branch information
qihqi authored and amithrm committed Mar 1, 2024
1 parent 1e93b22 commit 96c8e30
Show file tree
Hide file tree
Showing 9 changed files with 199 additions and 25 deletions.
44 changes: 44 additions & 0 deletions .github/workflows/torch_xla2.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
on:
pull_request:
branches:
- master
- r[0-9]+.[0-9]+
paths:
- 'experimental/torch_xla2/**'
push:
branches:
- master
- r[0-9]+.[0-9]+
paths:
- 'experimental/torch_xla2/**'
workflow_dispatch:

concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
cancel-in-progress: true

jobs:
torchxla2-cpu:
runs-on: ubuntu-20.04
steps:
- name: Checkout repo
uses: actions/checkout@v4
with:
sparse-checkout: |
experimental/torch_xla2
- name: Setup Python
uses: actions/setup-python@v4
with:
python-version: '3.10'
- name: Install
shell: bash
working-directory: experimental/torch_xla2
run: |
pip install pytest absl-py jax[cpu] flatbuffers tensorflow
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu
pip install -e .
- name: Run tests
working-directory: experimental/torch_xla2
shell: bash
run: |
pytest test/
27 changes: 22 additions & 5 deletions experimental/torch_xla2/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,22 @@
jax==0.4.24.dev20240202
jaxlib==0.4.24.dev20240202
numpy==1.26.3
torch==2.2.0
typing_extensions==4.9.0
certifi==2022.12.7
charset-normalizer==2.1.1
filelock==3.9.0
fsspec==2023.4.0
idna==3.4
jax==0.4.24
jaxlib==0.4.24
Jinja2==3.1.2
MarkupSafe==2.1.3
ml-dtypes==0.3.2
mpmath==1.2.1
networkx==3.0rc1
numpy==1.24.1
opt-einsum==3.3.0
Pillow==9.3.0
requests==2.28.1
scipy==1.12.0
sympy==1.11.1
--pre
torch
typing_extensions==4.8.0
urllib3==1.26.13
4 changes: 4 additions & 0 deletions experimental/torch_xla2/test/llama/llama_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,10 @@ def forward(

k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
print('q=', q.shape)
print('k=', k.shape)
print('v=', v.shape)
print('mask=', mask.shape)
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)

y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
Expand Down
79 changes: 79 additions & 0 deletions experimental/torch_xla2/test/test_conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import torch
from torch import nn
import torch_xla2
from . import test_base

class CustomConv1(torch.nn.Module):

def __init__(
self,
channels_conv1=3,
width_conv1=3,
channels_conv2=5,
width_conv2=5,
hidden_layer_size=50,
):
super(CustomConv1, self).__init__()
self.conv1 = nn.Conv1d(1, channels_conv1, width_conv1)
self.conv2 = nn.Conv1d(channels_conv1, channels_conv2, width_conv2)
self.fc1 = nn.Linear(hidden_layer_size, 2)

def forward(self, x):
x = nn.functional.max_pool1d(nn.functional.relu(self.conv1(x)), 2, stride=2)
x = nn.functional.max_pool1d(nn.functional.relu(self.conv2(x)), 2, stride=2)
x = torch.flatten(x, 1)
x = nn.functional.softmax(self.fc1(x), dim=1)
return x


class CustomConv2(nn.Module):

def __init__(self):
super().__init__()
inp = 4
out = 16

self.conv = nn.Conv2d(inp, out, kernel_size=3, padding=1)

# This is supposed to be a squeeze and excitation block.
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))

self.scale = nn.Sequential(nn.Linear(out, out), nn.Sigmoid())

def forward(self, x):
x = self.conv(x)

b = x.shape[0]
ap = self.avg_pool(x).view(b, -1)
ap = self.scale(ap)
ap = ap.view(b, -1, 1, 1)

return x * ap


class ConvTest(test_base.TestCase):

def test_conv1(self):
m = CustomConv1()
arg = torch.randn((20, 1, 50))
res = m(arg)

jax_weights, jax_func = torch_xla2.extract_jax(m)
arg = torch_xla2.tensor.t2j(arg)
res2 = jax_func(jax_weights, (arg, ))
res2_torch = torch_xla2.tensor.j2t(res2)
self.assertTrue(torch.allclose(res, res2_torch))

def test_conv2(self):
m = CustomConv2()
arg = torch.randn((20, 4, 50, 100))
res = m(arg)
jax_weights, jax_func = torch_xla2.extract_jax(m)
arg = torch_xla2.tensor.t2j(arg)
res2 = jax_func(jax_weights, (arg, ))
res2_torch = torch_xla2.tensor.j2t(res2)
self.assertTrue(torch.allclose(res, res2_torch, atol=1e-4, rtol=1e-4))


if __name__ == '__main__':
test_base.main()
12 changes: 12 additions & 0 deletions experimental/torch_xla2/test/test_core_aten_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1762,6 +1762,18 @@ def test_aten_index_put_2(self):
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.index_put, args, kwargs)

def test_aten_index_put_3(self):
args = (
torch.randint(0, 10, (10, 10)).to(torch.int32),
[
torch.randint(0, 10, (1,)).to(torch.int64),
],
torch.randint(0, 10, (10,)).to(torch.int32),
True,
)
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.index_put, args, kwargs)

def test_aten_index_select_0(self):
args = (
torch.randn((2, 10)).to(torch.float32),
Expand Down
10 changes: 7 additions & 3 deletions experimental/torch_xla2/torch_xla2/__init__.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
import jax
import torch
import torch._functorch
from torch._functorch import make_functional
from torch.utils import _pytree as pytree
from torch_xla2 import tensor
from torch_xla2 import export, ops, ops_registry, tensor, tf_integration


def extract_jax(mod: torch.nn.Module):
"""Returns a pytree of jax.ndarray and a jax callable."""
func, weights, buffer = torch._functorch.make_functional_with_buffers(mod)
func, weights, buffer = make_functional.make_functional_with_buffers(mod)
states = (weights, buffer)
states = pytree.tree_map_only(torch.Tensor, tensor.t2j, states)

@jax.jit
def jax_func(states, inputs):
(states, inputs) = tensor.wrap((states, inputs))
weights, buffer = states
res = func(weights, buffer, *inputs)
with tensor.XLADispatchMode():
res = func(weights, buffer, *inputs)
return tensor.unwrap(res)

return states, jax_func
15 changes: 10 additions & 5 deletions experimental/torch_xla2/torch_xla2/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def call_function(self, target, args: Tuple, kwargs: Dict) -> Any:
):
return super().call_function(target, args, kwargs)

print('Running ', target.name(), '--------')
# print('Running ', target.name(), '--------')

op = ops_registry.lowerings.lookup(target)
if op is None:
Expand All @@ -166,13 +166,18 @@ def call_function(self, target, args: Tuple, kwargs: Dict) -> Any:

def run_node(self, n) -> Any:
res = super().run_node(n)
if n.op == 'call_function':
if hasattr(res, 'shape'):
print('Meta:', n.meta.get('val').shape, 'REAL: ', res.shape)
#if n.op == 'call_function':
# if hasattr(res, 'shape'):
# print('Meta:', n.meta.get('val').shape, 'REAL: ', res.shape)
return res

from torch._decomp import get_decompositions
import torch._refs
_extra_decomp = get_decompositions(
[torch.ops.aten.unfold]
)


_extra_decomp = {}


def exported_program_to_jax(exported_program):
Expand Down
25 changes: 16 additions & 9 deletions experimental/torch_xla2/torch_xla2/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
"""Torch ops implemented using jax."""
import sys

import flax
import jax
from jax import numpy as jnp
import numpy as np
Expand Down Expand Up @@ -298,10 +297,13 @@ def _aten_empty(sizes, **kwargs):

@op(torch.ops.aten.index_put_)
@op(torch.ops.aten.index_put)
def _aten_index_put(self, indexes, values):
def _aten_index_put(self, indexes, values, accumulate=False):
indexes = [slice(None, None, None) if i is None else i for i in indexes]
indexes = tuple(indexes)
return self.at[indexes].set(values)
if accumulate:
return self.at[indexes].add(values)
else:
return self.at[indexes].set(values)


@op(torch.ops.aten.index)
Expand Down Expand Up @@ -508,8 +510,11 @@ def create_default_conv_dimension_numbers(num_spatial_dims):
)

if bias is not None:
# TODO(qihqi): this is wrong
bias = bias.reshape(bias.shape + (1,))
# TODO(qihqi): bias always on channel?
if len(bias.shape) == 1:
shape = [1] * len(res.shape)
shape[1] = bias.shape[0]
bias = bias.reshape(tuple(shape))
res = res + bias
return res

Expand Down Expand Up @@ -1200,19 +1205,21 @@ def _aten_arange(
start,
end=None,
step=1,
*,
dtype=None,
layout=None,
requires_grad=False,
device=None,
pin_memory=False,
pin_memory=False
):
if end is None:
end = start
start = 0
return jnp.arange(
start,
end,
step,
dtype=dtype,
layout=layout,
device=device,
pin_memory=pin_memory,
)


Expand Down
8 changes: 5 additions & 3 deletions experimental/torch_xla2/torch_xla2/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,12 @@ def __new__(cls, elem):
shape[i] = 1
if dtype is None:
dtype = torch.float32
return torch.Tensor._make_subclass(
return torch.Tensor._make_wrapper_subclass(
cls,
torch.empty(shape, dtype=dtype, device="meta"),
require_grad=False,
shape,
dtype=dtype,
device='meta',
requires_grad=False,
)

def __init__(self, elem: jax.Array):
Expand Down

0 comments on commit 96c8e30

Please sign in to comment.