From 0e35022f01a1bb89aff83f578dc62122a6a90d33 Mon Sep 17 00:00:00 2001 From: JackCaoG <59073027+JackCaoG@users.noreply.github.com> Date: Fri, 16 Aug 2024 17:33:58 -0700 Subject: [PATCH] As_stride should aplly on input tensor, not the base tensor (#7864) --- test/test_operations.py | 15 +++++++++++++++ torch_xla/csrc/aten_xla_type.cpp | 6 ++---- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/test/test_operations.py b/test/test_operations.py index a3617453672..94b88754621 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -1006,6 +1006,19 @@ def func(a, b): b = torch.ones([2, 2]) self.runAtenTest((a, b), func) + def test_multi_view(self): + + def func(x): + a1, b1 = x.chunk(2) + a2, b2 = x[0:1], x[1:2] + a3, b3 = x[0].unsqueeze(0), x[1].unsqueeze(0) + a4, b4 = x[0, None], x[1, None] + return a1.squeeze(), b1.squeeze(), a2.squeeze(), b2.squeeze(), a3.squeeze( + ), b3.squeeze(), a4.squeeze(), b4.squeeze() + + x = torch.randn(size=[2]) + self.runAtenTest(x, func) + # TODO - upstream behavior has changed and results in expected DestroyXlaTensor # counter as of 11/13/2023. Re-enable after reviewing the change. # @skipIfFunctionalizationDisabled("metrics differ") @@ -2691,6 +2704,8 @@ def from_tensors(self, tensors): self.assertEqual(xdata.batch_sizes.device, torch.device('cpu')) self.assertEqual(xdata.data.device, xla_device) + @skipIfFunctionalizationDisabled( + "https://github.com/pytorch/xla/pull/7864#issuecomment-2294034008") def test_as_strided_input_larger(self): size = (5, 5) device = xm.xla_device() diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index cded9b114ca..d355d6c378f 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -4167,15 +4167,13 @@ at::Tensor XLANativeFunctions::as_strided( const at::Tensor& self, at::IntArrayRef size, at::IntArrayRef stride, std::optional storage_offset) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - const auto& base = bridge::GetXlaTensor(self)->Base(); - const auto& tensor = base.defined() ? base : self; - XLATensorPtr self_tensor = bridge::GetXlaTensor(tensor); + XLATensorPtr self_tensor = bridge::GetXlaTensor(self); auto xsize = XlaHelpers::I64List(size); auto xstride = XlaHelpers::I64List(stride); if (!AsStrided::StrideIsSupported(self_tensor->shape(), xsize, xstride, storage_offset.value_or(0))) { return at::native::call_fallback_fn< - &xla_fallback, ATEN_OP(as_strided)>::call(tensor, size, stride, + &xla_fallback, ATEN_OP(as_strided)>::call(self, size, stride, storage_offset); } return bridge::AtenFromXlaTensor(tensor_methods::as_strided(