Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Bayesian PINN solver for ODE #692

Merged
merged 40 commits into from
Aug 17, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
d897e3f
begin project
AstitvaAggarwal May 11, 2023
b77488a
Merge branch 'develop' of https://github.com/AstitvaAggarwal/NeuralPD…
AstitvaAggarwal Jun 4, 2023
907c3ed
Started by defining Combined likelihood and priors
AstitvaAggarwal Jun 6, 2023
c8b5dba
model + custom likelihood + adding tests
AstitvaAggarwal Jun 11, 2023
81c47cc
formatted stuff
AstitvaAggarwal Jun 11, 2023
56e7afc
It works but predictions are superbad
AstitvaAggarwal Jun 17, 2023
978e115
formats, added advancedhmc version of BPINN, BPINN test file
AstitvaAggarwal Jun 29, 2023
a793b3d
Tests -BNN, BPINN,changes in var,full/half logprob
AstitvaAggarwal Jun 30, 2023
e817b93
custom logprob in turing_MCMC ,advancedHMC_MCMC-added NN parameter pr…
AstitvaAggarwal Jul 3, 2023
7980448
Trying to fix formatting done by mistake in commit-978e115
AstitvaAggarwal Jul 4, 2023
b760c60
The BPINN ODE solver works well for interpolation, now just needs gen…
AstitvaAggarwal Jul 6, 2023
7e3a3cd
Performs well in Extrapolation,Interpolation.
AstitvaAggarwal Jul 8, 2023
31d1f65
autodiff takes ~30mins sample- test file example
AstitvaAggarwal Jul 10, 2023
8329b8f
Less Numerical errors, faster sampling(autodiff),better EBFMI estimat…
AstitvaAggarwal Jul 11, 2023
15ec776
Inverse parameter estimation, seperate priors for weights and biases
AstitvaAggarwal Jul 15, 2023
c2cce0b
Inverse problem gives incorrect parameter esim
AstitvaAggarwal Jul 17, 2023
dfed5ee
Tests for inverse problem estimation
AstitvaAggarwal Jul 17, 2023
f77d04a
parameter estimation does not work
AstitvaAggarwal Jul 17, 2023
7cb6141
Parameter Estimation works, Custom choice for Prior Distributions
AstitvaAggarwal Jul 22, 2023
e8b1449
fixed a small error,cleaned BPINNtests.jl
AstitvaAggarwal Jul 22, 2023
b4e090d
Now works with Lux.jl chains,parallel sampled chains,good ode paramet…
AstitvaAggarwal Jul 25, 2023
9c07c97
Cleared up majority of Tests, BPINN ODE solver done.
AstitvaAggarwal Aug 1, 2023
2c70494
minor changes
AstitvaAggarwal Aug 1, 2023
9a47e36
Tests run automatically
AstitvaAggarwal Aug 7, 2023
c24ed46
Test Run automatically attempt-1
AstitvaAggarwal Aug 7, 2023
4468534
minor change
AstitvaAggarwal Aug 7, 2023
0bf8152
small change again fr fr
AstitvaAggarwal Aug 7, 2023
b23db49
changes from reviews-1
AstitvaAggarwal Aug 7, 2023
08883bd
Removed extra packages used in Tests
AstitvaAggarwal Aug 7, 2023
a42b6f7
Merge branch 'SciML:master' into develop
AstitvaAggarwal Aug 8, 2023
48dfa66
Test dependancies errors
AstitvaAggarwal Aug 9, 2023
bca2743
uses AHMC convienience structs
AstitvaAggarwal Aug 9, 2023
5357b33
Improved Tests - 1
AstitvaAggarwal Aug 10, 2023
3b35fc3
improved Tests - 2
AstitvaAggarwal Aug 11, 2023
d467250
Fixed the Problem in Test
AstitvaAggarwal Aug 12, 2023
f2e502a
improved Tests
AstitvaAggarwal Aug 12, 2023
e479c47
pls work
AstitvaAggarwal Aug 12, 2023
4df2d63
This shall work!
AstitvaAggarwal Aug 12, 2023
ec24839
Type checking for dataset, Solver usage made clear
AstitvaAggarwal Aug 15, 2023
0428dc5
changes from reviews
AstitvaAggarwal Aug 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 37 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,22 @@ authors = ["Chris Rackauckas <accounts@chrisrackauckas.com>"]
version = "5.7.0"

