Skip to content

Commit

Permalink
OpInfo test for take -- trace
Browse files Browse the repository at this point in the history
Explicitly mark sparse as unsupported

Fixes: #7966
  • Loading branch information
qihqi committed Sep 13, 2024
1 parent 24bc4b2 commit e99bfb2
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 7 deletions.
18 changes: 12 additions & 6 deletions experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,11 +223,7 @@
"svd",
"svd_lowrank",
"take_along_dim",
"take",
"tensor_split",
"to_sparse",
"topk",
"trace",
"to_sparse", # We are not supporting sparse tensors yet.
"triu",
"unbind",
"unfold_copy",
Expand Down Expand Up @@ -362,6 +358,13 @@ def run_export_and_compare(testcase,
test.variant_test_name not in variant_test_name_to_skip)
]

# Sort related ops should ignore index;
# For example: sort( [1, 0, 0]) -> [0, 0, 1]
# the correct index can be [1, 2, 0] or [2, 1, 0]
should_ignore_indexes = {
"topk"
}


class TestOpInfo(TestCase):

Expand Down Expand Up @@ -392,8 +395,11 @@ def test_reference_eager(self, device, dtype, op):
# To avoid errors during testing, replace values below 1 with 1.
sample_input.input = self.replace_values_below_threshold(
sample_input.input, 1)

ignore_index = op.name in should_ignore_indexes

run_export_and_compare(self, op, sample_input, check_output)
run_export_and_compare(self, op, sample_input, check_output,
ignore_indices=ignore_index)


instantiate_device_type_tests(TestOpInfo, globals())
Expand Down
3 changes: 3 additions & 0 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -2265,6 +2265,9 @@ def _aten_topk(input, k, dim=None, largest=True, sorted=True, *, out=None):
if not largest:
input = -input # Find top-k of negated input if we want the smallest

if input.ndim == 0:
return input, jnp.array(0, dtype=jnp.int64.dtype)

transpose_shape = None
if dim != -1 and dim != len(input.shape) - 1:
transpose_shape = list(range(len(input.shape)))
Expand Down
7 changes: 6 additions & 1 deletion experimental/torch_xla2/torch_xla2/ops/jtorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,11 @@ def logdet(input):


@register_function(torch.linalg.slogdet)
def slogdet(input):
def linalg_slogdet(input):
sign, logabsdet = jaten._aten__linalg_slogdet(input)
return torch.return_types.slogdet((sign, logabsdet))


@register_function(torch.tensor_split)
def tensor_split(input, indices_or_sections, dim=0):
return jnp.array_split(input, indices_or_sections, axis=dim)

0 comments on commit e99bfb2

Please sign in to comment.