Skip to content

Commit

Permalink
Merge branch 'master' into op-fix-3
Browse files Browse the repository at this point in the history
  • Loading branch information
matinehAkhlaghinia committed Sep 18, 2024
2 parents f2e90ce + 0415cc6 commit 093b46d
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 63 deletions.
62 changes: 4 additions & 58 deletions experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
"_segment_reduce",
"_upsample_bilinear2d_aa",
"bincount", # NOTE: dtype for int input torch gives float. This is weird.
"block_diag",
"byte",
"cat",
"cauchy",
"cdist",
"ceil",
"cholesky",
"cholesky_inverse",
"cholesky_solve",
Expand All @@ -29,17 +29,12 @@
"diagflat",
"diagonal_copy",
"diagonal_scatter",
"diff",
"digamma",
"erfc",
"erfinv",
"expand",
"exponential",
"floor_divide",
"gcd",
"geometric",
"geqrf",
"grid_sampler_2d",
"histogram", # hard op: AssertionError: Tensor-likes are not close!
"histogramdd", # TypeError: histogram requires ndarray or scalar arguments, got <class 'list'> at position 1.
"igammac",
Expand Down Expand Up @@ -67,7 +62,6 @@
"linalg.matrix_norm",
"linalg.matrix_power",
"linalg.matrix_rank",
"linalg.multi_dot",
"linalg.norm",
"linalg.pinv",
"linalg.solve",
Expand All @@ -77,15 +71,13 @@
"linalg.svdvals",
"linalg.tensorinv",
"linalg.tensorsolve",
"linalg.vander",
"linalg.vector_norm",
"linspace",
"log_normal",
"logspace",
"lu",
"lu_solve",
"lu_unpack",
"masked.argmin",
"masked.median",
"masked_scatter",
"masked_select",
Expand All @@ -98,7 +90,6 @@
"new_empty",
"new_empty_strided",
"nextafter",
"nn.functional.adaptive_avg_pool1d",
"nn.functional.adaptive_avg_pool3d",
"nn.functional.adaptive_max_pool1d",
"nn.functional.adaptive_max_pool2d",
Expand All @@ -107,17 +98,14 @@
"nn.functional.avg_pool1d",
"nn.functional.avg_pool2d",
"nn.functional.avg_pool3d",
"nn.functional.batch_norm",
"nn.functional.bilinear",
"nn.functional.binary_cross_entropy",
"nn.functional.conv2d",
"nn.functional.conv3d",
"nn.functional.conv_transpose1d",
"nn.functional.conv_transpose2d",
"nn.functional.conv_transpose3d",
"nn.functional.cosine_embedding_loss",
"nn.functional.cosine_similarity",
"nn.functional.cross_entropy",
"nn.functional.ctc_loss",
"nn.functional.dropout2d",
"nn.functional.dropout3d",
Expand All @@ -128,9 +116,7 @@
"nn.functional.fractional_max_pool3d",
"nn.functional.group_norm",
"nn.functional.hinge_embedding_loss",
"nn.functional.instance_norm",
"nn.functional.interpolate",
"nn.functional.layer_norm",
"nn.functional.margin_ranking_loss",
"nn.functional.max_pool1d",
"nn.functional.max_pool2d",
Expand All @@ -141,17 +127,12 @@
"nn.functional.multi_head_attention_forward",
"nn.functional.multi_margin_loss",
"nn.functional.multilabel_margin_loss",
"nn.functional.multilabel_soft_margin_loss",
"nn.functional.nll_loss",
"nn.functional.normalize",
"nn.functional.one_hot",
"nn.functional.pad",
"nn.functional.pairwise_distance",
"nn.functional.pixel_shuffle",
"nn.functional.pixel_unshuffle",
"nn.functional.poisson_nll_loss",
"nn.functional.rrelu",
"nn.functional.softmin",
"nn.functional.triplet_margin_loss",
"nn.functional.triplet_margin_with_distance_loss",
"nn.functional.unfold",
"nn.functional.upsample_nearest",
"nonzero",
Expand Down Expand Up @@ -183,54 +164,19 @@
"sub",
"svd",
"svd_lowrank",
"take_along_dim",
"to_sparse", # We are not supporting sparse tensors yet.
"triu",
"unbind",
"unfold_copy",
"unfold",
"unique_consecutive",
"unique",
"unravel_index",
"trunc",
"var_mean",
"zero_",
"argwhere",
"cumulative_trapezoid",
"expand_as",
"nanmean",
"bmm",
"broadcast_shapes",
"cartesian_prod",
"cdouble",
"ceil",
"chalf", # Skip due to jax not support complex32 with backend: https://github.com/google/jax/issues/14180
"nn.functional.smooth_l1_loss",
"nn.functional.soft_margin_loss",
"nn.functional.softplus",
"nn.functional.softshrink",
"nn.functional.softsign",
"nn.functional.tanhshrink",
"nn.functional.threshold",
"nn.functional.triplet_margin_loss",
"nn.functional.triplet_margin_with_distance_loss",
"nn.functional.upsample_bilinear",
"outer",
"permute",
"positive",
"rad2deg",
"randint",
"ravel",
"reciprocal",
"remainder",
"repeat",
"true_divide",
"trunc",
"unflatten",
"unsafe_chunk",
"unsafe_split",
"unsqueeze",
"view_as_complex",
"view_as",
}

# These inputs are themselves views
Expand Down
6 changes: 1 addition & 5 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -2399,13 +2399,9 @@ def _aten_trunc(a):
return jnp.trunc(a)


@op(torch.ops.aten.unbind)
@op(torch.ops.aten.unbind_copy)
def _aten_unbind(a, dim=0):
return tuple(
_aten_squeeze_dim(jax.lax.index_in_dim(a, i, axis=dim), dim)
for i in range(a.shape[dim])
)
return [jax.lax.index_in_dim(a, i, dim, keepdims=False) for i in range(a.shape[dim])]


# NOTE: skip aten.upsample_nearest2d and aten.upsample_bilinear2d
Expand Down

0 comments on commit 093b46d

Please sign in to comment.