Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Bayesian PINN solver for ODE #692

Merged
merged 40 commits into from
Aug 17, 2023
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
d897e3f
begin project
AstitvaAggarwal May 11, 2023
b77488a
Merge branch 'develop' of https://github.com/AstitvaAggarwal/NeuralPD…
AstitvaAggarwal Jun 4, 2023
907c3ed
Started by defining Combined likelihood and priors
AstitvaAggarwal Jun 6, 2023
c8b5dba
model + custom likelihood + adding tests
AstitvaAggarwal Jun 11, 2023
81c47cc
formatted stuff
AstitvaAggarwal Jun 11, 2023
56e7afc
It works but predictions are superbad
AstitvaAggarwal Jun 17, 2023
978e115
formats, added advancedhmc version of BPINN, BPINN test file
AstitvaAggarwal Jun 29, 2023
a793b3d
Tests -BNN, BPINN,changes in var,full/half logprob
AstitvaAggarwal Jun 30, 2023
e817b93
custom logprob in turing_MCMC ,advancedHMC_MCMC-added NN parameter pr…
AstitvaAggarwal Jul 3, 2023
7980448
Trying to fix formatting done by mistake in commit-978e115
AstitvaAggarwal Jul 4, 2023
b760c60
The BPINN ODE solver works well for interpolation, now just needs gen…
AstitvaAggarwal Jul 6, 2023
7e3a3cd
Performs well in Extrapolation,Interpolation.
AstitvaAggarwal Jul 8, 2023
31d1f65
autodiff takes ~30mins sample- test file example
AstitvaAggarwal Jul 10, 2023
8329b8f
Less Numerical errors, faster sampling(autodiff),better EBFMI estimat…
AstitvaAggarwal Jul 11, 2023
15ec776
Inverse parameter estimation, seperate priors for weights and biases
AstitvaAggarwal Jul 15, 2023
c2cce0b
Inverse problem gives incorrect parameter esim
AstitvaAggarwal Jul 17, 2023
dfed5ee
Tests for inverse problem estimation
AstitvaAggarwal Jul 17, 2023
f77d04a
parameter estimation does not work
AstitvaAggarwal Jul 17, 2023
7cb6141
Parameter Estimation works, Custom choice for Prior Distributions
AstitvaAggarwal Jul 22, 2023
e8b1449
fixed a small error,cleaned BPINNtests.jl
AstitvaAggarwal Jul 22, 2023
b4e090d
Now works with Lux.jl chains,parallel sampled chains,good ode paramet…
AstitvaAggarwal Jul 25, 2023
9c07c97
Cleared up majority of Tests, BPINN ODE solver done.
AstitvaAggarwal Aug 1, 2023
2c70494
minor changes
AstitvaAggarwal Aug 1, 2023
9a47e36
Tests run automatically
AstitvaAggarwal Aug 7, 2023
c24ed46
Test Run automatically attempt-1
AstitvaAggarwal Aug 7, 2023
4468534
minor change
AstitvaAggarwal Aug 7, 2023
0bf8152
small change again fr fr
AstitvaAggarwal Aug 7, 2023
b23db49
changes from reviews-1
AstitvaAggarwal Aug 7, 2023
08883bd
Removed extra packages used in Tests
AstitvaAggarwal Aug 7, 2023
a42b6f7
Merge branch 'SciML:master' into develop
AstitvaAggarwal Aug 8, 2023
48dfa66
Test dependancies errors
AstitvaAggarwal Aug 9, 2023
bca2743
uses AHMC convienience structs
AstitvaAggarwal Aug 9, 2023
5357b33
Improved Tests - 1
AstitvaAggarwal Aug 10, 2023
3b35fc3
improved Tests - 2
AstitvaAggarwal Aug 11, 2023
d467250
Fixed the Problem in Test
AstitvaAggarwal Aug 12, 2023
f2e502a
improved Tests
AstitvaAggarwal Aug 12, 2023
e479c47
pls work
AstitvaAggarwal Aug 12, 2023
4df2d63
This shall work!
AstitvaAggarwal Aug 12, 2023
ec24839
Type checking for dataset, Solver usage made clear
AstitvaAggarwal Aug 15, 2023
0428dc5
changes from reviews
AstitvaAggarwal Aug 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "5.7.0"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
Expand All @@ -18,6 +19,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Integrals = "de52edbc-65ea-441a-8357-d3a637375a31"
IntegralsCubature = "c31f79ba-6e32-46d4-a52f-182a8ac42a54"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
Expand All @@ -33,6 +35,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

