diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index c43232444cb..a7ffa8ddec9 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -462,6 +462,12 @@ def _aten_empty(size: Sequence[int], *, dtype=None, **kwargs): return jnp.empty(size, dtype=dtype) +@op(torch.ops.aten.empty_like) +@op_base.convert_dtype() +def _aten_empty_like(input, *, dtype=None, **kwargs): + return jnp.empty_like(input, dtype=dtype) + + @op(torch.ops.aten.ones) @op_base.convert_dtype() def _ones(size: Sequence[int], dtype=None, **kwargs):