Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Bayesian PINN solver for ODE #692

Merged
merged 40 commits into from
Aug 17, 2023
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
d897e3f
begin project
AstitvaAggarwal May 11, 2023
b77488a
Merge branch 'develop' of https://github.com/AstitvaAggarwal/NeuralPD…
AstitvaAggarwal Jun 4, 2023
907c3ed
Started by defining Combined likelihood and priors
AstitvaAggarwal Jun 6, 2023
c8b5dba
model + custom likelihood + adding tests
AstitvaAggarwal Jun 11, 2023
81c47cc
formatted stuff
AstitvaAggarwal Jun 11, 2023
56e7afc
It works but predictions are superbad
AstitvaAggarwal Jun 17, 2023
978e115
formats, added advancedhmc version of BPINN, BPINN test file
AstitvaAggarwal Jun 29, 2023
a793b3d
Tests -BNN, BPINN,changes in var,full/half logprob
AstitvaAggarwal Jun 30, 2023
e817b93
custom logprob in turing_MCMC ,advancedHMC_MCMC-added NN parameter pr…
AstitvaAggarwal Jul 3, 2023
7980448
Trying to fix formatting done by mistake in commit-978e115
AstitvaAggarwal Jul 4, 2023
b760c60
The BPINN ODE solver works well for interpolation, now just needs gen…
AstitvaAggarwal Jul 6, 2023
7e3a3cd
Performs well in Extrapolation,Interpolation.
AstitvaAggarwal Jul 8, 2023
31d1f65
autodiff takes ~30mins sample- test file example
AstitvaAggarwal Jul 10, 2023
8329b8f
Less Numerical errors, faster sampling(autodiff),better EBFMI estimat…
AstitvaAggarwal Jul 11, 2023
15ec776
Inverse parameter estimation, seperate priors for weights and biases
AstitvaAggarwal Jul 15, 2023
c2cce0b
Inverse problem gives incorrect parameter esim
AstitvaAggarwal Jul 17, 2023
dfed5ee
Tests for inverse problem estimation
AstitvaAggarwal Jul 17, 2023
f77d04a
parameter estimation does not work
AstitvaAggarwal Jul 17, 2023
7cb6141
Parameter Estimation works, Custom choice for Prior Distributions
AstitvaAggarwal Jul 22, 2023
e8b1449
fixed a small error,cleaned BPINNtests.jl
AstitvaAggarwal Jul 22, 2023
b4e090d
Now works with Lux.jl chains,parallel sampled chains,good ode paramet…
AstitvaAggarwal Jul 25, 2023
9c07c97
Cleared up majority of Tests, BPINN ODE solver done.
AstitvaAggarwal Aug 1, 2023
2c70494
minor changes
AstitvaAggarwal Aug 1, 2023
9a47e36
Tests run automatically
AstitvaAggarwal Aug 7, 2023
c24ed46
Test Run automatically attempt-1
AstitvaAggarwal Aug 7, 2023
4468534
minor change
AstitvaAggarwal Aug 7, 2023
0bf8152
small change again fr fr
AstitvaAggarwal Aug 7, 2023
b23db49
changes from reviews-1
AstitvaAggarwal Aug 7, 2023
08883bd
Removed extra packages used in Tests
AstitvaAggarwal Aug 7, 2023
a42b6f7
Merge branch 'SciML:master' into develop
AstitvaAggarwal Aug 8, 2023
48dfa66
Test dependancies errors
AstitvaAggarwal Aug 9, 2023
bca2743
uses AHMC convienience structs
AstitvaAggarwal Aug 9, 2023
5357b33
Improved Tests - 1
AstitvaAggarwal Aug 10, 2023
3b35fc3
improved Tests - 2
AstitvaAggarwal Aug 11, 2023
d467250
Fixed the Problem in Test
AstitvaAggarwal Aug 12, 2023
f2e502a
improved Tests
AstitvaAggarwal Aug 12, 2023
e479c47
pls work
AstitvaAggarwal Aug 12, 2023
4df2d63
This shall work!
AstitvaAggarwal Aug 12, 2023
ec24839
Type checking for dataset, Solver usage made clear
AstitvaAggarwal Aug 15, 2023
0428dc5
changes from reviews
AstitvaAggarwal Aug 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "5.7.0"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
Expand All @@ -15,9 +16,11 @@ DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
Integrals = "de52edbc-65ea-441a-8357-d3a637375a31"
IntegralsCubature = "c31f79ba-6e32-46d4-a52f-182a8ac42a54"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
Expand Down
6 changes: 6 additions & 0 deletions src/NeuralPDE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ include("transform_inf_integral.jl")
include("discretize.jl")
include("neural_adapter.jl")

