From 4fe55f559216493b6f8296fcd20a4e2a5673fa17 Mon Sep 17 00:00:00 2001 From: Greg Shikhman Date: Tue, 17 Sep 2024 17:44:35 -0400 Subject: [PATCH 1/2] Fix unbind op. (#8033) --- experimental/torch_xla2/test/test_ops.py | 1 - experimental/torch_xla2/torch_xla2/ops/jaten.py | 6 +----- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index 3b95476d413..7b7004d3eae 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -190,7 +190,6 @@ "take_along_dim", "to_sparse", # We are not supporting sparse tensors yet. "triu", - "unbind", "unfold_copy", "unfold", "unique_consecutive", diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index a7ffa8ddec9..ea7b60484f9 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -2396,13 +2396,9 @@ def _aten_trunc(a): return jnp.trunc(a) -@op(torch.ops.aten.unbind) @op(torch.ops.aten.unbind_copy) def _aten_unbind(a, dim=0): - return tuple( - _aten_squeeze_dim(jax.lax.index_in_dim(a, i, axis=dim), dim) - for i in range(a.shape[dim]) - ) + return [jax.lax.index_in_dim(a, i, dim, keepdims=False) for i in range(a.shape[dim])] # NOTE: skip aten.upsample_nearest2d and aten.upsample_bilinear2d From 0415cc60a16c3f542ed4ca0aa8640fffd776efb0 Mon Sep 17 00:00:00 2001 From: Greg Shikhman Date: Tue, 17 Sep 2024 17:56:51 -0400 Subject: [PATCH 2/2] Uncomment tests that are passing. (#8035) --- experimental/torch_xla2/test/test_ops.py | 62 ++---------------------- 1 file changed, 4 insertions(+), 58 deletions(-) diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index 7b7004d3eae..dc7eb8eee9c 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -15,11 +15,11 @@ "_segment_reduce", "_upsample_bilinear2d_aa", "bincount", # NOTE: dtype for int input torch gives float. This is weird. - "block_diag", "byte", "cat", "cauchy", "cdist", + "ceil", "cholesky", "cholesky_inverse", "cholesky_solve", @@ -29,18 +29,13 @@ "diagflat", "diagonal_copy", "diagonal_scatter", - "diff", "digamma", - "erfc", "erfinv", - "expand", "exponential", - "floor_divide", "gather", "gcd", "geometric", "geqrf", - "grid_sampler_2d", "histogram", # hard op: AssertionError: Tensor-likes are not close! "histogramdd", # TypeError: histogram requires ndarray or scalar arguments, got at position 1. "igammac", @@ -68,7 +63,6 @@ "linalg.matrix_norm", "linalg.matrix_power", "linalg.matrix_rank", - "linalg.multi_dot", "linalg.norm", "linalg.pinv", "linalg.solve", @@ -78,7 +72,6 @@ "linalg.svdvals", "linalg.tensorinv", "linalg.tensorsolve", - "linalg.vander", "linalg.vector_norm", "linspace", "log_normal", @@ -86,7 +79,6 @@ "lu", "lu_solve", "lu_unpack", - "masked.argmin", "masked.median", "masked_scatter", "masked_select", @@ -97,11 +89,9 @@ "mvlgamma", "nanmedian", "nanquantile", - "native_layer_norm", "new_empty", "new_empty_strided", "nextafter", - "nn.functional.adaptive_avg_pool1d", "nn.functional.adaptive_avg_pool3d", "nn.functional.adaptive_max_pool1d", "nn.functional.adaptive_max_pool2d", @@ -110,9 +100,7 @@ "nn.functional.avg_pool1d", "nn.functional.avg_pool2d", "nn.functional.avg_pool3d", - "nn.functional.batch_norm", "nn.functional.bilinear", - "nn.functional.binary_cross_entropy", "nn.functional.conv2d", "nn.functional.conv3d", "nn.functional.conv_transpose1d", @@ -120,7 +108,6 @@ "nn.functional.conv_transpose3d", "nn.functional.cosine_embedding_loss", "nn.functional.cosine_similarity", - "nn.functional.cross_entropy", "nn.functional.ctc_loss", "nn.functional.dropout2d", "nn.functional.dropout3d", @@ -131,9 +118,7 @@ "nn.functional.fractional_max_pool3d", "nn.functional.group_norm", "nn.functional.hinge_embedding_loss", - "nn.functional.instance_norm", "nn.functional.interpolate", - "nn.functional.layer_norm", "nn.functional.margin_ranking_loss", "nn.functional.max_pool1d", "nn.functional.max_pool2d", @@ -144,17 +129,12 @@ "nn.functional.multi_head_attention_forward", "nn.functional.multi_margin_loss", "nn.functional.multilabel_margin_loss", - "nn.functional.multilabel_soft_margin_loss", - "nn.functional.nll_loss", - "nn.functional.normalize", - "nn.functional.one_hot", "nn.functional.pad", "nn.functional.pairwise_distance", - "nn.functional.pixel_shuffle", - "nn.functional.pixel_unshuffle", "nn.functional.poisson_nll_loss", "nn.functional.rrelu", - "nn.functional.softmin", + "nn.functional.triplet_margin_loss", + "nn.functional.triplet_margin_with_distance_loss", "nn.functional.unfold", "nn.functional.upsample_nearest", "nonzero", @@ -187,53 +167,19 @@ "sub", "svd", "svd_lowrank", - "take_along_dim", "to_sparse", # We are not supporting sparse tensors yet. - "triu", "unfold_copy", "unfold", "unique_consecutive", "unique", "unravel_index", + "trunc", "var_mean", - "zero_", "argwhere", - "cumulative_trapezoid", - "expand_as", "nanmean", - "bmm", - "broadcast_shapes", - "cartesian_prod", - "cdouble", - "ceil", "chalf", # Skip due to jax not support complex32 with backend: https://github.com/google/jax/issues/14180 - "nn.functional.smooth_l1_loss", - "nn.functional.soft_margin_loss", - "nn.functional.softplus", - "nn.functional.softshrink", - "nn.functional.softsign", - "nn.functional.tanhshrink", - "nn.functional.threshold", - "nn.functional.triplet_margin_loss", - "nn.functional.triplet_margin_with_distance_loss", "nn.functional.upsample_bilinear", - "outer", - "permute", - "positive", - "rad2deg", "randint", - "ravel", - "reciprocal", - "remainder", - "repeat", - "true_divide", - "trunc", - "unflatten", - "unsafe_chunk", - "unsafe_split", - "unsqueeze", - "view_as_complex", - "view_as", } # These inputs are themselves views