Implementation of DLRM: Embedding Operations #4227
Unanswered
Sir-NoChill
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hello community!
I am currently trying to reimplement Meta's DLRM algorithm, specifically the architecture discussed in this paper for profiling and performance research. I am having some trouble with writing a flax implementation of the sparse vector embedding code:
In Meta's implementation, they initialize a
torch.EmbeddingBag
as follows (see line in the original code):but they subsequently use it like this (refer to this line):
However I cannot find a way to duplicate this functionality using the flax
nnx.Embed
orlinen.Embed
class. I am also relatively new to jax/flax so I apologize in advance for my further questions :) My current model is as follows (using nnx):Beta Was this translation helpful? Give feedback.
All reactions