Skip to content

Commit

Permalink
[torch_xla2] Fix all aten jax ops getter (#7868)
Browse files Browse the repository at this point in the history
  • Loading branch information
chunnienc committed Aug 19, 2024
1 parent 940a4b1 commit b075cf8
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions experimental/torch_xla2/torch_xla2/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
def all_aten_jax_ops():
# to load the ops
import torch_xla2.ops.jaten # type: ignore
import torch_xla2.ops.ops_registry # type: ignore
return {
key: val.func
for key, val in torch_xla2.ops.ops_registry.all_aten_ops
if val.is_jax_function
}
# to load the ops
import torch_xla2.ops.jaten # type: ignore
import torch_xla2.ops.ops_registry # type: ignore

return {
key: val.func
for key, val in torch_xla2.ops.ops_registry.all_aten_ops.items()
if val.is_jax_function
}

0 comments on commit b075cf8

Please sign in to comment.