diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index 2d50754829b..d6a03ae8f26 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -14,7 +14,6 @@ "__rpow__", # NOTE: cannot fix because torch test case has undefined behavior # such as 0 to negative power. "_segment_reduce", - "_upsample_bilinear2d_aa", "bincount", # NOTE: dtype for int input torch gives float. This is weird. "byte", "cat", @@ -270,6 +269,7 @@ def replace_values_below_threshold(self, torch_tensor, threshold): def test_reference_eager(self, device, dtype, op): sample_inputs = op.sample_inputs(device, dtype) for sample_input in sample_inputs: + print("sample_input: ", op.name, sample_input) t = sample_input.input if isinstance(t, torch.Tensor) and t.is_sparse: continue diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index 4d5197488af..373c6df5e87 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -1,9 +1,10 @@ """Torch ops implemented using jax.""" import sys -from typing import Optional, Sequence +from typing import Optional, Sequence, Tuple, Union import functools +import math import jax from jax import numpy as jnp import functools @@ -4095,3 +4096,59 @@ def _aten_max_unpoolxd(input, indices, output_size, stride=None, padding=0): return output +@op(torch.ops.aten._upsample_bilinear2d_aa) +def _aten_upsample_bilinear2d_aa(input, output_size, align_corners, scale_factors): + # input: is of type jaxlib.xla_extension.ArrayImpl + image = input + + # https://jax.readthedocs.io/en/latest/_autosummary/jax.image.resize.html + # Resize does not distinguish batch, channel size. + # We need to leave them as is + # https://pytorch.org/vision/stable/transforms.html#supported-input-types-and-conventions + # pytorch image shape is (C,H,W) or (N,C,H,W) + # N - batch size + # C - no of channels + # H,W - heigth, width + + shape = list(image.shape) + # overriding output_size + if output_size: + shape[-1] = output_size[-1] + shape[-2] = output_size[-2] + if scale_factors: + shape[-1] = int(math.floor(shape[-1]*scale_factors[-1])) + shape[-2] = int(math.floor(shape[-2]*scale_factors[-2])) + method = "bilinear" + antialias = True # ignored for upsampling + + # align_corners is not supported in resize() + # https://github.com/jax-ml/jax/issues/11206 + if align_corners: + return resize_with_aligned_corners(image, shape, method, antialias=True) + return jax.image.resize(image, shape, method, antialias) # precision=Precision.HIGHEST + +# From: https://github.com/jax-ml/jax/issues/11206 +def resize_with_aligned_corners( + image: jax.Array, + shape: Tuple[int, ...], + method: Union[str, jax.image.ResizeMethod], + antialias: bool, +): + """Alternative to jax.image.resize(), which emulates align_corners=True in PyTorch's + interpolation functions.""" + spatial_dims = tuple( + i + for i in range(len(shape)) + if not jax.core.symbolic_equal_dim(image.shape[i], shape[i]) + ) + scale = jnp.array([(shape[i] - 1.0) / (image.shape[i] - 1.0) for i in spatial_dims]) + translation = -(scale / 2.0 - 0.5) + return jax.image.scale_and_translate( + image, + shape, + method=method, + scale=scale, + spatial_dims=spatial_dims, + translation=translation, + antialias=antialias, + ) \ No newline at end of file