Skip to content

Commit

Permalink
Add new_empty (#8087)
Browse files Browse the repository at this point in the history
Co-authored-by: Tracy Chen <tracych@google.com>
  • Loading branch information
tracych477 and Tracy Chen committed Sep 27, 2024
1 parent 248a5bd commit ecc0f5a
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
2 changes: 1 addition & 1 deletion experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@
"multinomial",
"mvlgamma",
"nanmedian",
"new_empty",
"new_empty_strided",
"nextafter",
"nn.functional.adaptive_avg_pool3d",
Expand Down Expand Up @@ -167,6 +166,7 @@
'empty_permuted',
'empty_strided',
'bernoulli',
"new_empty",
'randint_like',
'randn',
'randn_like',
Expand Down
7 changes: 5 additions & 2 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,8 +474,6 @@ def _aten__to_copy(self, **kwargs):
return jnp.copy(self)




@op(torch.ops.aten.empty)
@op_base.convert_dtype()
def _aten_empty(size: Sequence[int], *, dtype=None, **kwargs):
Expand Down Expand Up @@ -3930,6 +3928,11 @@ def _aten_flatten(x, start_dim=0, end_dim=-1):
return jnp.reshape(x, new_shape)


@op(torch.ops.aten.new_empty)
def _new_empty(self, size, **kwargs):
return jnp.empty(size)


@op(torch.ops.aten.new_empty_strided)
def _new_empty_strided(self, size, stride, **kwargs):
return jnp.empty(size)
Expand Down

0 comments on commit ecc0f5a

Please sign in to comment.