Expand Down
7 changes: 7 additions & 0 deletions src/NeuralPDE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ include("transform_inf_integral.jl")
include("discretize.jl")
include("neural_adapter.jl")

# fixes #682
include("turing_MCMC.jl")
include("advancedHMC_MCMC.jl")

export NNODE, TerminalPDEProblem, NNPDEHan, NNPDENS, NNRODE,
KolmogorovPDEProblem, NNKolmogorov, NNStopping, ParamKolmogorovPDEProblem,
KolmogorovParamDomain, NNParamKolmogorov,
Expand All @@ -62,4 +66,7 @@ export NNODE, TerminalPDEProblem, NNPDEHan, NNPDENS, NNRODE,
MiniMaxAdaptiveLoss,
LogOptions

#fixes #682
export bayesian_pinn_ode, ahmc_bayesian_pinn_ode

end # module
215 changes: 215 additions & 0 deletions src/advancedHMC_MCMC.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
using AdvancedHMC, ForwardDiff, LogDensityProblems
using LinearAlgebra

mutable struct LogTargetDensity{C, S}
dim::Int
prob::DiffEqBase.DEProblem
AstitvaAggarwal marked this conversation as resolved.
Show resolved Hide resolved
chain::C
st::S
dataset::Tuple{AbstractVector, AbstractVector}
priorsNN::Tuple{Float64, Float64}
phystd::Float64
l2std::Float64
autodiff::Bool
physdt::Float64

function LogTargetDensity(dim, prob, chain::Optimisers.Restructure, st, dataset,

Check warning on line 16 in src/advancedHMC_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/advancedHMC_MCMC.jl#L16

Added line #L16 was not covered by tests
priorsNN, phystd, l2std, autodiff, physdt)
new{typeof(chain), Nothing}(dim, prob, chain, nothing,

Check warning on line 18 in src/advancedHMC_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/advancedHMC_MCMC.jl#L18

Added line #L18 was not covered by tests
dataset, priorsNN,
phystd, l2std, autodiff,
physdt)
end
function LogTargetDensity(dim, prob, chain::Lux.AbstractExplicitLayer, st, dataset,

Check warning on line 23 in src/advancedHMC_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/advancedHMC_MCMC.jl#L23

Added line #L23 was not covered by tests
priorsNN, phystd, l2std, autodiff, physdt)
new{typeof(chain), typeof(st)}(dim, prob, re, st,

Check warning on line 25 in src/advancedHMC_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/advancedHMC_MCMC.jl#L25

Added line #L25 was not covered by tests
dataset, priorsNN,
phystd, l2std, autodiff,
physdt)
end
end

function LogDensityProblems.logdensity(Tar::LogTargetDensity, θ)
return physloglikelihood(Tar, θ) + L2LossData(Tar, θ) + priorweights(Tar, θ)

Check warning on line 33 in src/advancedHMC_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/advancedHMC_MCMC.jl#L32-L33

Added lines #L32 - L33 were not covered by tests
end

LogDensityProblems.dimension(Tar::LogTargetDensity) = Tar.dim

Check warning on line 36 in src/advancedHMC_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/advancedHMC_MCMC.jl#L36

Added line #L36 was not covered by tests

function LogDensityProblems.capabilities(::LogTargetDensity)
LogDensityProblems.LogDensityOrder{0}()

Check warning on line 39 in src/advancedHMC_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/advancedHMC_MCMC.jl#L38-L39

Added lines #L38 - L39 were not covered by tests
end

