Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement torch.ops.aten.embedding_renorm_ #8091

Merged
merged 1 commit into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@
"nn.functional.dropout3d",
"nn.functional.dropout",
"nn.functional.embedding_bag",
"nn.functional.embedding",
"nn.functional.fractional_max_pool2d",
"nn.functional.fractional_max_pool3d",
"nn.functional.group_norm",
Expand Down
23 changes: 22 additions & 1 deletion experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,9 +367,30 @@ def _aten_bmm(x, y):

@op(torch.ops.aten.embedding)
# embedding(Tensor weight, Tensor indices, SymInt padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False)
def _aten_embedding(a, w, padding_idx=-1):
def _aten_embedding(a, w, padding_idx=-1, scale_grad_by_freq=False, sparse=False):
return jnp.take(a, w, axis=0)

@op(torch.ops.aten.embedding_renorm_)
def _aten_embedding_renorm_(weight, indices, max_norm, norm_type):
# Adapted from https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/Embedding.cpp
unique_indices = jnp.unique(indices)

norm = jnp.linalg.norm(
_aten_embedding(weight, unique_indices),
ord=norm_type,
axis=1,
)

indice_idx = jnp.where(norm > max_norm)

scale = max_norm / (norm[indice_idx] + 1e-7)

indices_to_update = unique_indices[indice_idx]

weight = weight.at[indices_to_update].set(
weight[indices_to_update] * scale[:, None]
)
return weight

#- func: _embedding_bag_forward_only(
# Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False,
Expand Down
Loading