From 183d1d418ac4ea231a6a69e44b4c4b5516467f08 Mon Sep 17 00:00:00 2001 From: qihqi Date: Fri, 13 Sep 2024 12:17:33 -0700 Subject: [PATCH] OpInfo test for take -- trace (#8009) --- experimental/torch_xla2/test/test_ops.py | 18 ++++++++++++------ .../torch_xla2/torch_xla2/ops/jaten.py | 3 +++ .../torch_xla2/torch_xla2/ops/jtorch.py | 7 ++++++- 3 files changed, 21 insertions(+), 7 deletions(-) diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index 01400c53b39..65242adc900 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -223,11 +223,7 @@ "svd", "svd_lowrank", "take_along_dim", - "take", - "tensor_split", - "to_sparse", - "topk", - "trace", + "to_sparse", # We are not supporting sparse tensors yet. "triu", "unbind", "unfold_copy", @@ -362,6 +358,13 @@ def run_export_and_compare(testcase, test.variant_test_name not in variant_test_name_to_skip) ] +# Sort related ops should ignore index; +# For example: sort( [1, 0, 0]) -> [0, 0, 1] +# the correct index can be [1, 2, 0] or [2, 1, 0] +should_ignore_indexes = { + "topk" +} + class TestOpInfo(TestCase): @@ -392,8 +395,11 @@ def test_reference_eager(self, device, dtype, op): # To avoid errors during testing, replace values below 1 with 1. sample_input.input = self.replace_values_below_threshold( sample_input.input, 1) + + ignore_index = op.name in should_ignore_indexes - run_export_and_compare(self, op, sample_input, check_output) + run_export_and_compare(self, op, sample_input, check_output, + ignore_indices=ignore_index) instantiate_device_type_tests(TestOpInfo, globals()) diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index 03a2da7b380..25c5f74d49a 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -2265,6 +2265,9 @@ def _aten_topk(input, k, dim=None, largest=True, sorted=True, *, out=None): if not largest: input = -input # Find top-k of negated input if we want the smallest + if input.ndim == 0: + return input, jnp.array(0, dtype=jnp.int64.dtype) + transpose_shape = None if dim != -1 and dim != len(input.shape) - 1: transpose_shape = list(range(len(input.shape))) diff --git a/experimental/torch_xla2/torch_xla2/ops/jtorch.py b/experimental/torch_xla2/torch_xla2/ops/jtorch.py index 346953e9e14..d53e522bc07 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jtorch.py +++ b/experimental/torch_xla2/torch_xla2/ops/jtorch.py @@ -275,6 +275,11 @@ def logdet(input): @register_function(torch.linalg.slogdet) -def slogdet(input): +def linalg_slogdet(input): sign, logabsdet = jaten._aten__linalg_slogdet(input) return torch.return_types.slogdet((sign, logabsdet)) + + +@register_function(torch.tensor_split) +def tensor_split(input, indices_or_sections, dim=0): + return jnp.array_split(input, indices_or_sections, axis=dim)