Skip to content

Commit

Permalink
Implement upsample_bilinear2d_aa
Browse files Browse the repository at this point in the history
  • Loading branch information
barney-s committed Sep 28, 2024
1 parent ecc0f5a commit 8a6db3e
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 1 deletion.
2 changes: 1 addition & 1 deletion experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
"__rpow__", # NOTE: cannot fix because torch test case has undefined behavior
# such as 0 to negative power.
"_segment_reduce",
"_upsample_bilinear2d_aa",
"bincount", # NOTE: dtype for int input torch gives float. This is weird.
"byte",
"cat",
Expand Down Expand Up @@ -270,6 +269,7 @@ def replace_values_below_threshold(self, torch_tensor, threshold):
def test_reference_eager(self, device, dtype, op):
sample_inputs = op.sample_inputs(device, dtype)
for sample_input in sample_inputs:
print("sample_input: ", op.name, sample_input)
t = sample_input.input
if isinstance(t, torch.Tensor) and t.is_sparse:
continue
Expand Down
31 changes: 31 additions & 0 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Optional, Sequence
import functools

import math
import jax
from jax import numpy as jnp
import functools
Expand Down Expand Up @@ -4095,3 +4096,33 @@ def _aten_max_unpoolxd(input, indices, output_size, stride=None, padding=0):

return output

@op(torch.ops.aten._upsample_bilinear2d_aa)
def _aten_upsample_bilinear2d_aa(input, output_size, align_corners, scale_factors):
# input: is of type jaxlib.xla_extension.ArrayImpl
image = input

# https://jax.readthedocs.io/en/latest/_autosummary/jax.image.resize.html
# Resize does not distinguish batch, channel size.
# We need to leave them as is
# https://pytorch.org/vision/stable/transforms.html#supported-input-types-and-conventions
# pytorch image shape is (C,H,W) or (N,C,H,W)
# N - batch size
# C - no of channels
# H,W - heigth, width

shape = list(image.shape)
# overriding output_size
if output_size:
shape[-1] = output_size[-1]
shape[-2] = output_size[-2]
if scale_factors:
shape[-1] = int(math.floor(shape[-1]*scale_factors[-1]))
shape[-2] = int(math.floor(shape[-2]*scale_factors[-2]))
method = "bilinear"
antialias = True # ignored for upsampling

# align_corners is not supported in resize()
# https://github.com/jax-ml/jax/issues/11206
if align_corners:
raise ValueError("align_corners=true not supported yet")
return jax.image.resize(image, shape, method, antialias) # precision=Precision.HIGHEST

0 comments on commit 8a6db3e

Please sign in to comment.