From 3e98f2f1b0981e3cc9d418bdfe0d03cd21298499 Mon Sep 17 00:00:00 2001 From: Gregory Shikhman Date: Tue, 17 Sep 2024 20:39:41 +0000 Subject: [PATCH] Uncomment tests that are passing. I unskipped the tests in the skiplist and checked which tests were actually passing. --- experimental/torch_xla2/test/test_ops.py | 61 +----------------------- 1 file changed, 2 insertions(+), 59 deletions(-) diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index 3b95476d413..f57ff23b1a4 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -15,7 +15,6 @@ "_segment_reduce", "_upsample_bilinear2d_aa", "bincount", # NOTE: dtype for int input torch gives float. This is weird. - "block_diag", "byte", "cat", "cauchy", @@ -29,18 +28,13 @@ "diagflat", "diagonal_copy", "diagonal_scatter", - "diff", "digamma", - "erfc", "erfinv", - "expand", "exponential", - "floor_divide", "gather", "gcd", "geometric", "geqrf", - "grid_sampler_2d", "histogram", # hard op: AssertionError: Tensor-likes are not close! "histogramdd", # TypeError: histogram requires ndarray or scalar arguments, got at position 1. "igammac", @@ -68,7 +62,6 @@ "linalg.matrix_norm", "linalg.matrix_power", "linalg.matrix_rank", - "linalg.multi_dot", "linalg.norm", "linalg.pinv", "linalg.solve", @@ -78,7 +71,6 @@ "linalg.svdvals", "linalg.tensorinv", "linalg.tensorsolve", - "linalg.vander", "linalg.vector_norm", "linspace", "log_normal", @@ -86,7 +78,6 @@ "lu", "lu_solve", "lu_unpack", - "masked.argmin", "masked.median", "masked_scatter", "masked_select", @@ -97,11 +88,9 @@ "mvlgamma", "nanmedian", "nanquantile", - "native_layer_norm", "new_empty", "new_empty_strided", "nextafter", - "nn.functional.adaptive_avg_pool1d", "nn.functional.adaptive_avg_pool3d", "nn.functional.adaptive_max_pool1d", "nn.functional.adaptive_max_pool2d", @@ -110,9 +99,7 @@ "nn.functional.avg_pool1d", "nn.functional.avg_pool2d", "nn.functional.avg_pool3d", - "nn.functional.batch_norm", "nn.functional.bilinear", - "nn.functional.binary_cross_entropy", "nn.functional.conv2d", "nn.functional.conv3d", "nn.functional.conv_transpose1d", @@ -120,7 +107,6 @@ "nn.functional.conv_transpose3d", "nn.functional.cosine_embedding_loss", "nn.functional.cosine_similarity", - "nn.functional.cross_entropy", "nn.functional.ctc_loss", "nn.functional.dropout2d", "nn.functional.dropout3d", @@ -131,9 +117,7 @@ "nn.functional.fractional_max_pool3d", "nn.functional.group_norm", "nn.functional.hinge_embedding_loss", - "nn.functional.instance_norm", "nn.functional.interpolate", - "nn.functional.layer_norm", "nn.functional.margin_ranking_loss", "nn.functional.max_pool1d", "nn.functional.max_pool2d", @@ -144,17 +128,12 @@ "nn.functional.multi_head_attention_forward", "nn.functional.multi_margin_loss", "nn.functional.multilabel_margin_loss", - "nn.functional.multilabel_soft_margin_loss", - "nn.functional.nll_loss", - "nn.functional.normalize", - "nn.functional.one_hot", "nn.functional.pad", "nn.functional.pairwise_distance", - "nn.functional.pixel_shuffle", - "nn.functional.pixel_unshuffle", "nn.functional.poisson_nll_loss", "nn.functional.rrelu", - "nn.functional.softmin", + "nn.functional.triplet_margin_loss", + "nn.functional.triplet_margin_with_distance_loss", "nn.functional.unfold", "nn.functional.upsample_nearest", "nonzero", @@ -187,54 +166,18 @@ "sub", "svd", "svd_lowrank", - "take_along_dim", "to_sparse", # We are not supporting sparse tensors yet. - "triu", - "unbind", "unfold_copy", "unfold", "unique_consecutive", "unique", "unravel_index", "var_mean", - "zero_", "argwhere", - "cumulative_trapezoid", - "expand_as", "nanmean", - "bmm", - "broadcast_shapes", - "cartesian_prod", - "cdouble", - "ceil", "chalf", # Skip due to jax not support complex32 with backend: https://github.com/google/jax/issues/14180 - "nn.functional.smooth_l1_loss", - "nn.functional.soft_margin_loss", - "nn.functional.softplus", - "nn.functional.softshrink", - "nn.functional.softsign", - "nn.functional.tanhshrink", - "nn.functional.threshold", - "nn.functional.triplet_margin_loss", - "nn.functional.triplet_margin_with_distance_loss", "nn.functional.upsample_bilinear", - "outer", - "permute", - "positive", - "rad2deg", "randint", - "ravel", - "reciprocal", - "remainder", - "repeat", - "true_divide", - "trunc", - "unflatten", - "unsafe_chunk", - "unsafe_split", - "unsqueeze", - "view_as_complex", - "view_as", } # These inputs are themselves views