Skip to content

Commit

Permalink
Implement torch.ops.aten.embedding_renorm_ (#8091)
Browse files Browse the repository at this point in the history
  • Loading branch information
guyao committed Sep 30, 2024
1 parent ecc0f5a commit 5872b20
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 2 deletions.
1 change: 0 additions & 1 deletion experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@
"nn.functional.dropout3d",
"nn.functional.dropout",
"nn.functional.embedding_bag",
"nn.functional.embedding",
"nn.functional.fractional_max_pool2d",
"nn.functional.fractional_max_pool3d",
"nn.functional.group_norm",
Expand Down
23 changes: 22 additions & 1 deletion experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,9 +367,30 @@ def _aten_bmm(x, y):

@op(torch.ops.aten.embedding)
# embedding(Tensor weight, Tensor indices, SymInt padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False)
def _aten_embedding(a, w, padding_idx=-1):
def _aten_embedding(a, w, padding_idx=-1, scale_grad_by_freq=False, sparse=False):
return jnp.take(a, w, axis=0)

@op(torch.ops.aten.embedding_renorm_)
def _aten_embedding_renorm_(weight, indices, max_norm, norm_type):
# Adapted from https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/Embedding.cpp
unique_indices = jnp.unique(indices)

norm = jnp.linalg.norm(
_aten_embedding(weight, unique_indices),
ord=norm_type,
axis=1,
)

indice_idx = jnp.where(norm > max_norm)

scale = max_norm / (norm[indice_idx] + 1e-7)

indices_to_update = unique_indices[indice_idx]

weight = weight.at[indices_to_update].set(
weight[indices_to_update] * scale[:, None]
)
return weight

#- func: _embedding_bag_forward_only(
# Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False,
Expand Down

0 comments on commit 5872b20

Please sign in to comment.