Skip to content

Commit

Permalink
Implement torch diagflat, and movedim as well as it is also needed as…
Browse files Browse the repository at this point in the history
… a helper function (#7378) (#8038)
  • Loading branch information
simonteozw committed Sep 18, 2024
1 parent 6b33f8f commit b79a46b
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
1 change: 0 additions & 1 deletion experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
"combinations",
"complex",
"diag_embed",
"diagflat",
"diagonal_copy",
"diagonal_scatter",
"digamma",
Expand Down
11 changes: 11 additions & 0 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -1939,6 +1939,17 @@ def _aten_diagonal(input, offset=0, dim1=0, dim2=1):
return jnp.diagonal(input, offset, dim1, dim2)


# aten.diagflat
@op(torch.ops.aten.diagflat)
def _aten_diagflat(input, offset=0):
return jnp.diagflat(jnp.array(input), offset)


@op(torch.ops.aten.movedim)
def _aten_movedim(input, source, destination):
return jnp.moveaxis(input, source, destination)


# aten.eq
@op(torch.ops.aten.eq)
def _aten_eq(input1, input2):
Expand Down

0 comments on commit b79a46b

Please sign in to comment.