Skip to content

Commit

Permalink
Test and fix bug for convolution
Browse files Browse the repository at this point in the history
  • Loading branch information
qihqi committed Feb 8, 2024
1 parent d24abfb commit 5960449
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 5 deletions.
45 changes: 45 additions & 0 deletions .github/workflows/torch_xla2.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
on:
pull_request:
branches:
- master
- r[0-9]+.[0-9]+
paths:
- 'experimental/torch_xla2/**'
push:
branches:
- master
- r[0-9]+.[0-9]+
paths:
- 'experimental/torch_xla2/**'
workflow_dispatch:

concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
cancel-in-progress: true

jobs:
torchxla2-cpu:
runs-on: ubuntu-20.04
steps:
- name: Setup Linux
uses: pytorch/test-infra/.github/actions/setup-linux@main
- name: Checkout repo
uses: actions/checkout@v4
with:
sparse-checkout: |
experimental/torch_xla2/**
- name: Install
shell: bash
working-directory: experimental/torch_xla2
run: |
pip install -e .
pip install pytest
- name: Run tests
working-directory: experimental/torch_xla2
shell: bash
run: |
pytest test/
- name: Teardown Linux
uses: pytorch/test-infra/.github/actions/teardown-linux@main
if: always()
79 changes: 79 additions & 0 deletions experimental/torch_xla2/test/test_conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import torch
from torch import nn
import torch_xla2
from . import test_base

class CustomConv1(torch.nn.Module):

def __init__(
self,
channels_conv1=3,
width_conv1=3,
channels_conv2=5,
width_conv2=5,
hidden_layer_size=50,
):
super(CustomConv1, self).__init__()
self.conv1 = nn.Conv1d(1, channels_conv1, width_conv1)
self.conv2 = nn.Conv1d(channels_conv1, channels_conv2, width_conv2)
self.fc1 = nn.Linear(hidden_layer_size, 2)

def forward(self, x):
x = nn.functional.max_pool1d(nn.functional.relu(self.conv1(x)), 2, stride=2)
x = nn.functional.max_pool1d(nn.functional.relu(self.conv2(x)), 2, stride=2)
x = torch.flatten(x, 1)
x = nn.functional.softmax(self.fc1(x), dim=1)
return x


class CustomConv2(nn.Module):

def __init__(self):
super().__init__()
inp = 4
out = 16

self.conv = nn.Conv2d(inp, out, kernel_size=3, padding=1)

# This is supposed to be a squeeze and excitation block.
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))

self.scale = nn.Sequential(nn.Linear(out, out), nn.Sigmoid())

def forward(self, x):
x = self.conv(x)

b = x.shape[0]
ap = self.avg_pool(x).view(b, -1)
ap = self.scale(ap)
ap = ap.view(b, -1, 1, 1)

return x * ap


class ConvTest(test_base.TestCase):

def test_conv1(self):
m = CustomConv1()
arg = torch.randn((20, 1, 50))
res = m(arg)

jax_weights, jax_func = torch_xla2.extract_jax(m)
arg = torch_xla2.tensor.t2j(arg)
res2 = jax_func(jax_weights, (arg, ))
res2_torch = torch_xla2.tensor.j2t(res2)
self.assertTrue(torch.allclose(res, res2_torch))

def test_conv2(self):
m = CustomConv2()
arg = torch.randn((20, 4, 50, 100))
res = m(arg)
jax_weights, jax_func = torch_xla2.extract_jax(m)
arg = torch_xla2.tensor.t2j(arg)
res2 = jax_func(jax_weights, (arg, ))
res2_torch = torch_xla2.tensor.j2t(res2)
self.assertTrue(torch.allclose(res, res2_torch, atol=1e-4, rtol=1e-4))


if __name__ == '__main__':
test_base.main()
7 changes: 5 additions & 2 deletions experimental/torch_xla2/torch_xla2/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import jax
import torch
import torch._functorch
from torch._functorch import make_functional
from torch.utils import _pytree as pytree
from torch_xla2 import tensor
from torch_xla2 import export, ops, ops_registry, tensor, tf_integration


def extract_jax(mod: torch.nn.Module):
"""Returns a pytree of jax.ndarray and a jax callable."""
func, weights, buffer = torch._functorch.make_functional_with_buffers(mod)
func, weights, buffer = make_functional.make_functional_with_buffers(mod)
states = (weights, buffer)
states = pytree.tree_map_only(torch.Tensor, tensor.t2j, states)

@jax.jit
def jax_func(states, inputs):
Expand Down
8 changes: 5 additions & 3 deletions experimental/torch_xla2/torch_xla2/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
"""Torch ops implemented using jax."""
import sys

import flax
import jax
from jax import numpy as jnp
import numpy as np
Expand Down Expand Up @@ -508,8 +507,11 @@ def create_default_conv_dimension_numbers(num_spatial_dims):
)

if bias is not None:
# TODO(qihqi): this is wrong
bias = bias.reshape(bias.shape + (1,))
# TODO(qihqi): bias always on channel?
if len(bias.shape) == 1:
shape = [1] * len(res.shape)
shape[1] = bias.shape[0]
bias = bias.reshape(tuple(shape))
res = res + bias
return res

Expand Down

0 comments on commit 5960449

Please sign in to comment.