Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Does this approach save forward passes? #18

Open
gdalle opened this issue Apr 22, 2024 · 1 comment
Open

Does this approach save forward passes? #18

gdalle opened this issue Apr 22, 2024 · 1 comment

Comments

@gdalle
Copy link

gdalle commented Apr 22, 2024

Hi @martinResearch and thanks for creating this package, which I stumbled upon while reading jax-ml/jax#1032. This issue is about understanding it a little better.

From what I can tell, known sparsity patterns are most useful in conjunction with coloring algorithms, because they reduce the number of forward (or reverse) passes needed to compute a Jacobian. Typically, the Jacobian of a function $f : R^n \to R^m$ would require $n$ forward passes, one for each of the Jacobian-vector products associated with the basis vectors of $R^n$. Grouping basis vectors together is what the coloring step is all about, and it can reduce the number of forward passes from $O(n)$ to $O(1)$ in the best cases. See https://epubs.siam.org/doi/10.1137/S0036144504444711 for more details, or the Example 5.7 from https://tomopt.com/docs/TOMLAB_MAD.pdf#page26.

As you stated in jax-ml/jax#1032 (comment), your library does not rely on this paradigm. When you talk of a "single forward pass using matrix-matrix products at each step of the forward computation", is it correct that you still end up computing the JVP with every single basis vector? In other words, while the runtime may be low in practice thanks to vectorization and efficient sparse operations, the theoretical complexity remains $O(n)$?

Thanks in advance for your response

ping @adrhill

@martinResearch
Copy link
Owner

Hi @gdalle ,
"is it correct that you still end up computing the JVP with every single basis vector" in some sense yes: I start with the identity matrix for the derivatives of the input vector (

x.derivatives = speye(numel(values));
) and then multiply that derivates matrix with the Jacobian of each operation in the chain using the forward chain rule . Each column of the initial identity matrix can be interpreted as a basis vector. Even if I were to use dense matrices to represent the derivatives, also the complexity would be O(n) this would differs slightly from an implementation of Forward AD that would call n times the end-to-end function because the code is executed only once.

"the theoretical complexity remains O(n)?", It depends on how you define "theoretical" and were we draw the line between theoretical and practice.

If a function you want to differentiate has a Jacobian whose sparsity is s, I would expect the complexity of the method that uses the coloring approach to be O((1-s)*n) with (1-s)*n an approximation of the number of groups and passes needed, is that correct?

The complexity of the approach implemented here would also theoretically be O((1-S)*n) if we assume that all the intermediate derivatives in the forward derivates chain rule have also sparsity s or greater then s. I believe that in general this it a reasonable assumption because it is rare to have denser intermediate derivatives the the end derivates because that would require to have derivatives values that cancel each other out in some step to increase the sparsity in following steps of the chain. If we can theoretically prove for a particular class of problems that the sparsity of the intermediate derivatives is at least s then the speedup is not only practical but also becomes theoretical for that specific problem.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants