Skip to content

Commit

Permalink
Add atan2, arange, addcmul, addcdiv, _softmax_backward_data, T, H (#7681
Browse files Browse the repository at this point in the history
)
  • Loading branch information
ManfeiBai committed Jul 16, 2024
1 parent b2c7f65 commit a69fd51
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 9 deletions.
7 changes: 0 additions & 7 deletions experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,13 +324,6 @@
"nanmean",
"trapezoid",
"trapz",
"H",
"T",
"_softmax_backward_data",
"addcdiv",
"addcmul",
"arange",
"atan2",
"atleast_1d",
"atleast_2d",
"atleast_3d",
Expand Down
21 changes: 19 additions & 2 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,17 @@ def _handle_int64_trig(self, func):
return res


def _handle_int64_to_int32_trig(func, args):
target_type = None
for i in range(len(args)):
if args[i].dtype.name == 'int64':
target_type = jnp.dtype('int32')
if target_type is not None:
args[i] = args[i].astype(target_type)
res = func(*args)
return res


@op(
torch.ops.aten.view_copy,
torch.ops.aten.view,
Expand Down Expand Up @@ -1518,6 +1529,12 @@ def _aten_arange(
):
if dtype:
dtype = mappings.t2j_dtype(dtype)
if start and dtype:
start = jax.lax.convert_element_type(start, dtype)
if end and dtype:
end = jax.lax.convert_element_type(end, dtype)
if step and dtype:
step = jax.lax.convert_element_type(step, dtype)
return jnp.arange(
start,
end,
Expand Down Expand Up @@ -1562,8 +1579,8 @@ def _aten_as_strided_scatter(x, src, sizes, strides, storage_offset):

# aten.atan2
@op(torch.ops.aten.atan2)
def _aten_atan2(self, other):
return jnp.arctan2(self, other)
def _aten_atan2(input, other):
return _handle_int64_to_int32_trig(jnp.arctan2, [input, other])


# aten.bitwise_and
Expand Down

0 comments on commit a69fd51

Please sign in to comment.