function generate_Tar(chain::Lux.AbstractExplicitLayer, init_params)
θ, st = Lux.setup(Random.default_rng(), chain)
return ComponentArrays.ComponentArray(init_params), chain, st

Check warning on line 44 in src/advancedHMC_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/advancedHMC_MCMC.jl#L42-L44

Added lines #L42 - L44 were not covered by tests
end

function generate_Tar(chain::Lux.AbstractExplicitLayer, init_params::Nothing)
θ, st = Lux.setup(Random.default_rng(), chain)
return ComponentArrays.ComponentArray(θ), chain, st

Check warning on line 49 in src/advancedHMC_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/advancedHMC_MCMC.jl#L47-L49

Added lines #L47 - L49 were not covered by tests
end

function generate_Tar(chain::Flux.Chain, init_params)
θ, re = Flux.destructure(chain)

Check warning on line 53 in src/advancedHMC_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/advancedHMC_MCMC.jl#L52-L53

Added lines #L52 - L53 were not covered by tests
# if init_params==nothing
# θ = collect(Float64, θ)
# return θ, re, nothing
# else
# return init_params, re, nothing
# end
return init_params, re, nothing

Check warning on line 60 in src/advancedHMC_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/advancedHMC_MCMC.jl#L60

Added line #L60 was not covered by tests
end
function generate_Tar(chain::Flux.Chain, init_params::Nothing)
θ, re = Flux.destructure(chain)

Check warning on line 63 in src/advancedHMC_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/advancedHMC_MCMC.jl#L62-L63

Added lines #L62 - L63 were not covered by tests
# find_good_stepsize takes only float64?
θ = collect(Float64, θ)
return θ, re, nothing

Check warning on line 66 in src/advancedHMC_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/advancedHMC_MCMC.jl#L65-L66

Added lines #L65 - L66 were not covered by tests
end

# nn OUTPUT AT t
function (f::LogTargetDensity{C, S})(t::AbstractVector,

Check warning on line 70 in src/advancedHMC_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/advancedHMC_MCMC.jl#L70

Added line #L70 was not covered by tests
θ) where {C <: Optimisers.Restructure, S}
# f.prob.u0 .+ (t .- f.prob.tspan[1]) .* vec(f.chain(θ)(t'))
f.prob.u0 .+ (t' .- f.prob.tspan[1]) .* f.chain(θ)(adapt(parameterless_type(θ), t'))

Check warning on line 73 in src/advancedHMC_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/advancedHMC_MCMC.jl#L73

Added line #L73 was not covered by tests
# f.prob.u0 .+ (t' .- f.prob.tspan[1]) .* f.chain(θ)(t')
end

