From 8a6db3e1d56b5fc6d75828a97a59730efdbeffab Mon Sep 17 00:00:00 2001 From: Barni Seetharaman Date: Fri, 27 Sep 2024 23:54:59 +0000 Subject: [PATCH] Implement upsample_bilinear2d_aa --- experimental/torch_xla2/test/test_ops.py | 2 +- .../torch_xla2/torch_xla2/ops/jaten.py | 31 +++++++++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index 2d50754829b..d6a03ae8f26 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -14,7 +14,6 @@ "__rpow__", # NOTE: cannot fix because torch test case has undefined behavior # such as 0 to negative power. "_segment_reduce", - "_upsample_bilinear2d_aa", "bincount", # NOTE: dtype for int input torch gives float. This is weird. "byte", "cat", @@ -270,6 +269,7 @@ def replace_values_below_threshold(self, torch_tensor, threshold): def test_reference_eager(self, device, dtype, op): sample_inputs = op.sample_inputs(device, dtype) for sample_input in sample_inputs: + print("sample_input: ", op.name, sample_input) t = sample_input.input if isinstance(t, torch.Tensor) and t.is_sparse: continue diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index 4d5197488af..dc65eed7104 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -4,6 +4,7 @@ from typing import Optional, Sequence import functools +import math import jax from jax import numpy as jnp import functools @@ -4095,3 +4096,33 @@ def _aten_max_unpoolxd(input, indices, output_size, stride=None, padding=0): return output +@op(torch.ops.aten._upsample_bilinear2d_aa) +def _aten_upsample_bilinear2d_aa(input, output_size, align_corners, scale_factors): + # input: is of type jaxlib.xla_extension.ArrayImpl + image = input + + # https://jax.readthedocs.io/en/latest/_autosummary/jax.image.resize.html + # Resize does not distinguish batch, channel size. + # We need to leave them as is + # https://pytorch.org/vision/stable/transforms.html#supported-input-types-and-conventions + # pytorch image shape is (C,H,W) or (N,C,H,W) + # N - batch size + # C - no of channels + # H,W - heigth, width + + shape = list(image.shape) + # overriding output_size + if output_size: + shape[-1] = output_size[-1] + shape[-2] = output_size[-2] + if scale_factors: + shape[-1] = int(math.floor(shape[-1]*scale_factors[-1])) + shape[-2] = int(math.floor(shape[-2]*scale_factors[-2])) + method = "bilinear" + antialias = True # ignored for upsampling + + # align_corners is not supported in resize() + # https://github.com/jax-ml/jax/issues/11206 + if align_corners: + raise ValueError("align_corners=true not supported yet") + return jax.image.resize(image, shape, method, antialias) # precision=Precision.HIGHEST