Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OpInfo test for take -- trace #8009

Merged
merged 1 commit into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,11 +223,7 @@
"svd",
"svd_lowrank",
"take_along_dim",
"take",
"tensor_split",
"to_sparse",
"topk",
"trace",
"to_sparse", # We are not supporting sparse tensors yet.
"triu",
"unbind",
"unfold_copy",
Expand Down Expand Up @@ -362,6 +358,13 @@ def run_export_and_compare(testcase,
test.variant_test_name not in variant_test_name_to_skip)
]

# Sort related ops should ignore index;
# For example: sort( [1, 0, 0]) -> [0, 0, 1]
# the correct index can be [1, 2, 0] or [2, 1, 0]
should_ignore_indexes = {
"topk"
}


class TestOpInfo(TestCase):

Expand Down Expand Up @@ -392,8 +395,11 @@ def test_reference_eager(self, device, dtype, op):
# To avoid errors during testing, replace values below 1 with 1.
sample_input.input = self.replace_values_below_threshold(
sample_input.input, 1)

ignore_index = op.name in should_ignore_indexes

run_export_and_compare(self, op, sample_input, check_output)
run_export_and_compare(self, op, sample_input, check_output,
ignore_indices=ignore_index)


instantiate_device_type_tests(TestOpInfo, globals())
Expand Down
3 changes: 3 additions & 0 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -2265,6 +2265,9 @@ def _aten_topk(input, k, dim=None, largest=True, sorted=True, *, out=None):
if not largest:
input = -input # Find top-k of negated input if we want the smallest

if input.ndim == 0:
return input, jnp.array(0, dtype=jnp.int64.dtype)

transpose_shape = None
if dim != -1 and dim != len(input.shape) - 1:
transpose_shape = list(range(len(input.shape)))
Expand Down
7 changes: 6 additions & 1 deletion experimental/torch_xla2/torch_xla2/ops/jtorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,11 @@ def logdet(input):


@register_function(torch.linalg.slogdet)
def slogdet(input):
def linalg_slogdet(input):
sign, logabsdet = jaten._aten__linalg_slogdet(input)
return torch.return_types.slogdet((sign, logabsdet))


@register_function(torch.tensor_split)
def tensor_split(input, indices_or_sections, dim=0):
return jnp.array_split(input, indices_or_sections, axis=dim)
Loading