diff --git a/experimental/torch_xla2/torch_xla2/ops_registry.py b/experimental/torch_xla2/torch_xla2/ops_registry.py index 6830153a10a..53ac3a4f6d9 100644 --- a/experimental/torch_xla2/torch_xla2/ops_registry.py +++ b/experimental/torch_xla2/torch_xla2/ops_registry.py @@ -24,6 +24,9 @@ def _lookup(self, op): return candidate def register(self, op, lowering): + if isinstance(op, torch._ops.OpOverloadPacket): + if hasattr(op, 'default'): + self.registered_ops[op.default] = lowering self.registered_ops[op] = lowering