[deps]
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb"
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Missings = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
StanSample = "c1514b29-d3a0-5178-b312-660c88baa699"
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
TransformVariables = "84d833dd-6860-57f9-a1a7-6da5db126cff"
TransformedLogDensities = "f9bc47f6-f3f8-4f3b-ab21-f8bc73906f26"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"

Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down Expand Up @@ -37,6 +53,22 @@ UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
DiffResults = "1.0"
Distances = "0.10"
DynamicHMC = "2.1, 3"
LabelledArrays = "1.0"
LogDensityProblemsAD = "1"
MacroTools = "0.5"
Missings = "0.4, 1.0"
PDMats = "0.11"
Parameters = "0.12"
Requires = "1.0"
StanSample = "6, 7"
StructArrays = "0.6"
TransformVariables = "0.8"
TransformedLogDensities = "1"
Turing = "0.25"

Adapt = "3"
ArrayInterface = "6, 7"
CUDA = "4"
Expand Down Expand Up @@ -78,5 +110,9 @@ Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

ParameterizedFunctions = "65888b18-ceab-5e60-b2b9-181511a3b968"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f"

[targets]
test = ["Test", "CUDA", "SafeTestsets", "OptimizationOptimisers", "OptimizationOptimJL", "Pkg", "OrdinaryDiffEq", "IntegralsCuba"]
test = ["Test", "ParameterizedFunctions", "StatsBase","SteadyStateDiffEq","CUDA", "SafeTestsets", "OptimizationOptimisers", "OptimizationOptimJL", "Pkg", "OrdinaryDiffEq", "IntegralsCuba"]
16 changes: 16 additions & 0 deletions src/NeuralPDE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,16 @@ $(DocStringExtensions.README)
"""
module NeuralPDE

# Fixes #682
using Turing, MacroTools, LinearAlgebra
using Parameters, Distributions, Optim, Requires
using Distances, DocStringExtensions, Random, StanSample
using DynamicHMC, TransformVariables, LogDensityProblemsAD, TransformedLogDensities
STANDARD_PROB_GENERATOR(prob, p) = remake(prob; u0 = eltype(p).(prob.u0), p = p)
function STANDARD_PROB_GENERATOR(prob::EnsembleProblem, p)
EnsembleProblem(remake(prob.prob; u0 = eltype(p).(prob.prob.u0), p = p))
end

using DocStringExtensions
using Reexport, Statistics
@reexport using DiffEqBase
Expand Down Expand Up @@ -38,6 +48,9 @@ abstract type AbstractPINN end

abstract type AbstractTrainingStrategy end

# Fixes #682
include("turing_MCMC.jl")

include("pinn_types.jl")
include("symbolic_utilities.jl")
include("training_strategies.jl")
Expand All @@ -62,4 +75,7 @@ export NNODE, TerminalPDEProblem, NNPDEHan, NNPDENS, NNRODE,
MiniMaxAdaptiveLoss,
LogOptions

# Fixes #682
export turing_MCMC

end # module
49 changes: 49 additions & 0 deletions src/turing_MCMC.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
using Turing

function pinn(prob::DiffEqBase.DEProblem, parameters, dataset, recon)
chain = recon(parameters)

# output PINN predictions
sol
end

myloglikelihood(x, μ) = loglikelihood(MvNormal(μ, 1), x)

# priors: pdf for W,b + pdf for ODE params
function bayesian_pinn_ode(prob::DiffEqBase.DEProblem, chain, priors, dataset, ts;
sampling_strategy = NUTS(0.65), num_samples = 1000,
syms = [Turing.@varname(theta[i]) for i in 1:length(priors)])
param_initial, recon = Flux.destructure(chain)
nparameters = length(param_initial)

alpha = 0.09
sig = sqrt(1.0 / alpha)
@model function bayes_pinn(dataset)
theta = Vector{T}{undef, length(priors)}
for i in eachindex(priors)
theta[i] ~ NamedDist(priors[i], sym[i])
end
nnparameters ~ MvNormal(zeros(nparameters), sig .* ones(nparameters))

preds = pinn(prob, parameters, dataset, recon)
Vaibhavdixit02 marked this conversation as resolved.
Show resolved Hide resolved
for i in eachindex(ts)
datapoints ~ MvNormal(pred)
Vaibhavdixit02 marked this conversation as resolved.
Show resolved Hide resolved
Turing.@addlogprob! physloglikelihood(pred, μ)
Vaibhavdixit02 marked this conversation as resolved.
Show resolved Hide resolved
end
end

model = bayes_pinn(dataset)
ch = sample(model, sampling_strategy, num_samples)
return ch
end

function pendulum(du, u, p, t)
ω, L = p
x, y = u
du[1] = y
du[2] = -ω * y - (9.8 / L) * sin(x)
end

u0 = [1.0, 0.1]
tspan = (0.0, 10.0)
prob1 = ODEProblem(pendulum, u0, tspan, [1.0, 2.5])