Skip to content

Commit

Permalink
Add fixes for narrow_copy and narrow ops
Browse files Browse the repository at this point in the history
  • Loading branch information
matinehAkhlaghinia committed Sep 15, 2024
1 parent 5f82da9 commit 7635eb1
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
3 changes: 0 additions & 3 deletions experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,6 @@
"mvlgamma",
"nanmedian",
"nanquantile",
"nansum",
"narrow_copy",
"narrow",
"native_layer_norm",
"new_empty",
"new_empty_strided",
Expand Down
3 changes: 3 additions & 0 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,8 @@ def _aten_triu(m, k):
@op(torch.ops.aten.slice)
@op(torch.ops.aten.slice_copy)
def _aten_slice(self, dim=0, start=None, end=None, step=1):
if dim < 0:
dim += self.ndim
if end == sys.maxsize:
end = self.shape[dim]
sl = slice(start, end, step)
Expand Down Expand Up @@ -3220,6 +3222,7 @@ def f(k, carry):


@op(torch.ops.aten.narrow)
@op(torch.ops.aten.narrow_copy)
def _aten_narrow(input, dim, start, length):
return jax.lax.dynamic_slice_in_dim(input, start, length, axis=dim)

Expand Down

0 comments on commit 7635eb1

Please sign in to comment.