Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Bayesian PINN solver for ODE #692

Merged
merged 40 commits into from
Aug 17, 2023

Conversation

AstitvaAggarwal
Copy link
Contributor

fixes #682

src/turing_MCMC.jl Outdated Show resolved Hide resolved
src/turing_MCMC.jl Outdated Show resolved Hide resolved
src/turing_MCMC.jl Outdated Show resolved Hide resolved
@AstitvaAggarwal
Copy link
Contributor Author

AstitvaAggarwal commented Jun 29, 2023

only changed Project.toml(added some packages), NeuralPDE.jl, turing_MCMC.jl and added 2 new files(BPINN_Tests.jl, advancedHMC_MCMC.jl), the rest is the result of JuliaFormatter.format().

@Vaibhavdixit02
Copy link
Member

Why are there formatting changes in this PR, I think it might have been by mistake? Try to see if you can remove those since it makes hard to review relevant files

@AstitvaAggarwal
Copy link
Contributor Author

yeah in commit-978e115 I formatted the whole NeuralPDE directory, ive fixed it now. Sorry for the inconvenience.

@AstitvaAggarwal
Copy link
Contributor Author

AstitvaAggarwal commented Jul 11, 2023

7mins for finite diff derivative scheme and ~20mins for autodiff, 700 samples for NUTS with multinomial sampling. beyond a certain number of samples the Hamiltonian energy decrease stagnates. Problem specific set of Metrics,integrators,adaptors etc. increases performance, speed(HMC faster by around 5x). I was thinking something like Adaptive weighting of BPINNs might help as well due to the nature of optimization.

test/BPINN_Tests.jl Outdated Show resolved Hide resolved
test/BPINN_Tests.jl Outdated Show resolved Hide resolved
test/BPINN_Tests.jl Outdated Show resolved Hide resolved
# priors: pdf for W,b + pdf for ODE params
# lotka specific kwargs here

function ahmc_bayesian_pinn_ode(prob::DiffEqBase.DEProblem, chain,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be good to support Float32 for GPU use

.github/workflows/CI.yml Outdated Show resolved Hide resolved
@@ -13,14 +13,17 @@ jobs:
fail-fast: false
matrix:
group:
#fixes 682
-ODEBPINN
Copy link
Member

@xtalax xtalax Aug 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

- ODEBPINN

needs a space. One check runs for each group, not for each testset which is why there are only 5 (should be 6)

p = Tar.prob.p
dt = Tar.physdt
if isempty(Tar.dataset[end])
t = collect(Float64, Tar.prob.tspan[1]:dt:Tar.prob.tspan[2])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

replace all Float64 with eltype(\theta) in case we're running on GPU


# dataset would be (x̂,t)
# priors: pdf for W,b + pdf for ODE params
function ahmc_bayesian_pinn_ode(prob::DiffEqBase.DEProblem, chain;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xtalax does this suit the API in NeuralPDE or should we change this?

Copy link
Member

@Vaibhavdixit02 Vaibhavdixit02 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just curious about the API here being suitable which can also be addressed from another PR as per @xtalax preference

Comment on lines 16 to 18
#fixes 682
- ODEBPINN

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#fixes 682
- ODEBPINN
- ODEBPINN

@@ -0,0 +1,459 @@
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this docstring needs to move to be right on top of the function, otherwise it won't work

chain::C
st::S
dataset::Vector{Vector{Float64}}
priors::Vector{Distribution}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is going to be unstable: does it not effect performance?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ive been doing @time and @Btime on the function on the merged struct version and the version you put this review (but uses prob::ODEProblem) the difference dosent change much. In fact on many runs the unstable version performs better.

t[i])
for i in 1:length(out[1, :])]
end
physsol = hcat(physsol...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
physsol = hcat(physsol...)
physsol = reduce(hcat,physsol)

@ChrisRackauckas
Copy link
Member

Implementation-wise, I think it generally looks pretty good now. I left a few comments but those are pretty minor.

As for API, I do think that might need to change a little. I think this should be setup as just a solve(prob::ODEProblem, alg::BNNODE(...)) with the details for the fitting process in the BNNODE. Now this would be a bit of a weird ODE solver since the solution cannot be easily represented as an ODESolution, but what we can do is add a .original field to the ODESolution (which exists in most other solution types already anyways so it's a reasonable extension) which then holds the fh_mcmc_chain, fhsamples, fhstats (which we can put in a BPINNStatistics struct or something), and then the "standard" solution computation can be done in advance to construct the sol.u. That would make all of this stuff (https://github.com/SciML/NeuralPDE.jl/pull/692/files#diff-f426df4f589f3ec64dfa582ee30371490608d20f836b399cbfebcb443e832c98R130-R147) just isapprox(sol.u , ...) which makes sense because most users will first want to know "what's the ODE solution"? Something that would be nice to the user would be to represent this as a MonteCarloMeasurements or Measurements type so that plotting and such automatically works with error bounds. As done in NNODE, the t would just be chosen to match the saveat.

Then ahmc_bayesian_pinn_ode can be documented as a lower level API for those who really want to dive in, but solve(prob::ODEProblem, alg::BNNODE(...)) would then work very naturally "just like any other ODE solver", with the extra information always available in sol.original but with the uncertainty information represented in the u. That API would be much nicer than the current post-solution analysis required in the tests.

That all said, creating that higher level API is simply calls to ahmc_bayesian_pinn_ode under the hood and then the post analysis that's exactly done in the tests right now, so I don't think doing this high level API will require any new code. It's just restructuring it to be a bit simpler.

@AstitvaAggarwal
Copy link
Contributor Author

AstitvaAggarwal commented Aug 16, 2023

Thank you for the reviews, I'll create a seperate Issue, PR for the High level APIs for both the ODE, PDE solver together. I'll move onto the PDE solver first for now, unless you all would recommend finishing with the APIs first.

@ChrisRackauckas
Copy link
Member

Finishing the higher level ODE interface would be top priority. This functionality effectively won't exist until that piece is done.

@ChrisRackauckas ChrisRackauckas merged commit a8ef303 into SciML:master Aug 17, 2023
14 of 17 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Bayesian Inference for NNODE
4 participants