# fixes #682
include("advancedHMC_MCMC.jl")

export NNODE, TerminalPDEProblem, NNPDEHan, NNPDENS, NNRODE,
KolmogorovPDEProblem, NNKolmogorov, NNStopping, ParamKolmogorovPDEProblem,
KolmogorovParamDomain, NNParamKolmogorov,
Expand All @@ -62,4 +65,7 @@ export NNODE, TerminalPDEProblem, NNPDEHan, NNPDENS, NNRODE,
MiniMaxAdaptiveLoss,
LogOptions

#fixes #682
export ahmc_bayesian_pinn_ode

end # module
350 changes: 350 additions & 0 deletions src/advancedHMC_MCMC.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,350 @@
using AdvancedHMC, ForwardDiff, LogDensityProblems, LinearAlgebra, Distributions, Functors

mutable struct LogTargetDensity{C, S, I}
dim::Int
prob::DiffEqBase.DEProblem
AstitvaAggarwal marked this conversation as resolved.
Show resolved Hide resolved
chain::C
st::S
dataset::Vector{Vector{Float64}}
priors::Vector{Distribution}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is going to be unstable: does it not effect performance?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ive been doing @time and @Btime on the function on the merged struct version and the version you put this review (but uses prob::ODEProblem) the difference dosent change much. In fact on many runs the unstable version performs better.

phystd::Vector{Float64}
l2std::Vector{Float64}
autodiff::Bool
physdt::Float64
extraparams::Int
init_params::I

function LogTargetDensity(dim, prob, chain::Optimisers.Restructure, st, dataset,
priors, phystd, l2std, autodiff, physdt, extraparams,
init_params::AbstractVector)
new{typeof(chain), Nothing, typeof(init_params)}(dim, prob, chain, nothing,
dataset, priors,
phystd, l2std, autodiff,
physdt, extraparams, init_params)
end
function LogTargetDensity(dim, prob, chain::Lux.AbstractExplicitLayer, st, dataset,
priors, phystd, l2std, autodiff, physdt, extraparams,
init_params::NamedTuple)
new{typeof(chain), typeof(st), typeof(init_params)}(dim, prob, chain, st,
dataset, priors,
phystd, l2std, autodiff,
physdt, extraparams,
init_params)
end
end

function LogDensityProblems.logdensity(Tar::LogTargetDensity, θ)
return physloglikelihood(Tar, θ) + priorweights(Tar, θ) + L2LossData(Tar, θ)
end

LogDensityProblems.dimension(Tar::LogTargetDensity) = Tar.dim

function LogDensityProblems.capabilities(::LogTargetDensity)
LogDensityProblems.LogDensityOrder{1}()
end

function generate_Tar(chain::Lux.AbstractExplicitLayer, init_params)
θ, st = Lux.setup(Random.default_rng(), chain)
return init_params, chain, st
end

function generate_Tar(chain::Lux.AbstractExplicitLayer, init_params::Nothing)
θ, st = Lux.setup(Random.default_rng(), chain)
return θ, chain, st
end

function generate_Tar(chain::Flux.Chain, init_params)
θ, re = Flux.destructure(chain)
return init_params, re, nothing
end

function generate_Tar(chain::Flux.Chain, init_params::Nothing)
θ, re = Flux.destructure(chain)
# find_good_stepsize,phasepoint takes only float64
θ = collect(Float64, θ)
return θ, re, nothing
end

# nn OUTPUT AT t

function vector_to_parameters(ps_new::AbstractVector, ps::NamedTuple)
@assert length(ps_new) == Lux.parameterlength(ps)
i = 1
function get_ps(x)
z = reshape(view(ps_new, i:(i + length(x) - 1)), size(x))
i += length(x)
return z
end
return Functors.fmap(get_ps, ps)
end

function (f::LogTargetDensity{C, S})(t::AbstractVector,
θ) where {C <: Optimisers.Restructure, S}
f.prob.u0 .+ (t' .- f.prob.tspan[1]) .* f.chain(θ)(adapt(parameterless_type(θ), t'))
end

