Skip to content

Is there a 2D "masked convolution" in JAX/Flax? #979

Answered by avital
avital asked this question in Q&A
Discussion options

You must be logged in to vote

@levskaya points out that:

Masked convolutions are a fair bit of work to implement, and I don't think anyone's done this yet in any JAX setting that I'm aware of.
Curious: what kind of masking are you wanting? "masked conv" is a bit of an overloaded term -- Causal constraints for generation or general sparsity? Something more like what are called partial convs?

@j-towns actually implemented PixelCNN++ style causal conv layers, e.g. https://github.com/google/flax/blob/master/examples/pixelcnn/pixelcnn.py#L231 (ConvDown, ConvDownRight) but it's pretty specific to that model, rather than a fully parametrized, general mask-conv.

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by marcvanzee
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
1 participant