Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix a bug in convolution for torch_xla2. Enable tests #6496

Merged
merged 1 commit into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions .github/workflows/torch_xla2.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
on:
pull_request:
branches:
- master
- r[0-9]+.[0-9]+
paths:
- 'experimental/torch_xla2/**'
push:
branches:
- master
- r[0-9]+.[0-9]+
paths:
- 'experimental/torch_xla2/**'
workflow_dispatch:

concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
cancel-in-progress: true

jobs:
torchxla2-cpu:
runs-on: ubuntu-20.04
steps:
- name: Checkout repo
uses: actions/checkout@v4
with:
sparse-checkout: |
experimental/torch_xla2
- name: Setup Python
uses: actions/setup-python@v4
with:
python-version: '3.10'
- name: Install
shell: bash
working-directory: experimental/torch_xla2
run: |
pip install pytest absl-py jax[cpu] flatbuffers tensorflow
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu
pip install -e .
- name: Run tests
working-directory: experimental/torch_xla2
shell: bash
run: |
pytest test/
27 changes: 22 additions & 5 deletions experimental/torch_xla2/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,22 @@
jax==0.4.24.dev20240202
jaxlib==0.4.24.dev20240202
numpy==1.26.3
torch==2.2.0
typing_extensions==4.9.0
certifi==2022.12.7
charset-normalizer==2.1.1
filelock==3.9.0
fsspec==2023.4.0
idna==3.4
jax==0.4.24
jaxlib==0.4.24
Jinja2==3.1.2
MarkupSafe==2.1.3
ml-dtypes==0.3.2
mpmath==1.2.1
networkx==3.0rc1
numpy==1.24.1
opt-einsum==3.3.0
Pillow==9.3.0
requests==2.28.1
scipy==1.12.0
sympy==1.11.1
--pre
torch
typing_extensions==4.8.0
urllib3==1.26.13
4 changes: 4 additions & 0 deletions experimental/torch_xla2/test/llama/llama_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,10 @@ def forward(

k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
print('q=', q.shape)
print('k=', k.shape)
print('v=', v.shape)
print('mask=', mask.shape)
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)

y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
Expand Down
79 changes: 79 additions & 0 deletions experimental/torch_xla2/test/test_conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import torch
from torch import nn
import torch_xla2
from . import test_base

class CustomConv1(torch.nn.Module):

def __init__(
self,
channels_conv1=3,
width_conv1=3,
channels_conv2=5,
width_conv2=5,
hidden_layer_size=50,
):
super(CustomConv1, self).__init__()
self.conv1 = nn.Conv1d(1, channels_conv1, width_conv1)
self.conv2 = nn.Conv1d(channels_conv1, channels_conv2, width_conv2)
self.fc1 = nn.Linear(hidden_layer_size, 2)

def forward(self, x):
x = nn.functional.max_pool1d(nn.functional.relu(self.conv1(x)), 2, stride=2)
x = nn.functional.max_pool1d(nn.functional.relu(self.conv2(x)), 2, stride=2)
x = torch.flatten(x, 1)
x = nn.functional.softmax(self.fc1(x), dim=1)
return x


class CustomConv2(nn.Module):

def __init__(self):
super().__init__()
inp = 4
out = 16

self.conv = nn.Conv2d(inp, out, kernel_size=3, padding=1)

# This is supposed to be a squeeze and excitation block.
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))

self.scale = nn.Sequential(nn.Linear(out, out), nn.Sigmoid())

def forward(self, x):
x = self.conv(x)

b = x.shape[0]
ap = self.avg_pool(x).view(b, -1)
ap = self.scale(ap)
ap = ap.view(b, -1, 1, 1)

return x * ap


class ConvTest(test_base.TestCase):

def test_conv1(self):
m = CustomConv1()
arg = torch.randn((20, 1, 50))
res = m(arg)

jax_weights, jax_func = torch_xla2.extract_jax(m)
arg = torch_xla2.tensor.t2j(arg)
res2 = jax_func(jax_weights, (arg, ))
res2_torch = torch_xla2.tensor.j2t(res2)
self.assertTrue(torch.allclose(res, res2_torch))

def test_conv2(self):
m = CustomConv2()
arg = torch.randn((20, 4, 50, 100))
res = m(arg)
jax_weights, jax_func = torch_xla2.extract_jax(m)
arg = torch_xla2.tensor.t2j(arg)
res2 = jax_func(jax_weights, (arg, ))
res2_torch = torch_xla2.tensor.j2t(res2)
self.assertTrue(torch.allclose(res, res2_torch, atol=1e-4, rtol=1e-4))


if __name__ == '__main__':
test_base.main()
12 changes: 12 additions & 0 deletions experimental/torch_xla2/test/test_core_aten_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1762,6 +1762,18 @@ def test_aten_index_put_2(self):
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.index_put, args, kwargs)