function (f::LogTargetDensity{C, S})(t::AbstractVector,
θ) where {C <: Lux.AbstractExplicitLayer, S}
θ = vector_to_parameters(θ, f.init_params)
# Batch via data as row vectors
y, st = f.chain(adapt(parameterless_type(ComponentArrays.getdata(θ)), t'), θ, f.st)
ChainRulesCore.@ignore_derivatives f.st = st
f.prob.u0 .+ (t' .- f.prob.tspan[1]) .* y
end

function (f::LogTargetDensity{C, S})(t::Number,
θ) where {C <: Optimisers.Restructure, S}
# must handle paired odes hence u0 broadcasted
f.prob.u0 .+ (t - f.prob.tspan[1]) * f.chain(θ)(adapt(parameterless_type(θ), [t]))
end

function (f::LogTargetDensity{C, S})(t::Number,
θ) where {C <: Lux.AbstractExplicitLayer, S}
θ = vector_to_parameters(θ, f.init_params)
y, st = f.chain(adapt(parameterless_type(ComponentArrays.getdata(θ)), [t]), θ, f.st)
ChainRulesCore.@ignore_derivatives f.st = st
f.prob.u0 .+ (t .- f.prob.tspan[1]) .* y
end

# ODE DU/DX
function NNodederi(phi::LogTargetDensity, t::AbstractVector, θ, autodiff::Bool)
if autodiff
hcat(ForwardDiff.derivative.(ti -> phi(ti, θ), t)...)
else
(phi(t .+ sqrt(eps(eltype(t))), θ) - phi(t, θ)) ./ sqrt(eps(eltype(t)))
end
end

# physics loglikelihood over problem timespan
function physloglikelihood(Tar::LogTargetDensity, θ)
f = Tar.prob.f
p = Tar.prob.p
t = copy(Tar.dataset[end])

# parameter estimation chosen or not
if Tar.extraparams > 0
ode_params = Tar.extraparams == 1 ?
θ[((length(θ) - Tar.extraparams) + 1):length(θ)][1] :
θ[((length(θ) - Tar.extraparams) + 1):length(θ)]
else
ode_params = p == SciMLBase.NullParameters() ? [] : p
end

# train for NN deriative upon dataset as well as beyond but within timespan
autodiff = Tar.autodiff
dt = Tar.physdt

if t[end] != Tar.prob.tspan[2]
append!(t, collect(Float64, t[end]:dt:Tar.prob.tspan[2]))
end

# compare derivatives(matrix)
out = Tar(t, θ[1:(length(θ) - Tar.extraparams)])

# reject samples case
if any(isinf, out[:, 1]) || any(isinf, ode_params)
return -Inf
end

# this is a vector{vector{dx,dy}}(handle case single u(float passed))
if length(out[:, 1]) == 1
physsol = [f(out[:, i][1],
ode_params,
t[i])
for i in 1:length(out[1, :])]
else
physsol = [f(out[:, i],
ode_params,
t[i])
for i in 1:length(out[1, :])]
end
physsol = hcat(physsol...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
physsol = hcat(physsol...)
physsol = reduce(hcat,physsol)


# convert to matrix as nnsol
nnsol = NNodederi(Tar, t, θ[1:(length(θ) - Tar.extraparams)], autodiff)

physlogprob = 0
for i in 1:length(Tar.prob.u0)
# can add phystd[i] for u[i]
physlogprob += logpdf(MvNormal(nnsol[i, :], Tar.phystd[i]), physsol[i, :])
end
return physlogprob
end

# L2 losses loglikelihood(needed mainly for ODE parameter estimation)
function L2LossData(Tar::LogTargetDensity, θ)
# matrix(each row corresponds to vector u's rows)
if Tar.extraparams == 0
return 0
else
nn = Tar(Tar.dataset[end], θ[1:(length(θ) - Tar.extraparams)])

L2logprob = 0
for i in 1:length(Tar.prob.u0)
# for u[i] ith vector must be added to dataset,nn[1,:] is the dx in lotka_volterra
L2logprob += logpdf(MvNormal(nn[i, :], Tar.l2std[i]), Tar.dataset[i])
end
return L2logprob
end
end

# priors for NN parameters + ODE constants
function priorweights(Tar::LogTargetDensity, θ)
allparams = Tar.priors
# Vector of ode parameters priors
invpriors = allparams[2:end]

# nn weights
nnwparams = allparams[1]

if Tar.extraparams > 0
invlogpdf = sum(logpdf(invpriors[length(θ) - i + 1], θ[i])
for i in (length(θ) - Tar.extraparams + 1):length(θ); init = 0.0)

return (invlogpdf
+
logpdf(nnwparams, θ[1:(length(θ) - Tar.extraparams)]))
else
return logpdf(nnwparams, θ)
end
end

function integratorchoice(Integrator, initial_ϵ; jitter_rate = 3.0,
tempering_rate = 3.0)
if Integrator == JitteredLeapfrog
Integrator(initial_ϵ, jitter_rate)
elseif Integrator == TemperedLeapfrog
Integrator(initial_ϵ, tempering_rate)
else
Integrator(initial_ϵ)
end
end

function proposalchoice(Sampler, Integrator; n_steps = 50,
trajectory_length = 30.0)
if Sampler == StaticTrajectory
Sampler(Integrator, n_steps)
elseif Sampler == AdvancedHMC.HMCDA
Sampler(Integrator, trajectory_length)
else
Sampler(Integrator)
end
end

# dataset would be (x̂,t)
# priors: pdf for W,b + pdf for ODE params
# lotka specific kwargs here

function ahmc_bayesian_pinn_ode(prob::DiffEqBase.DEProblem, chain,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be good to support Float32 for GPU use

dataset::Vector{Vector{Float64}};
init_params = nothing, nchains = 1,
draw_samples = 1000, l2std = [0.05],
phystd = [0.05], priorsNNw = (0.0, 2.0),
param = [],
autodiff = false, physdt = 1 / 20.0f0,
Proposal = StaticTrajectory,
Adaptor = StanHMCAdaptor, targetacceptancerate = 0.8,
Integrator = Leapfrog,
Metric = DiagEuclideanMetric)

# NN parameter prior mean and variance(PriorsNN must be a tuple)
if isinplace(prob)
throw(error("The BPINN ODE solver only supports out-of-place ODE definitions, i.e. du=f(u,p,t)."))
end

if chain isa Lux.AbstractExplicitLayer || chain isa Flux.Chain
# Flux-vector, Lux-Named Tuple
initial_nnθ, recon, st = generate_Tar(chain, init_params)
else
error("Only Lux.AbstractExplicitLayer and Flux.Chain neural networks are supported")
end

if nchains > Threads.nthreads()
throw(error("number of chains is greater than available threads"))
elseif nchains < 1
throw(error("number of chains must be greater than 1"))
end

#
if chain isa Lux.AbstractExplicitLayer
# Lux chain(using component array later as vector_to_parameter need namedtuple,AHMC uses Float64)
initial_θ = collect(Float64, vcat(ComponentArrays.ComponentArray(initial_nnθ)))
else
initial_θ = initial_nnθ
end

# adding ode parameter estimation
nparameters = length(initial_θ)
ninv = length(param)
priors = [MvNormal(priorsNNw[1] * ones(nparameters), priorsNNw[2] * ones(nparameters))]

# append Ode params to all paramvector
if ninv > 0
# shift ode params(initialise ode params by prior means)
initial_θ = vcat(initial_θ, [Distributions.params(param[i])[1] for i in 1:ninv])
priors = vcat(priors, param)
nparameters += ninv
end

# dimensions would be total no of params,initial_nnθ for Lux namedTuples
ℓπ = LogTargetDensity(nparameters, prob, recon, st, dataset, priors,
phystd, l2std, autodiff, physdt, ninv,
initial_nnθ)

t0 = prob.tspan[1]
try
ℓπ(t0, initial_θ[1:(nparameters - ninv)])
catch err
if isa(err, DimensionMismatch)
throw(DimensionMismatch("Dimensions of the initial u0 and chain should match"))
else
throw(err)
end
end

# Define Hamiltonian system
metric = Metric(nparameters)
hamiltonian = Hamiltonian(metric, ℓπ, ForwardDiff)

# parallel sampling option
if nchains != 1
# Cache to store the chains
chains = Vector{Any}(undef, nchains)
statsc = Vector{Any}(undef, nchains)
samplesc = Vector{Any}(undef, nchains)

Threads.@threads for i in 1:nchains
# each chain has different initial NNparameter values(better posterior exploration)
initial_θ = vcat(randn(nparameters - ninv),
initial_θ[(nparameters - ninv + 1):end])
initial_ϵ = find_good_stepsize(hamiltonian, initial_θ)
integrator = integratorchoice(Integrator, initial_ϵ)
proposal = proposalchoice(Proposal, integrator)
adaptor = Adaptor(MassMatrixAdaptor(metric),
StepSizeAdaptor(targetacceptancerate, integrator))

samples, stats = sample(hamiltonian, proposal, initial_θ, draw_samples, adaptor;
progress = true, verbose = false)
samplesc[i] = samples
statsc[i] = stats

mcmc_chain = Chains(hcat(samples...)')
chains[i] = mcmc_chain
end

return chains, samplesc, statsc
else
initial_ϵ = find_good_stepsize(hamiltonian, initial_θ)
integrator = integratorchoice(Integrator, initial_ϵ)
proposal = proposalchoice(Proposal, integrator)
adaptor = Adaptor(MassMatrixAdaptor(metric),
StepSizeAdaptor(targetacceptancerate, integrator))

samples, stats = sample(hamiltonian, proposal, initial_θ, draw_samples, adaptor;
progress = true)
# return a chain(basic chain),samples and stats
matrix_samples = hcat(samples...)
mcmc_chain = Chains(matrix_samples')
return mcmc_chain, samples, stats
end
end
Loading
Loading