Skip to content

Commit

Permalink
[WIP] Implement randn and empty
Browse files Browse the repository at this point in the history
  • Loading branch information
will-cromar committed May 2, 2024
1 parent 93ce054 commit afaebdf
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 2 deletions.
2 changes: 2 additions & 0 deletions experimental/torch_xla2/test/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ class TestTorchFunctions(parameterized.TestCase):
('full_2d', lambda: torch.full((2, 3), 3.141592)),
('full_2d_dtype', lambda: torch.full(
(2, 3), 3.141592, dtype=torch.float16)),
('randn_1d', lambda: torch.randn(4)),
('randn_2d', lambda: torch.randn(2, 3)),
)
def test_tensor_constructor(self, func: Callable[[], torch.Tensor]):
expected = func()
Expand Down
2 changes: 1 addition & 1 deletion experimental/torch_xla2/torch_xla2/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,7 +892,7 @@ def _aten_minimum(self, other):


def _scatter_index(dim, index):
"""Returns a tuple of indexes;
"""Returns a tuple of indexes;
The first is to select in input (to modify),
the second is to select from the values.
Expand Down
17 changes: 16 additions & 1 deletion experimental/torch_xla2/torch_xla2/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Callable, Optional, ParamSpec, Sequence

import jax
import numpy as np
import torch
import jax.numpy as jnp
from torch_xla2 import tensor
Expand Down Expand Up @@ -111,12 +112,26 @@ def _torch_argsort(input, dim=-1, descending=False, stable=False):
# behavior is the same as a jnp array of rank 1
expanded = True
input = jnp.expand_dims(input, 0)
res = jnp.argsort(input, axis=dim, descending=descending,
res = jnp.argsort(input, axis=dim, descending=descending,
stable=stable)
if expanded:
res = res.squeeze()
return res

@register_function(torch.empty)
@convert_dtype()
def _empty(*size: int, dtype=None, **kwargs):
return jnp.empty(size, dtype)

@register_function(torch.randn)
@convert_dtype()
def _randn(*size: int, generator: Optional[torch.Generator] = None, dtype=None):
if not generator:
key = jax.random.key(np.uint64(torch.random.seed()))
else:
key = jax.random.key(generator.seed())

return jax.random.normal(key, size, dtype)


class XLAFunctionMode(torch.overrides.TorchFunctionMode):
Expand Down

0 comments on commit afaebdf

Please sign in to comment.