From ef0a1db9199eb1818e035da75977c3f7fed88596 Mon Sep 17 00:00:00 2001 From: Gregory Shikhman Date: Tue, 17 Sep 2024 19:04:39 +0000 Subject: [PATCH] Fix unbind op. The test passed with the previous implementation but I created a slightly more efficient one anyway. --- experimental/torch_xla2/test/test_ops.py | 1 - experimental/torch_xla2/torch_xla2/ops/jaten.py | 5 +---- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index 3b95476d413..7b7004d3eae 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -190,7 +190,6 @@ "take_along_dim", "to_sparse", # We are not supporting sparse tensors yet. "triu", - "unbind", "unfold_copy", "unfold", "unique_consecutive", diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index a7ffa8ddec9..1bafebe7672 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -2399,10 +2399,7 @@ def _aten_trunc(a): @op(torch.ops.aten.unbind) @op(torch.ops.aten.unbind_copy) def _aten_unbind(a, dim=0): - return tuple( - _aten_squeeze_dim(jax.lax.index_in_dim(a, i, axis=dim), dim) - for i in range(a.shape[dim]) - ) + return [lax.index_in_dim(a, i, axis, keepdims=False) for i in range(a.shape[axis])] # NOTE: skip aten.upsample_nearest2d and aten.upsample_bilinear2d