From b075cf8155ad977c98c4a1617f6f6778521795ee Mon Sep 17 00:00:00 2001 From: Chunnien Chan <121328115+chunnienc@users.noreply.github.com> Date: Mon, 19 Aug 2024 14:49:47 -0700 Subject: [PATCH] [torch_xla2] Fix all aten jax ops getter (#7868) --- .../torch_xla2/torch_xla2/ops/__init__.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/experimental/torch_xla2/torch_xla2/ops/__init__.py b/experimental/torch_xla2/torch_xla2/ops/__init__.py index 1be0e45378e..3ba99a250c2 100644 --- a/experimental/torch_xla2/torch_xla2/ops/__init__.py +++ b/experimental/torch_xla2/torch_xla2/ops/__init__.py @@ -1,9 +1,10 @@ def all_aten_jax_ops(): - # to load the ops - import torch_xla2.ops.jaten # type: ignore - import torch_xla2.ops.ops_registry # type: ignore - return { - key: val.func - for key, val in torch_xla2.ops.ops_registry.all_aten_ops - if val.is_jax_function - } \ No newline at end of file + # to load the ops + import torch_xla2.ops.jaten # type: ignore + import torch_xla2.ops.ops_registry # type: ignore + + return { + key: val.func + for key, val in torch_xla2.ops.ops_registry.all_aten_ops.items() + if val.is_jax_function + }