function (f::LogTargetDensity{C, S})(t::AbstractVector,

Check warning on line 77 in src/advancedHMC_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/advancedHMC_MCMC.jl#L77

Added line #L77 was not covered by tests
θ) where {C <: Lux.AbstractExplicitLayer, S}
# Batch via data as row vectors
y, st = f.chain(adapt(parameterless_type(ComponentArrays.getdata(θ)), t'), θ, f.st)
ChainRulesCore.@ignore_derivatives f.st = st
f.u0 .+ (t' .- f.t0) .* y

Check warning on line 82 in src/advancedHMC_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/advancedHMC_MCMC.jl#L80-L82

Added lines #L80 - L82 were not covered by tests
end

function (f::LogTargetDensity{C, S})(t::Number,

Check warning on line 85 in src/advancedHMC_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/advancedHMC_MCMC.jl#L85

Added line #L85 was not covered by tests
θ) where {C <: Optimisers.Restructure, S}
f.prob.u0 + (t - f.prob.tspan[1]) * first(f.chain(θ)(adapt(parameterless_type(θ), [t])))

Check warning on line 87 in src/advancedHMC_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/advancedHMC_MCMC.jl#L87

Added line #L87 was not covered by tests
end

using ForwardDiff: Chunk
# ODE DU/DX
function NNodederi(phi::LogTargetDensity, t::AbstractVector, θ, autodiff::Bool)
if autodiff

Check warning on line 93 in src/advancedHMC_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/advancedHMC_MCMC.jl#L92-L93

Added lines #L92 - L93 were not covered by tests
# diag(ForwardDiff.jacobian(t -> phi(t, θ), t))
[ForwardDiff.derivative(ti -> phi(ti, θ), ti) for ti in t]

Check warning on line 95 in src/advancedHMC_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/advancedHMC_MCMC.jl#L95

Added line #L95 was not covered by tests
Vaibhavdixit02 marked this conversation as resolved.
Show resolved Hide resolved
else
(phi(t .+ sqrt(eps(eltype(t))), θ) - phi(t, θ)) ./ sqrt(eps(eltype(t)))

Check warning on line 97 in src/advancedHMC_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/advancedHMC_MCMC.jl#L97

Added line #L97 was not covered by tests
end
end

function physloglikelihood(Tar::LogTargetDensity, θ)
p = Tar.prob.p
f = Tar.prob.f
var = Tar.phystd^2
autodiff = Tar.autodiff
dt = Tar.physdt
t = collect(Float64, Tar.prob.tspan[1]:dt:Tar.prob.tspan[2])

Check warning on line 107 in src/advancedHMC_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/advancedHMC_MCMC.jl#L101-L107

Added lines #L101 - L107 were not covered by tests

# compare derivatives
out = Tar(t, θ)

Check warning on line 110 in src/advancedHMC_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/advancedHMC_MCMC.jl#L110

Added line #L110 was not covered by tests
# print(size(out))
physsol = [f(out[i], p, t[i]) for i in eachindex(out)]
nnsol = NNodederi(Tar, t, θ, autodiff)

Check warning on line 113 in src/advancedHMC_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/advancedHMC_MCMC.jl#L112-L113

Added lines #L112 - L113 were not covered by tests

# distribution's mean is forwarddiff diag(jacobian)
# if autodiff
# nnsol = vec(diag(nnsol))
# end
# typeof(nnsol), size(nnsol), typeof(physsol), size(physsol)
n = length(nnsol)
return logpdf(MvNormal(nnsol, Diagonal(var .* ones(n))), physsol)

Check warning on line 121 in src/advancedHMC_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/advancedHMC_MCMC.jl#L120-L121

Added lines #L120 - L121 were not covered by tests
end

# standard MvNormal Dist Assume
function L2LossData(Tar::LogTargetDensity, θ)
nn = vec(Tar(Tar.dataset[2], θ))
n = length(nn)
var = Tar.l2std^2
return logpdf(MvNormal(nn, Diagonal(var .* ones(n))), Tar.dataset[1])

Check warning on line 129 in src/advancedHMC_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/advancedHMC_MCMC.jl#L125-L129

Added lines #L125 - L129 were not covered by tests
end

function priorweights(Tar::LogTargetDensity, θ)
params = Tar.priorsNN
return logpdf(MvNormal(θ, Diagonal(params[2]^2 .* ones(length(θ)))),

Check warning on line 134 in src/advancedHMC_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/advancedHMC_MCMC.jl#L132-L134

Added lines #L132 - L134 were not covered by tests
params[1] * ones(length(θ)))
end

# dataset would be (x̂,t)
# priors: pdf for W,b + pdf for ODE params
function ahmc_bayesian_pinn_ode(prob::DiffEqBase.DEProblem, chain::Flux.Chain,

Check warning on line 140 in src/advancedHMC_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/advancedHMC_MCMC.jl#L140

Added line #L140 was not covered by tests
dataset::Tuple{AbstractVector, AbstractVector};
init_params = nothing,
draw_samples = 1000, l2std = 0.08,
phystd = 0.08, priorsNN = (0, 2), autodiff = false,
physdt = 1 / 20.0f0,
Proposal = AdvancedHMC.NUTS{MultinomialTS,
GeneralisedNoUTurn},
Adaptor = StanHMCAdaptor, targetacceptancerate = 0.8,
Integrator = Leapfrog, Metric = DiagEuclideanMetric)
# NN parameter prior mean and variance(PriorsNN must be a tuple)
if isinplace(prob)
throw(error("The BPINN ODE solver only supports out-of-place ODE definitions, i.e. du=f(u,p,t)."))

Check warning on line 152 in src/advancedHMC_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/advancedHMC_MCMC.jl#L151-L152

Added lines #L151 - L152 were not covered by tests
end

if chain isa Lux.AbstractExplicitLayer || chain isa Flux.Chain
initial_θ, recon, st = generate_Tar(chain, init_params)

Check warning on line 156 in src/advancedHMC_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/advancedHMC_MCMC.jl#L155-L156

Added lines #L155 - L156 were not covered by tests
else
error("Only Lux.AbstractExplicitLayer and Flux.Chain neural networks are supported")

Check warning on line 158 in src/advancedHMC_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/advancedHMC_MCMC.jl#L158

Added line #L158 was not covered by tests
end

# adding ode parameter estimation?
nparameters = length(initial_θ)
ℓπ = LogTargetDensity(nparameters, prob, recon, st, dataset, priorsNN,

Check warning on line 163 in src/advancedHMC_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/advancedHMC_MCMC.jl#L162-L163

Added lines #L162 - L163 were not covered by tests
phystd, l2std, autodiff, physdt)

# [add f(t,θ) for t being a number]
# try
# ℓπ(t0, initial_θ)
# catch err
# if isa(err, DimensionMismatch)
# throw(DimensionMismatch("Dimensions of the initial u0 and chain should match"))
# else
# throw(err)
# end
# end

# return physloglikelihood(ℓπ, initial_θ)
n_samples = draw_samples
metric = Metric(nparameters)
hamiltonian = Hamiltonian(metric, ℓπ, ForwardDiff)
initial_ϵ = find_good_stepsize(hamiltonian, initial_θ)

Check warning on line 181 in src/advancedHMC_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/advancedHMC_MCMC.jl#L178-L181

Added lines #L178 - L181 were not covered by tests

# choices for integrators?
# [define n for JL and a for TL]
# if Integrator == JitteredLeapfrog(n)
# integrator = JitteredLeapfrog(initial_ϵ, n)
# elseif Integrator == TemperedLeapfrog(a)
# integrator == TemperedLeapfrog(initial_ϵ, a)
# else
# integrator = Leapfrog(initial_ϵ)
# end

integrator = Leapfrog(initial_ϵ)
proposal = Proposal(integrator)
adaptor = Adaptor(MassMatrixAdaptor(metric),

Check warning on line 195 in src/advancedHMC_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/advancedHMC_MCMC.jl#L193-L195

Added lines #L193 - L195 were not covered by tests
StepSizeAdaptor(targetacceptancerate, integrator))

samples, stats = sample(hamiltonian, proposal, initial_θ,

Check warning on line 198 in src/advancedHMC_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/advancedHMC_MCMC.jl#L198

Added line #L198 was not covered by tests
n_samples, adaptor;
progress = true)

# return a chain(basic chain),samples and stats
matrix_samples = hcat(samples...)
mcmc_chain = Chains(matrix_samples')
return mcmc_chain, samples, stats

Check warning on line 205 in src/advancedHMC_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/advancedHMC_MCMC.jl#L203-L205

Added lines #L203 - L205 were not covered by tests
end

# non vectorised functions(i noticed sampling time increase)
# function NNodederi(phi::odeByNN, t::Number, θ, autodiff::Bool)
# if autodiff
# ForwardDiff.jacobian(t -> phi(t, θ), t)
# else
# (phi(t + sqrt(eps(typeof(t))), θ) - phi(t, θ)) / sqrt(eps(typeof(t)))
# end
# end
125 changes: 125 additions & 0 deletions src/turing_MCMC.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
using Turing, Distributions
struct odeByNN{C, T, U}
chain::C
u0::U
t0::T

function odeByNN(re::Optimisers.Restructure, t, u0)
new{typeof(re), typeof(t), typeof(u0)}(re, t, u0)

Check warning on line 8 in src/turing_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/turing_MCMC.jl#L7-L8

Added lines #L7 - L8 were not covered by tests
end
end

function generate_phi(chain::Flux.Chain, t, u0, init_params::Nothing)
θ, re = Flux.destructure(chain)
odeByNN(re, t, u0), θ

Check warning on line 14 in src/turing_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/turing_MCMC.jl#L12-L14

Added lines #L12 - L14 were not covered by tests
end

# nn OUTPUT AT t
function (f::odeByNN{C, T, U})(t::Number,

Check warning on line 18 in src/turing_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/turing_MCMC.jl#L18

Added line #L18 was not covered by tests
θ) where {C <: Optimisers.Restructure, T, U}
f.u0 + (t - f.t0) * first(f.chain(θ)(adapt(parameterless_type(θ), [t])))

Check warning on line 20 in src/turing_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/turing_MCMC.jl#L20

Added line #L20 was not covered by tests
end

function (f::odeByNN{C, T, U})(t::AbstractVector,

Check warning on line 23 in src/turing_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/turing_MCMC.jl#L23

Added line #L23 was not covered by tests
θ) where {C <: Optimisers.Restructure, T, U}
f.u0 .+ (t .- f.t0) .* f.chain(θ)(adapt(parameterless_type(θ), t'))

Check warning on line 25 in src/turing_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/turing_MCMC.jl#L25

Added line #L25 was not covered by tests
end

# ODE DU/DX
function NNodederi(phi::odeByNN, t::Number, θ, autodiff::Bool)
if autodiff
ForwardDiff.jacobian(t -> phi(t, θ), t)

Check warning on line 31 in src/turing_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/turing_MCMC.jl#L29-L31

Added lines #L29 - L31 were not covered by tests
else
(phi(t + sqrt(eps(typeof(t))), θ) - phi(t, θ)) / sqrt(eps(typeof(t)))

Check warning on line 33 in src/turing_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/turing_MCMC.jl#L33

Added line #L33 was not covered by tests
end
end

function NNodederi(phi::odeByNN, t::AbstractVector, θ, autodiff::Bool)
if autodiff
ForwardDiff.jacobian(t -> phi(t, θ), t)

Check warning on line 39 in src/turing_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/turing_MCMC.jl#L37-L39

Added lines #L37 - L39 were not covered by tests
else
(phi(t .+ sqrt(eps(eltype(t))), θ) - phi(t, θ)) ./ sqrt(eps(eltype(t)))

Check warning on line 41 in src/turing_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/turing_MCMC.jl#L41

Added line #L41 was not covered by tests
end
end

function physloglikelihood(chain::Any, prob::DiffEqBase.DEProblem,

Check warning on line 45 in src/turing_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/turing_MCMC.jl#L45

Added line #L45 was not covered by tests
t::AbstractVector; var = 0.5)
u0 = prob.u0
t0 = t[1]
p = prob.p
f = prob.f

Check warning on line 50 in src/turing_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/turing_MCMC.jl#L47-L50

Added lines #L47 - L50 were not covered by tests
# let this be(will fix)
autodiff = false

Check warning on line 52 in src/turing_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/turing_MCMC.jl#L52

Added line #L52 was not covered by tests

phi, initparams = generate_phi(chain, t0, u0, nothing)

Check warning on line 54 in src/turing_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/turing_MCMC.jl#L54

Added line #L54 was not covered by tests

μ = vec([f(phi(t[i], initparams), p, u0) for i in eachindex(t)])
physsol = vec([NNodederi(phi, t[i], initparams, autodiff) for i in eachindex(t)])

Check warning on line 57 in src/turing_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/turing_MCMC.jl#L56-L57

Added lines #L56 - L57 were not covered by tests
# print(typeof(μ))
# print(typeof(physsol))
# To reduce heap allocations but some erros came up
# μ = similar(t)
# physsol = similar(t)
# μ = f(phi(t, initparams), p, u0)
# physsol = NNodederi(phi, t, initparams, autodiff)
# for i in eachindex(t)
# μ[i] = f(phi(t[i], initparams), p, u0)
# physsol[i] = NNodederi(phi, t[i], initparams, autodiff)
# end
return sum(abs2, (μ .- physsol) ./ (-2 * (var^2)))

Check warning on line 69 in src/turing_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/turing_MCMC.jl#L69

Added line #L69 was not covered by tests
# return loglikelihood(MvNormal(physsol - μ,
# Diagonal(var .* ones(Float64, length(μ)))))
end

# dataset would be (x̂,t)
# priors: pdf for W,b + pdf for ODE params
function bayesian_pinn_ode(prob::DiffEqBase.DEProblem, chain, dataset;

Check warning on line 76 in src/turing_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/turing_MCMC.jl#L76

Added line #L76 was not covered by tests
sampling_strategy = Turing.NUTS(0.65), num_samples = 1000)
param_initial, recon = Flux.destructure(chain)
nparameters = length(param_initial)

Check warning on line 79 in src/turing_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/turing_MCMC.jl#L78-L79

Added lines #L78 - L79 were not covered by tests

if isinplace(prob)
throw(error("The BPINN ODE solver only supports out-of-place ODE definitions, i.e. du=f(u,p,t)."))

Check warning on line 82 in src/turing_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/turing_MCMC.jl#L81-L82

Added lines #L81 - L82 were not covered by tests
end

alpha = 0.09
sig = 0.2

Check warning on line 86 in src/turing_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/turing_MCMC.jl#L85-L86

Added lines #L85 - L86 were not covered by tests

p = prob.p
f = prob.f
u0 = prob.u0
t0 = prob.tspan[1]
t = dataset[2]

Check warning on line 92 in src/turing_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/turing_MCMC.jl#L88-L92

Added lines #L88 - L92 were not covered by tests
# compare derivatives
phi, initparams = generate_phi(chain, t0, u0, nothing)
physsol = vec([f(phi(t[i], param_initial), p, t[i]) for i in eachindex(t)])

Check warning on line 95 in src/turing_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/turing_MCMC.jl#L94-L95

Added lines #L94 - L95 were not covered by tests

DynamicPPL.@model function bayes_pinn(dataset, physsol, phi)

Check warning on line 97 in src/turing_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/turing_MCMC.jl#L97

Added line #L97 was not covered by tests
# parameter estimation?

# prior for NN parameters(not included bias yet?) - P(Θ)
nnparameters ~ MvNormal(zeros(nparameters), Diagonal(sig .* ones(nparameters)))
preds = [phi(ti, nnparameters) for ti in dataset[2]]

Check warning on line 102 in src/turing_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/turing_MCMC.jl#L101-L102

Added lines #L101 - L102 were not covered by tests

# # likelihood for NN pred vs Equation satif - P(phys | Θ)
# if DynamicPPL.leafcontext(__context__) !== Turing.PriorContext()
# Turing.@addlogprob! physloglikelihood(nn, prob, dataset[2], var = sig)
# end
physsol ~ MvNormal(vec(preds), Diagonal(sig .* ones(length(dataset[2]))))

Check warning on line 108 in src/turing_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/turing_MCMC.jl#L108

Added line #L108 was not covered by tests
# # likelihood for dataset vs NN pred - P( X̄ | Θ)
dataset[1] ~ MvNormal(vec(preds), Diagonal(sig .* ones(length(dataset[2]))))

Check warning on line 110 in src/turing_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/turing_MCMC.jl#L110

Added line #L110 was not covered by tests
end

model = bayes_pinn(dataset, physsol, phi)
ch = sample(model, sampling_strategy, num_samples)
return ch

Check warning on line 115 in src/turing_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/turing_MCMC.jl#L113-L115

Added lines #L113 - L115 were not covered by tests
end

# ----------need speed up
# the phase point struct
# create custom distri?
# using chain with updated parameters in physloglikelihood and L2LossData

# ----------more code and compatibility
# allow options for prior,likelihood distributions
# add support for Lux chains
Loading
Loading