Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: functional API #521

Closed
wants to merge 2 commits into from
Closed

WIP: functional API #521

wants to merge 2 commits into from

Conversation

jonas-eschle
Copy link
Contributor

@jonas-eschle jonas-eschle commented Mar 6, 2024

If we swap the order of data and params, the params associated with the model are treated as defaults and can be overriden by giving them explicitly.

This (to my understanding) makes the function callable with parameters as args instead of updating the parameters in the model.

There should not be any legacy issue as I hardly doubt anyone gives wrong parameter values and relies on the functionto overwrite it with the defaults?

If that's fine, I'll gladly add a quick test, changelog etc.

What do you think @redeboer ?

@redeboer redeboer marked this pull request as draft March 6, 2024 09:10
@redeboer redeboer added the ⚙️ Enhancement Improvements and optimizations of existing features label Mar 6, 2024
@redeboer
Copy link
Member

redeboer commented Mar 6, 2024

Hi @jonas-eschle, thanks a lot for diving through the codebase and thinking along!

First some background on the function interfaces.
On the most general level, a Function is just a callable that gives some output for some input. Estimator is an example of that: it takes a dict of parameters and returns a float. Functions that we plot or fit usually take a DataSample and return some array (usually 1D, real-valued intensities). ParametrizedFunction is a specific Function interface that keep track of which arguments are to be tweaked in a fitting algorithm.1 (I now realized some type parameters were missing, see #522.)
The reason for using dicts to pass around parameters and data input is mainly because of the large number of function arguments (parameters and data columns) in amplitude models. In addition, when functions are created from SymPy expressions, parameters and data arguments can have rather large names that are 'dummified' to valid Python argument names. Dictionary inputs are there to remove the need for the caller to think about those argument positions.

I think what you want to change has more to do with the argument order of the functions that are generated with SymPy. Essentially #488, but swapped, see here. But this is more about the implementation, not the API. For that we would have to think about 1.

Note

CI currently fails because of ComPWA/actions#61. The current constraint files were generated with pip-tools, whereas uv pip install somehow requires a narrower set of constraints (specifically a lower version of python-constraints, see ComPWA/qrules#256). This should be fixed tomorrow in #523, when the constraints are automatically updated, this time with uv pip compile (ComPWA/update-pip-constraints#22).

Footnotes

  1. Their update_parameters() method is for that. But at some stage this has to be redesigned or abandoned altogether, as it seemingly does not work well with autodiff. This was also the motivation for ENH: set data keys as first positional arguments #488. 2

@jonas-eschle
Copy link
Contributor Author

jonas-eschle commented Mar 6, 2024

Thanks for the explanations @redeboer, that's useful insights!

AFAIU, the #488 introduces some ordering? I may don't quite understand the details of that, so that's why I am maybe missing the point (or is this anyways heavy WIP?)

Btw,

The reason for using dicts to pass around parameters and data

fully agree, and I would suggest to do that anyways.

But I think my point is different: the order of the arguments doesn't matter (as it's a mapping), but the code has a different meaning

params = {**p1, **p2}

gives precedence to p2

In the code, this means that it gives precedence to the default args. So think of the following

assuming it takes x (data) and mu, sigma (parameters). The model has a default for mu and sigma.

We can call it with just data = {'x': array}. That's fine, it uses mu & sigma as defaults.

But we cannot explicitly say the parameters here, i.e. data = {'x': array, 'mu': float} will not use mu, because it will be override by the default.

Now it says basically "if the users gives a parameter value, override it and ignore it, use the default". On the other hand, if we swap as suggested, it's "use the default values if not given, but if the user provides a parameter explicitly, use that instead"

So nothing changes anywhere basically except that a parameter can be given explicitly. Actually, the current order of deliberatly (and silently!) overwriting parameters that are given explicitly to the function is probably not intentional? It looks a bit to me as if this always should have been the other way around?
Or is there maybe any disadvantage that I don't see?

@redeboer
Copy link
Member

redeboer commented Mar 6, 2024

AFAIU, the #488 introduces some ordering? I may don't quite understand the details of that, so that's why I am maybe missing the point (or is this anyways heavy WIP?)

