From 96c8e3047da6a72e3600280da665912504e1b4db Mon Sep 17 00:00:00 2001 From: qihqi Date: Thu, 8 Feb 2024 09:53:31 -0800 Subject: [PATCH] Fix a bug in convolution for torch_xla2. Enable tests (#6496) --- .github/workflows/torch_xla2.yml | 44 +++++++++++ experimental/torch_xla2/requirements.txt | 27 +++++-- .../torch_xla2/test/llama/llama_model.py | 4 + experimental/torch_xla2/test/test_conv.py | 79 +++++++++++++++++++ .../torch_xla2/test/test_core_aten_ops.py | 12 +++ .../torch_xla2/torch_xla2/__init__.py | 10 ++- experimental/torch_xla2/torch_xla2/export.py | 15 ++-- experimental/torch_xla2/torch_xla2/ops.py | 25 +++--- experimental/torch_xla2/torch_xla2/tensor.py | 8 +- 9 files changed, 199 insertions(+), 25 deletions(-) create mode 100644 .github/workflows/torch_xla2.yml create mode 100644 experimental/torch_xla2/test/test_conv.py diff --git a/.github/workflows/torch_xla2.yml b/.github/workflows/torch_xla2.yml new file mode 100644 index 000000000000..b93ecba50ee7 --- /dev/null +++ b/.github/workflows/torch_xla2.yml @@ -0,0 +1,44 @@ +on: + pull_request: + branches: + - master + - r[0-9]+.[0-9]+ + paths: + - 'experimental/torch_xla2/**' + push: + branches: + - master + - r[0-9]+.[0-9]+ + paths: + - 'experimental/torch_xla2/**' + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} + cancel-in-progress: true + +jobs: + torchxla2-cpu: + runs-on: ubuntu-20.04 + steps: + - name: Checkout repo + uses: actions/checkout@v4 + with: + sparse-checkout: | + experimental/torch_xla2 + - name: Setup Python + uses: actions/setup-python@v4 + with: + python-version: '3.10' + - name: Install + shell: bash + working-directory: experimental/torch_xla2 + run: | + pip install pytest absl-py jax[cpu] flatbuffers tensorflow + pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu + pip install -e . + - name: Run tests + working-directory: experimental/torch_xla2 + shell: bash + run: | + pytest test/ \ No newline at end of file diff --git a/experimental/torch_xla2/requirements.txt b/experimental/torch_xla2/requirements.txt index fa373ce770fe..54d87fb64302 100644 --- a/experimental/torch_xla2/requirements.txt +++ b/experimental/torch_xla2/requirements.txt @@ -1,5 +1,22 @@ -jax==0.4.24.dev20240202 -jaxlib==0.4.24.dev20240202 -numpy==1.26.3 -torch==2.2.0 -typing_extensions==4.9.0 +certifi==2022.12.7 +charset-normalizer==2.1.1 +filelock==3.9.0 +fsspec==2023.4.0 +idna==3.4 +jax==0.4.24 +jaxlib==0.4.24 +Jinja2==3.1.2 +MarkupSafe==2.1.3 +ml-dtypes==0.3.2 +mpmath==1.2.1 +networkx==3.0rc1 +numpy==1.24.1 +opt-einsum==3.3.0 +Pillow==9.3.0 +requests==2.28.1 +scipy==1.12.0 +sympy==1.11.1 +--pre +torch +typing_extensions==4.8.0 +urllib3==1.26.13 diff --git a/experimental/torch_xla2/test/llama/llama_model.py b/experimental/torch_xla2/test/llama/llama_model.py index 790afed02586..fb3df18459a4 100644 --- a/experimental/torch_xla2/test/llama/llama_model.py +++ b/experimental/torch_xla2/test/llama/llama_model.py @@ -244,6 +244,10 @@ def forward( k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1) v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1) + print('q=', q.shape) + print('k=', k.shape) + print('v=', v.shape) + print('mask=', mask.shape) y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) diff --git a/experimental/torch_xla2/test/test_conv.py b/experimental/torch_xla2/test/test_conv.py new file mode 100644 index 000000000000..de6873f0c0e4 --- /dev/null +++ b/experimental/torch_xla2/test/test_conv.py @@ -0,0 +1,79 @@ +import torch +from torch import nn +import torch_xla2 +from . import test_base + +class CustomConv1(torch.nn.Module): + + def __init__( + self, + channels_conv1=3, + width_conv1=3, + channels_conv2=5, + width_conv2=5, + hidden_layer_size=50, + ): + super(CustomConv1, self).__init__() + self.conv1 = nn.Conv1d(1, channels_conv1, width_conv1) + self.conv2 = nn.Conv1d(channels_conv1, channels_conv2, width_conv2) + self.fc1 = nn.Linear(hidden_layer_size, 2) + + def forward(self, x): + x = nn.functional.max_pool1d(nn.functional.relu(self.conv1(x)), 2, stride=2) + x = nn.functional.max_pool1d(nn.functional.relu(self.conv2(x)), 2, stride=2) + x = torch.flatten(x, 1) + x = nn.functional.softmax(self.fc1(x), dim=1) + return x + + +class CustomConv2(nn.Module): + + def __init__(self): + super().__init__() + inp = 4 + out = 16 + + self.conv = nn.Conv2d(inp, out, kernel_size=3, padding=1) + + # This is supposed to be a squeeze and excitation block. + self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) + + self.scale = nn.Sequential(nn.Linear(out, out), nn.Sigmoid()) + + def forward(self, x): + x = self.conv(x) + + b = x.shape[0] + ap = self.avg_pool(x).view(b, -1) + ap = self.scale(ap) + ap = ap.view(b, -1, 1, 1) + + return x * ap + + +class ConvTest(test_base.TestCase): + + def test_conv1(self): + m = CustomConv1() + arg = torch.randn((20, 1, 50)) + res = m(arg) + + jax_weights, jax_func = torch_xla2.extract_jax(m) + arg = torch_xla2.tensor.t2j(arg) + res2 = jax_func(jax_weights, (arg, )) + res2_torch = torch_xla2.tensor.j2t(res2) + self.assertTrue(torch.allclose(res, res2_torch)) + + def test_conv2(self): + m = CustomConv2() + arg = torch.randn((20, 4, 50, 100)) + res = m(arg) + jax_weights, jax_func = torch_xla2.extract_jax(m) + arg = torch_xla2.tensor.t2j(arg) + res2 = jax_func(jax_weights, (arg, )) + res2_torch = torch_xla2.tensor.j2t(res2) + self.assertTrue(torch.allclose(res, res2_torch, atol=1e-4, rtol=1e-4)) + + +if __name__ == '__main__': + test_base.main() \ No newline at end of file diff --git a/experimental/torch_xla2/test/test_core_aten_ops.py b/experimental/torch_xla2/test/test_core_aten_ops.py index 83420e976080..ba9bf5d05823 100644 --- a/experimental/torch_xla2/test/test_core_aten_ops.py +++ b/experimental/torch_xla2/test/test_core_aten_ops.py @@ -1762,6 +1762,18 @@ def test_aten_index_put_2(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.index_put, args, kwargs) + def test_aten_index_put_3(self): + args = ( + torch.randint(0, 10, (10, 10)).to(torch.int32), + [ + torch.randint(0, 10, (1,)).to(torch.int64), + ], + torch.randint(0, 10, (10,)).to(torch.int32), + True, + ) + kwargs = dict() + run_export_and_compare(self, torch.ops.aten.index_put, args, kwargs) + def test_aten_index_select_0(self): args = ( torch.randn((2, 10)).to(torch.float32), diff --git a/experimental/torch_xla2/torch_xla2/__init__.py b/experimental/torch_xla2/torch_xla2/__init__.py index 94e50b47c95a..1cbc67cf650b 100644 --- a/experimental/torch_xla2/torch_xla2/__init__.py +++ b/experimental/torch_xla2/torch_xla2/__init__.py @@ -1,19 +1,23 @@ import jax import torch -import torch._functorch +from torch._functorch import make_functional +from torch.utils import _pytree as pytree from torch_xla2 import tensor +from torch_xla2 import export, ops, ops_registry, tensor, tf_integration def extract_jax(mod: torch.nn.Module): """Returns a pytree of jax.ndarray and a jax callable.""" - func, weights, buffer = torch._functorch.make_functional_with_buffers(mod) + func, weights, buffer = make_functional.make_functional_with_buffers(mod) states = (weights, buffer) + states = pytree.tree_map_only(torch.Tensor, tensor.t2j, states) @jax.jit def jax_func(states, inputs): (states, inputs) = tensor.wrap((states, inputs)) weights, buffer = states - res = func(weights, buffer, *inputs) + with tensor.XLADispatchMode(): + res = func(weights, buffer, *inputs) return tensor.unwrap(res) return states, jax_func diff --git a/experimental/torch_xla2/torch_xla2/export.py b/experimental/torch_xla2/torch_xla2/export.py index 9d70be87d559..7a61341786ad 100644 --- a/experimental/torch_xla2/torch_xla2/export.py +++ b/experimental/torch_xla2/torch_xla2/export.py @@ -156,7 +156,7 @@ def call_function(self, target, args: Tuple, kwargs: Dict) -> Any: ): return super().call_function(target, args, kwargs) - print('Running ', target.name(), '--------') + # print('Running ', target.name(), '--------') op = ops_registry.lowerings.lookup(target) if op is None: @@ -166,13 +166,18 @@ def call_function(self, target, args: Tuple, kwargs: Dict) -> Any: def run_node(self, n) -> Any: res = super().run_node(n) - if n.op == 'call_function': - if hasattr(res, 'shape'): - print('Meta:', n.meta.get('val').shape, 'REAL: ', res.shape) + #if n.op == 'call_function': + # if hasattr(res, 'shape'): + # print('Meta:', n.meta.get('val').shape, 'REAL: ', res.shape) return res +from torch._decomp import get_decompositions +import torch._refs +_extra_decomp = get_decompositions( + [torch.ops.aten.unfold] +) + -_extra_decomp = {} def exported_program_to_jax(exported_program): diff --git a/experimental/torch_xla2/torch_xla2/ops.py b/experimental/torch_xla2/torch_xla2/ops.py index aa0643c61bc4..72fa291536da 100644 --- a/experimental/torch_xla2/torch_xla2/ops.py +++ b/experimental/torch_xla2/torch_xla2/ops.py @@ -2,7 +2,6 @@ """Torch ops implemented using jax.""" import sys -import flax import jax from jax import numpy as jnp import numpy as np @@ -298,10 +297,13 @@ def _aten_empty(sizes, **kwargs): @op(torch.ops.aten.index_put_) @op(torch.ops.aten.index_put) -def _aten_index_put(self, indexes, values): +def _aten_index_put(self, indexes, values, accumulate=False): indexes = [slice(None, None, None) if i is None else i for i in indexes] indexes = tuple(indexes) - return self.at[indexes].set(values) + if accumulate: + return self.at[indexes].add(values) + else: + return self.at[indexes].set(values) @op(torch.ops.aten.index) @@ -508,8 +510,11 @@ def create_default_conv_dimension_numbers(num_spatial_dims): ) if bias is not None: - # TODO(qihqi): this is wrong - bias = bias.reshape(bias.shape + (1,)) + # TODO(qihqi): bias always on channel? + if len(bias.shape) == 1: + shape = [1] * len(res.shape) + shape[1] = bias.shape[0] + bias = bias.reshape(tuple(shape)) res = res + bias return res @@ -1200,19 +1205,21 @@ def _aten_arange( start, end=None, step=1, + *, dtype=None, layout=None, + requires_grad=False, device=None, - pin_memory=False, + pin_memory=False ): + if end is None: + end = start + start = 0 return jnp.arange( start, end, step, dtype=dtype, - layout=layout, - device=device, - pin_memory=pin_memory, ) diff --git a/experimental/torch_xla2/torch_xla2/tensor.py b/experimental/torch_xla2/torch_xla2/tensor.py index 77a2ad7d3281..eeb34d5238dd 100644 --- a/experimental/torch_xla2/torch_xla2/tensor.py +++ b/experimental/torch_xla2/torch_xla2/tensor.py @@ -131,10 +131,12 @@ def __new__(cls, elem): shape[i] = 1 if dtype is None: dtype = torch.float32 - return torch.Tensor._make_subclass( + return torch.Tensor._make_wrapper_subclass( cls, - torch.empty(shape, dtype=dtype, device="meta"), - require_grad=False, + shape, + dtype=dtype, + device='meta', + requires_grad=False, ) def __init__(self, elem: jax.Array):