Skip to content

Commit

Permalink
Add extra APIs. (pytorch#6586)
Browse files Browse the repository at this point in the history
  • Loading branch information
qihqi authored and amithrm committed Mar 1, 2024
1 parent 9e759c8 commit eae04bb
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 0 deletions.
31 changes: 31 additions & 0 deletions experimental/torch_xla2/test/test_extra.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import unittest
import torch
import torch.nn.functional as F
import jax
import torch_xla2
from torch_xla2 import tensor, extra


class ExtraTest(unittest.TestCase):

def setUp(self):
torch.manual_seed(0)

def test_fori_loop(self):
a = tensor.move_to_device(torch.ones((10, 10)))

def body(i, c):
return c + a[i]

init_val = tensor.move_to_device(torch.zeros(10))
res = extra.fori_loop(0, 10, body, init_val)

expect = torch.ones(10) * 10

self.assertTrue(torch.allclose(tensor.j2t(res._elem), expect))




if __name__ == '__main__':
unittest.main()
41 changes: 41 additions & 0 deletions experimental/torch_xla2/torch_xla2/extra.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import jax
import functools
from torch.utils import _pytree as pytree
from torch_xla2 import tensor


def call_jax(jax_function, *args, **kwargs):
# args, kwargs are torch tensors
# return val is torch tensor
args, kwargs = tensor.unwrap((args, kwargs))
res = jax_function(*args, **kwargs)
return tensor.wrap(res)


def call_torch(torch_function, *args, **kwargs):
# args, kwargs are torch tensors
# return val is torch tensor
args, kwargs = tensor.wrap((args, kwargs))
res = torch_function(*args, **kwargs)
return tensor.unwrap(res)


def fori_loop(lower, upper, body_fn, init_val, *, unroll=None):
"""Torch fori_loop mimicking jax behavior.
Args:
lower: lower bound
upper: upperbound
init_val: init value (tree of torch.Tensors)
body_fn is a function that takes (int, a) -> a
where a is a pytree with torch.Tensors
unroll = False | True | int
"""
jax_body = functools.partial(call_torch, body_fn)
return call_jax(
jax.lax.fori_loop,
lower,
upper,
jax_body,
init_val,
unroll=unroll)

0 comments on commit eae04bb

Please sign in to comment.