def test_aten_index_put_3(self):
args = (
torch.randint(0, 10, (10, 10)).to(torch.int32),
[
torch.randint(0, 10, (1,)).to(torch.int64),
],
torch.randint(0, 10, (10,)).to(torch.int32),
True,
)
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.index_put, args, kwargs)

def test_aten_index_select_0(self):
args = (
torch.randn((2, 10)).to(torch.float32),
Expand Down
10 changes: 7 additions & 3 deletions experimental/torch_xla2/torch_xla2/__init__.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
import jax
import torch
import torch._functorch
from torch._functorch import make_functional
from torch.utils import _pytree as pytree
from torch_xla2 import tensor
from torch_xla2 import export, ops, ops_registry, tensor, tf_integration


def extract_jax(mod: torch.nn.Module):
"""Returns a pytree of jax.ndarray and a jax callable."""
func, weights, buffer = torch._functorch.make_functional_with_buffers(mod)
func, weights, buffer = make_functional.make_functional_with_buffers(mod)
states = (weights, buffer)
states = pytree.tree_map_only(torch.Tensor, tensor.t2j, states)

@jax.jit
def jax_func(states, inputs):
(states, inputs) = tensor.wrap((states, inputs))
weights, buffer = states
res = func(weights, buffer, *inputs)
with tensor.XLADispatchMode():
res = func(weights, buffer, *inputs)
return tensor.unwrap(res)

return states, jax_func
15 changes: 10 additions & 5 deletions experimental/torch_xla2/torch_xla2/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def call_function(self, target, args: Tuple, kwargs: Dict) -> Any:
):
return super().call_function(target, args, kwargs)

print('Running ', target.name(), '--------')
# print('Running ', target.name(), '--------')

op = ops_registry.lowerings.lookup(target)
if op is None:
Expand All @@ -166,13 +166,18 @@ def call_function(self, target, args: Tuple, kwargs: Dict) -> Any:

def run_node(self, n) -> Any:
res = super().run_node(n)
if n.op == 'call_function':
if hasattr(res, 'shape'):
print('Meta:', n.meta.get('val').shape, 'REAL: ', res.shape)
#if n.op == 'call_function':
# if hasattr(res, 'shape'):
# print('Meta:', n.meta.get('val').shape, 'REAL: ', res.shape)
return res

from torch._decomp import get_decompositions
import torch._refs
_extra_decomp = get_decompositions(
[torch.ops.aten.unfold]
)


_extra_decomp = {}


def exported_program_to_jax(exported_program):
Expand Down
25 changes: 16 additions & 9 deletions experimental/torch_xla2/torch_xla2/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
"""Torch ops implemented using jax."""
import sys

import flax
import jax
from jax import numpy as jnp
import numpy as np
Expand Down Expand Up @@ -298,10 +297,13 @@ def _aten_empty(sizes, **kwargs):

@op(torch.ops.aten.index_put_)
@op(torch.ops.aten.index_put)
def _aten_index_put(self, indexes, values):
def _aten_index_put(self, indexes, values, accumulate=False):
indexes = [slice(None, None, None) if i is None else i for i in indexes]
indexes = tuple(indexes)
return self.at[indexes].set(values)
if accumulate:
return self.at[indexes].add(values)
else:
return self.at[indexes].set(values)


@op(torch.ops.aten.index)
Expand Down Expand Up @@ -508,8 +510,11 @@ def create_default_conv_dimension_numbers(num_spatial_dims):
)

if bias is not None:
# TODO(qihqi): this is wrong
bias = bias.reshape(bias.shape + (1,))
# TODO(qihqi): bias always on channel?
if len(bias.shape) == 1:
shape = [1] * len(res.shape)
shape[1] = bias.shape[0]
bias = bias.reshape(tuple(shape))
res = res + bias
return res

Expand Down Expand Up @@ -1200,19 +1205,21 @@ def _aten_arange(
start,
end=None,
step=1,
*,
dtype=None,
layout=None,
requires_grad=False,
device=None,
pin_memory=False,
pin_memory=False
):
if end is None:
end = start
start = 0
return jnp.arange(
start,
end,
step,
dtype=dtype,
layout=layout,
device=device,
pin_memory=pin_memory,
)


Expand Down
8 changes: 5 additions & 3 deletions experimental/torch_xla2/torch_xla2/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,12 @@ def __new__(cls, elem):
shape[i] = 1
if dtype is None:
dtype = torch.float32
return torch.Tensor._make_subclass(
return torch.Tensor._make_wrapper_subclass(
cls,
torch.empty(shape, dtype=dtype, device="meta"),
require_grad=False,
shape,
dtype=dtype,
device='meta',
requires_grad=False,
)

def __init__(self, elem: jax.Array):
Expand Down
Loading