Indeed, it's an internal thing related to how sympy lambdifies symbols to argument names. Because non-python symbol names are dummified, you can't really use the (internal) lambdified function with keyword arguments, so #448 separates the parameter arguments from the data arguments. It's indeed WIP (see ComPWA/compwa.github.io#195), but doesn't affect the interfaces.

But I think my point is different

Ahh sorry I misunderstood the intention of the PR then :) I thought it had to do with the internal lambdified functions, so #488. You want to provide the option to use the parametrized function with different parameter values without using update_parameters(). Okay yeah that would work, but it's a bit hacky and going against the way the interfaces are designed now (the ParametrizedFunction call expects a DataSample that has arrays as values).

If we go this way, the interface descriptions have to be updated (always good to do that anyways :)). It means that a ParametrizedFunction is carrying around some default values that can optionally be updated with update_parameters(), but you can also temporarily overwrite them by passing them through the call as dictionary keys.

Btw, is there a speficic use case you have in mind?

@jonas-eschle
Copy link
Contributor Author

Yes, exactly, that's the idea, and I think, except of adding in the docs that you can also use it directly, there is nothing to change, should be 100% backwards compatible but rather forward compatible (at least if I think about the best API idea).

Btw, is there a speficic use case you have in mind?

The main consequence: it becomes jittable (right? Or am I missing something here?)! Also, the parametrized function becomes, automatically, from a very specific object that has a very specific interface (such as the parameter update) to a general function that takes the variables as inputs. But even without the jitting, for example for minimization, it's now a completely functional API aswell, basically for free. So JIT and minimization is the main motivation. (I thought grads also work, that would have been better, but yep).

So there is a difference between calling the function with specific parameters and changing the default values of the function. While autodiff doesn't quite work with it (not sure why?), it's also the way to go for that.

it's a bit hacky and going against the way the interfaces are designed now (the ParametrizedFunction call expects a DataSample that has arrays as values).

Is it? To me, it looks perfectly elegant. Especially as internally, the ParametrizedFunction does the exact same thing but overrides the parameters if given: It currently doesn't allow but forces it's state. I think it would just solve, in a very elegant way, one of the largest problems with the API, namely that it isn't functional (but here I may miss details).

Since you're also working on jitting in other PRs I've seen, maybe I am really just missing details. But then, the functional API that it offers for free seems, without a downside I think, seems easy to enable?

@redeboer
Copy link
Member

redeboer commented Mar 6, 2024

there is nothing to change, should be 100% backwards compatible but rather forward compatible (at least if I think about the best API idea).

Yes that's true for the implementation that this PR affects: it extends the behaviour of a ParametrizedBackendFunction, which is an implementation of the ParametrizedFunction interface. But it does not guarantee that any other implementation of a ParametrizedFunction has the same behaviour. In fact, its call signature suggests that you should feed it data samples, not parameters. (I'm not saying that it is a good design/API – the struggle with autodiff suggests otherwise – but it is what the interfaces currently guarantee.)

The main consequence: it becomes jittable

I'm not too familiar with the internals of jitting, but could be. I wonder whether/why a ParametrizedBackendFunction as it is now is jittable now (I think not, because of these dictionaries the class is carrying around). If not, why would it become jittable if parameters can be passed to the call method?
Perhaps this also touches on design. As of now, the interfaces do not guarantee or suggest that they are jittable, as jitting etc is considered to be done by the function itself internally (I'm not even sure what that would be – guarantuee that you can call jax.jit on it? 😅). In the case of ParametrizedBackendFunction, you can access the (potentially jitted) function through its function attribute (this is intentionally not part of the ParametrizedFunction interface).

one of the largest problems with the API, namely that it isn't functional

Good point, actually the perfect summary of the problems that the ParametrizedFunction has 😆 It came from original ComPWA project and the more traditional thinking in terms of an optimizer updating parameters step-by-step.

@jonas-eschle
Copy link
Contributor Author

We can also discuss this tomorrow, it's an interesting issue ;)

If not, why would it become jittable if parameters can be passed to the call method?

It passes through a "symbolic parameter", much like sympy, but care must be taken that it's not converted somewhere. Also, it would not reflect any updated state, leading to a different behavior if run with JIT vs without (i.e. the Python object won't be updated if jitted, but it will if not jitted if the parameter is set).

