Skip to content

Commit

Permalink
As_stride should aplly on input tensor, not the base tensor (#7864)
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG committed Aug 17, 2024
1 parent 37312c1 commit 0e35022
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 4 deletions.
15 changes: 15 additions & 0 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1006,6 +1006,19 @@ def func(a, b):
b = torch.ones([2, 2])
self.runAtenTest((a, b), func)

def test_multi_view(self):

def func(x):
a1, b1 = x.chunk(2)
a2, b2 = x[0:1], x[1:2]
a3, b3 = x[0].unsqueeze(0), x[1].unsqueeze(0)
a4, b4 = x[0, None], x[1, None]
return a1.squeeze(), b1.squeeze(), a2.squeeze(), b2.squeeze(), a3.squeeze(
), b3.squeeze(), a4.squeeze(), b4.squeeze()

x = torch.randn(size=[2])
self.runAtenTest(x, func)

# TODO - upstream behavior has changed and results in expected DestroyXlaTensor
# counter as of 11/13/2023. Re-enable after reviewing the change.
# @skipIfFunctionalizationDisabled("metrics differ")
Expand Down Expand Up @@ -2691,6 +2704,8 @@ def from_tensors(self, tensors):
self.assertEqual(xdata.batch_sizes.device, torch.device('cpu'))
self.assertEqual(xdata.data.device, xla_device)

@skipIfFunctionalizationDisabled(
"https://github.com/pytorch/xla/pull/7864#issuecomment-2294034008")
def test_as_strided_input_larger(self):
size = (5, 5)
device = xm.xla_device()
Expand Down
6 changes: 2 additions & 4 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4167,15 +4167,13 @@ at::Tensor XLANativeFunctions::as_strided(
const at::Tensor& self, at::IntArrayRef size, at::IntArrayRef stride,
std::optional<int64_t> storage_offset) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
const auto& base = bridge::GetXlaTensor(self)->Base();
const auto& tensor = base.defined() ? base : self;
XLATensorPtr self_tensor = bridge::GetXlaTensor(tensor);
XLATensorPtr self_tensor = bridge::GetXlaTensor(self);
auto xsize = XlaHelpers::I64List(size);
auto xstride = XlaHelpers::I64List(stride);
if (!AsStrided::StrideIsSupported(self_tensor->shape(), xsize, xstride,
storage_offset.value_or(0))) {
return at::native::call_fallback_fn<
&xla_fallback, ATEN_OP(as_strided)>::call(tensor, size, stride,
&xla_fallback, ATEN_OP(as_strided)>::call(self, size, stride,
storage_offset);
}
return bridge::AtenFromXlaTensor(tensor_methods::as_strided(
Expand Down

0 comments on commit 0e35022

Please sign in to comment.