diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index af8cc21b53a..2d50754829b 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -75,7 +75,6 @@ "multinomial", "mvlgamma", "nanmedian", - "new_empty", "new_empty_strided", "nextafter", "nn.functional.adaptive_avg_pool3d", @@ -167,6 +166,7 @@ 'empty_permuted', 'empty_strided', 'bernoulli', + "new_empty", 'randint_like', 'randn', 'randn_like', diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index 4181f11e3cf..4d5197488af 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -474,8 +474,6 @@ def _aten__to_copy(self, **kwargs): return jnp.copy(self) - - @op(torch.ops.aten.empty) @op_base.convert_dtype() def _aten_empty(size: Sequence[int], *, dtype=None, **kwargs): @@ -3930,6 +3928,11 @@ def _aten_flatten(x, start_dim=0, end_dim=-1): return jnp.reshape(x, new_shape) +@op(torch.ops.aten.new_empty) +def _new_empty(self, size, **kwargs): + return jnp.empty(size) + + @op(torch.ops.aten.new_empty_strided) def _new_empty_strided(self, size, stride, **kwargs): return jnp.empty(size)