I think it's jittable because it's simply a function now? Hm, ParametrizedBackendFunction okay! So that's the "executable", I've been a bit confused about what to call then and what not to, I'll try that! The ParametrizedBackendFunction.function is maybe what I was "looking for"

@redeboer
Copy link
Member

redeboer commented Mar 6, 2024

We can also discuss this tomorrow, it's an interesting issue ;)

Indeed!

I think it's jittable because it's simply a function now? Hm, ParametrizedBackendFunction okay! So that's the "executable", I've been a bit confused about what to call then and what not to, I'll try that! The ParametrizedBackendFunction.function is maybe what I was "looking for"

I don't know how jit works exactly (probably also package-specific), but if I were jit, I would have a hard time figuring out what the effect of those internal parameters is, even if the unpacking order is swapped 😅 If that is true, the fact that a ParametrizedFunction (yes the interface, so anything that implements that) is not functional makes it non-jittable by design.

I've been a bit confused about what to call then and what not to, I'll try that! The ParametrizedBackendFunction.function is maybe what I was "looking for"

Hah, actually this touches on a larger issue :D The use of tensorwaves has shifted towards streamlining the creation and use of large expressions/functions and less towards fitting, even if the interfaces had the latter in mind. Perhaps that is also what this PR is about, that the interfaces need to be rethought towards that. Also interesting!

@jonas-eschle
Copy link
Contributor Author

jonas-eschle commented Mar 6, 2024

what the effect of those internal parameters is, even if the unpacking order is swapped

actually, here, only Python is happening: yes, these are "symbolic parameters", but only the ones that survive the dict merge survive the jit, so to say

Hah, actually this touches on a larger issue :D

Sounds good, that is well aligned and I am really happy to see the extensive work on the conversion and careful backend handling, it's quite a tricky issue (so happy to make use of it!)

P.S: I'll just keep it open for now if somebody has an urgent comment, but probably close it if things are anyways meant to be handled differently and we can see in other PRs/chat

@jonas-eschle
Copy link
Contributor Author

jonas-eschle commented Mar 8, 2024

@redeboer feel free to close, whatever you think fits best, as discussed, it's not crucial and I'll use other ways to access the function.
(If still keeping open with the intent to merge, I can gladly add some docs to it to explain the behavior)

@redeboer
Copy link
Member

redeboer commented Mar 11, 2024

Thanks again for thinking along! :)
As discussed, this issue/PR touches on larger design issues, so I wouldn't implement this change given what the interfaces currently are. It would make the parameters in a ParametrizedBackendFunction behave like 'default' parameters:

import numpy as np
import sympy as sp

from tensorwaves.function.sympy import create_parametrized_function

x, μ, σ = sp.symbols("x mu sigma")
gaussian_expr = sp.exp(-(((x - μ) / σ) ** 2) / 2) / (σ * sp.sqrt(2 * sp.pi))

gaussian_func = create_parametrized_function(
    gaussian_expr,
    parameters={μ: 0, σ: 0.5},
    backend="numpy",
)
gaussian_func.update_parameters({"mu": 3.5})  # means updating "default" parameters?
y_values = gaussian_func(
    data={  # type hints and keyword suggest a data sample as input
        "x": np.linspace(-2.0, +2.0, num=100),
        "mu": -1.5,  # now this parameter is used
    }
)

Instead, to achieve the desired behavior, it is probably better to use create_function(). The resulting function does not have update_parameters().

from tensorwaves.function.sympy import create_function

gaussian_func = create_function(gaussian_expr, backend="numpy")
y_values = gaussian_func({
    "x": np.linspace(-2.0, +2.0, num=100),
    "mu": -1.2,
    "sigma": 0.5,
})

Still, the PR touched on many points that are probably better discussed in #525.

@redeboer redeboer closed this Mar 11, 2024
@redeboer redeboer deleted the je_feat_functionalcall branch March 11, 2024 09:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
⚙️ Enhancement Improvements and optimizations of existing features
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants