diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 052a8bdd4d..0a89ffe4e3 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -13,14 +13,15 @@ jobs: fail-fast: false matrix: group: + - ODEBPINN - NNPDE1 - NNPDE2 - AdaptiveLoss - Logging - Forward version: - - '1' - - '1.6' + - "1" + - "1.6" steps: - uses: actions/checkout@v3 - uses: julia-actions/setup-julia@v1 diff --git a/Project.toml b/Project.toml index 2c1ceaa3c6..812263960b 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "5.8.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" @@ -15,10 +16,13 @@ DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" Integrals = "de52edbc-65ea-441a-8357-d3a637375a31" IntegralsCubature = "c31f79ba-6e32-46d4-a52f-182a8ac42a54" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" +MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78" Optim = "429524aa-4258-5aef-a3af-852621145aeb" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" diff --git a/src/NeuralPDE.jl b/src/NeuralPDE.jl index f61544274f..c5cddfddff 100644 --- a/src/NeuralPDE.jl +++ b/src/NeuralPDE.jl @@ -22,6 +22,8 @@ using DomainSets using Symbolics using Symbolics: wrap, unwrap, arguments, operation using SymbolicUtils +using AdvancedHMC, LogDensityProblems, LinearAlgebra, Functors, MCMCChains + import ModelingToolkit: value, nameof, toexpr, build_expr, expand_derivatives import DomainSets: Domain, ClosedInterval import ModelingToolkit: Interval, infimum, supremum #,Ball @@ -47,6 +49,7 @@ include("rode_solve.jl") include("transform_inf_integral.jl") include("discretize.jl") include("neural_adapter.jl") +include("advancedHMC_MCMC.jl") export NNODE, TerminalPDEProblem, NNPDEHan, NNPDENS, NNRODE, KolmogorovPDEProblem, NNKolmogorov, NNStopping, ParamKolmogorovPDEProblem, @@ -60,6 +63,6 @@ export NNODE, TerminalPDEProblem, NNPDEHan, NNPDENS, NNRODE, build_symbolic_equation, build_symbolic_loss_function, symbolic_discretize, AbstractAdaptiveLoss, NonAdaptiveLoss, GradientScaleAdaptiveLoss, MiniMaxAdaptiveLoss, - LogOptions + LogOptions, ahmc_bayesian_pinn_ode end # module diff --git a/src/advancedHMC_MCMC.jl b/src/advancedHMC_MCMC.jl new file mode 100644 index 0000000000..6b7a1cfce3 --- /dev/null +++ b/src/advancedHMC_MCMC.jl @@ -0,0 +1,480 @@ +mutable struct LogTargetDensity{C, S, I, P <: Vector{<:Distribution}, + D <: Vector{<:Vector{<:AbstractFloat}} + } + dim::Int + prob::DiffEqBase.ODEProblem + chain::C + st::S + dataset::D + priors::P + phystd::Vector{Float64} + l2std::Vector{Float64} + autodiff::Bool + physdt::Float64 + extraparams::Int + init_params::I + + function LogTargetDensity(dim, prob, chain::Optimisers.Restructure, st, dataset, + priors, phystd, l2std, autodiff, physdt, extraparams, + init_params::AbstractVector) + new{typeof(chain), Nothing, typeof(init_params), typeof(priors), typeof(dataset)}(dim, + prob, + chain, + nothing, + dataset, + priors, + phystd, + l2std, + autodiff, + physdt, + extraparams, + init_params) + end + function LogTargetDensity(dim, prob, chain::Lux.AbstractExplicitLayer, st, dataset, + priors, phystd, l2std, autodiff, physdt, extraparams, + init_params::NamedTuple) + new{typeof(chain), typeof(st), typeof(init_params), typeof(priors), typeof(dataset) + }(dim, + prob, + chain, st, + dataset, priors, + phystd, l2std, + autodiff, + physdt, + extraparams, + init_params) + end +end + +function LogDensityProblems.logdensity(Tar::LogTargetDensity, θ) + return physloglikelihood(Tar, θ) + priorweights(Tar, θ) + L2LossData(Tar, θ) +end + +LogDensityProblems.dimension(Tar::LogTargetDensity) = Tar.dim + +function LogDensityProblems.capabilities(::LogTargetDensity) + LogDensityProblems.LogDensityOrder{1}() +end + +function generate_Tar(chain::Lux.AbstractExplicitLayer, init_params) + θ, st = Lux.setup(Random.default_rng(), chain) + return init_params, chain, st +end + +function generate_Tar(chain::Lux.AbstractExplicitLayer, init_params::Nothing) + θ, st = Lux.setup(Random.default_rng(), chain) + return θ, chain, st +end + +function generate_Tar(chain::Flux.Chain, init_params) + θ, re = Flux.destructure(chain) + return init_params, re, nothing +end + +function generate_Tar(chain::Flux.Chain, init_params::Nothing) + θ, re = Flux.destructure(chain) + # find_good_stepsize,phasepoint takes only float64 + return θ, re, nothing +end + +# For vector of samples to Lux ComponentArrays +function vector_to_parameters(ps_new::AbstractVector, ps::NamedTuple) + @assert length(ps_new) == Lux.parameterlength(ps) + i = 1 + function get_ps(x) + z = reshape(view(ps_new, i:(i + length(x) - 1)), size(x)) + i += length(x) + return z + end + return Functors.fmap(get_ps, ps) +end + +# nn OUTPUT AT t +function (f::LogTargetDensity{C, S})(t::AbstractVector, + θ) where {C <: Optimisers.Restructure, S} + f.prob.u0 .+ (t' .- f.prob.tspan[1]) .* f.chain(θ)(adapt(parameterless_type(θ), t')) +end + +function (f::LogTargetDensity{C, S})(t::AbstractVector, + θ) where {C <: Lux.AbstractExplicitLayer, S} + θ = vector_to_parameters(θ, f.init_params) + y, st = f.chain(adapt(parameterless_type(ComponentArrays.getdata(θ)), t'), θ, f.st) + ChainRulesCore.@ignore_derivatives f.st = st + f.prob.u0 .+ (t' .- f.prob.tspan[1]) .* y +end + +function (f::LogTargetDensity{C, S})(t::Number, + θ) where {C <: Optimisers.Restructure, S} + # must handle paired odes hence u0 broadcasted + f.prob.u0 .+ (t - f.prob.tspan[1]) * f.chain(θ)(adapt(parameterless_type(θ), [t])) +end + +function (f::LogTargetDensity{C, S})(t::Number, + θ) where {C <: Lux.AbstractExplicitLayer, S} + θ = vector_to_parameters(θ, f.init_params) + y, st = f.chain(adapt(parameterless_type(ComponentArrays.getdata(θ)), [t]), θ, f.st) + ChainRulesCore.@ignore_derivatives f.st = st + f.prob.u0 .+ (t .- f.prob.tspan[1]) .* y +end + +# ODE DU/DX +function NNodederi(phi::LogTargetDensity, t::AbstractVector, θ, autodiff::Bool) + if autodiff + hcat(ForwardDiff.derivative.(ti -> phi(ti, θ), t)...) + else + (phi(t .+ sqrt(eps(eltype(t))), θ) - phi(t, θ)) ./ sqrt(eps(eltype(t))) + end +end + +# physics loglikelihood over problem timespan +function physloglikelihood(Tar::LogTargetDensity, θ) + f = Tar.prob.f + p = Tar.prob.p + dt = Tar.physdt + + # Timepoints to enforce Physics + if isempty(Tar.dataset[end]) + t = collect(eltype(dt), Tar.prob.tspan[1]:dt:Tar.prob.tspan[2]) + else + t = vcat(collect(eltype(dt), Tar.prob.tspan[1]:dt:Tar.prob.tspan[2]), + Tar.dataset[end]) + end + + # parameter estimation chosen or not + if Tar.extraparams > 0 + ode_params = Tar.extraparams == 1 ? + θ[((length(θ) - Tar.extraparams) + 1):length(θ)][1] : + θ[((length(θ) - Tar.extraparams) + 1):length(θ)] + else + ode_params = p == SciMLBase.NullParameters() ? [] : p + end + + # train for NN deriative upon dataset as well as beyond but within timespan + autodiff = Tar.autodiff + + # compare derivatives(matrix) + out = Tar(t, θ[1:(length(θ) - Tar.extraparams)]) + + # reject samples case + if any(isinf, out[:, 1]) || any(isinf, ode_params) + return -Inf + end + + # this is a vector{vector{dx,dy}}(handle case single u(float passed)) + if length(out[:, 1]) == 1 + physsol = [f(out[:, i][1], + ode_params, + t[i]) + for i in 1:length(out[1, :])] + else + physsol = [f(out[:, i], + ode_params, + t[i]) + for i in 1:length(out[1, :])] + end + physsol = reduce(hcat, physsol) + + # convert to matrix as nnsol + nnsol = NNodederi(Tar, t, θ[1:(length(θ) - Tar.extraparams)], autodiff) + + physlogprob = 0 + for i in 1:length(Tar.prob.u0) + # can add phystd[i] for u[i] + physlogprob += logpdf(MvNormal(nnsol[i, :], + LinearAlgebra.Diagonal(map(abs2, + Tar.phystd[i] .* + ones(length(physsol[i, :]))))), + physsol[i, :]) + end + return physlogprob +end + +# L2 losses loglikelihood(needed mainly for ODE parameter estimation) +function L2LossData(Tar::LogTargetDensity, θ) + # check if dataset is provided + if isempty(Tar.dataset[end]) || Tar.extraparams == 0 + return 0 + else + # matrix(each row corresponds to vector u's rows) + nn = Tar(Tar.dataset[end], θ[1:(length(θ) - Tar.extraparams)]) + + L2logprob = 0 + for i in 1:length(Tar.prob.u0) + # for u[i] ith vector must be added to dataset,nn[1,:] is the dx in lotka_volterra + L2logprob += logpdf(MvNormal(nn[i, :], + LinearAlgebra.Diagonal(map(abs2, + Tar.l2std[i] .* + ones(length(Tar.dataset[i]))))), + Tar.dataset[i]) + end + return L2logprob + end +end + +# priors for NN parameters + ODE constants +function priorweights(Tar::LogTargetDensity, θ) + allparams = Tar.priors + # nn weights + nnwparams = allparams[1] + + if Tar.extraparams > 0 + # Vector of ode parameters priors + invpriors = allparams[2:end] + + invlogpdf = sum(logpdf(invpriors[length(θ) - i + 1], θ[i]) + for i in (length(θ) - Tar.extraparams + 1):length(θ); init = 0.0) + + return (invlogpdf + + + logpdf(nnwparams, θ[1:(length(θ) - Tar.extraparams)])) + else + return logpdf(nnwparams, θ) + end +end + +function kernelchoice(Kernel, max_depth, Δ_max, n_leapfrog, δ, λ) + if Kernel == HMC + Kernel(n_leapfrog) + elseif Kernel == HMCDA + Kernel(δ, λ) + else + Kernel(δ, max_depth = max_depth, Δ_max = Δ_max) + end +end + +function integratorchoice(Integrator, initial_ϵ, jitter_rate, + tempering_rate) + if Integrator == JitteredLeapfrog + Integrator(initial_ϵ, jitter_rate) + elseif Integrator == TemperedLeapfrog + Integrator(initial_ϵ, tempering_rate) + else + Integrator(initial_ϵ) + end +end + +function adaptorchoice(Adaptor, mma, ssa) + if Adaptor != AdvancedHMC.NoAdaptation() + Adaptor(mma, ssa) + else + AdvancedHMC.NoAdaptation() + end +end + +""" +```julia +ahmc_bayesian_pinn_ode(prob, chain; + dataset = [[]],init_params = nothing, + draw_samples = 1000, physdt = 1 / 20.0f0,l2std = [0.05], + phystd = [0.05], priorsNNw = (0.0, 2.0), + param = [],nchains = 1,autodiff = false, Kernel = HMC, + Integrator = Leapfrog, Adaptor = StanHMCAdaptor, + targetacceptancerate = 0.8, Metric = DiagEuclideanMetric, + jitter_rate = 3.0, tempering_rate = 3.0, max_depth = 10, + Δ_max = 1000, n_leapfrog = 10, δ = 0.65, λ = 0.3, + progress = false,verbose = false) +``` + +## Example +linear = (u, p, t) -> -u / p[1] + exp(t / p[2]) * cos(t) +tspan = (0.0, 10.0) +u0 = 0.0 +p = [5.0, -5.0] +prob = ODEProblem(linear, u0, tspan, p) + +# CREATE DATASET (Necessity for accurate Parameter estimation) +sol = solve(prob, Tsit5(); saveat = 0.05) +u = sol.u[1:100] +time = sol.t[1:100] + +# dataset and BPINN create +x̂ = collect(Float64, Array(u) + 0.05 * randn(size(u))) +dataset = [x̂, time] + +chainflux1 = Flux.Chain(Flux.Dense(1, 5, tanh), Flux.Dense(5, 5, tanh), Flux.Dense(5, 1) + +# simply solving ode here hence better to not pass dataset(uses ode params specified in prob) +fh_mcmc_chainflux1, fhsamplesflux1, fhstatsflux1 = ahmc_bayesian_pinn_ode(prob,chainflux1, + dataset = dataset, + draw_samples = 1500, + l2std = [0.05], + phystd = [0.05], + priorsNNw = (0.0,3.0)) + +# solving ode + estimating parameters hence dataset needed to optimize parameters upon + Pior Distributions for ODE params +fh_mcmc_chainflux2, fhsamplesflux2, fhstatsflux2 = ahmc_bayesian_pinn_ode(prob,chainflux1, + dataset = dataset, + draw_samples = 1500, + l2std = [0.05], + phystd = [0.05], + priorsNNw = (0.0,3.0), + param = [Normal(6.5,0.5),Normal(-3,0.5)]) + +## NOTES +Dataset is required for accurate Parameter estimation + solving equations +Incase you are only solving the Equations for solution, do not provide dataset + +## Positional Arguments +prob -> DEProblem(out of place and the function signature should be f(u,p,t) +chain -> Lux/Flux Neural Netork which would be made the Bayesian PINN +dataset -> Vector containing Vectors of corresponding u,t values +init_params -> intial parameter values for BPINN (ideally for multiple chains different initializations preferred) +nchains -> number of chains you want to sample (random initialisation of params by default) +draw_samples -> number of samples to be drawn in the MCMC algorithms (warmup samples are ~2/3 of draw samples) +l2std -> standard deviation of BPINN predicition against L2 losses/Dataset +phystd -> standard deviation of BPINN predicition against Chosen Underlying ODE System +priorsNNw -> Vector of [mean, std] for BPINN parameter. Weights and Biases of BPINN are Normal Distributions by default +param -> Vector of chosen ODE parameters Distributions in case of Inverse problems. +autodiff -> Boolean Value for choice of Derivative Backend(default is numerical) +physdt -> Timestep for approximating ODE in it's Time domain. (1/20.0 by default) + +# AHMC still developing convenience structs so might need changes on new releases. +Kernel -> Choice of MCMC Sampling Algorithm (AdvancedHMC.jl implemenations HMC/NUTS/HMCDA) +targetacceptancerate -> Target percentage(in decimal) of iterations in which the proposals were accepted(0.8 by default) +Integrator(jitter_rate, tempering_rate), Metric, Adaptor -> https://turinglang.org/AdvancedHMC.jl/stable/ +max_depth -> Maximum doubling tree depth (NUTS) +Δ_max -> Maximum divergence during doubling tree (NUTS) +n_leapfrog -> number of leapfrog steps for HMC +δ -> target acceptance probability for NUTS/HMCDA +λ -> target trajectory length for HMCDA +progress -> controls whether to show the progress meter or not. +verbose -> controls the verbosity. (Sample call args in AHMC) + +## References + +""" + +# dataset would be (x̂,t) +# priors: pdf for W,b + pdf for ODE params +function ahmc_bayesian_pinn_ode(prob::DiffEqBase.DEProblem, chain; + dataset = [[]], + init_params = nothing, draw_samples = 1000, + physdt = 1 / 20.0, l2std = [0.05], + phystd = [0.05], priorsNNw = (0.0, 2.0), + param = [], nchains = 1, + autodiff = false, + Kernel = HMC, Integrator = Leapfrog, + Adaptor = StanHMCAdaptor, targetacceptancerate = 0.8, + Metric = DiagEuclideanMetric, jitter_rate = 3.0, + tempering_rate = 3.0, max_depth = 10, Δ_max = 1000, + n_leapfrog = 10, δ = 0.65, λ = 0.3, progress = false, + verbose = false) + + # NN parameter prior mean and variance(PriorsNN must be a tuple) + if isinplace(prob) + throw(error("The BPINN ODE solver only supports out-of-place ODE definitions, i.e. du=f(u,p,t).")) + end + + if dataset != [] && + (length(dataset) < 2 || !(typeof(dataset) <: Vector{<:Vector{<:AbstractFloat}})) + throw(error("Invalid dataset. dataset would be timeseries (x̂,t) where type: Vector{Vector{AbstractFloat}")) + end + + if dataset != [] && param == [] + println("Dataset is only needed for Parameter Estimation + Forward Problem, not in only Forward Problem case.") + elseif dataset == [] && param != [] + throw(error("Dataset Required for Parameter Estimation.")) + end + + if chain isa Lux.AbstractExplicitLayer || chain isa Flux.Chain + # Flux-vector, Lux-Named Tuple + initial_nnθ, recon, st = generate_Tar(chain, init_params) + else + error("Only Lux.AbstractExplicitLayer and Flux.Chain neural networks are supported") + end + + if nchains > Threads.nthreads() + throw(error("number of chains is greater than available threads")) + elseif nchains < 1 + throw(error("number of chains must be greater than 1")) + end + + # eltype(physdt) cause needs Float64 for find_good_stepsize + if chain isa Lux.AbstractExplicitLayer + # Lux chain(using component array later as vector_to_parameter need namedtuple) + initial_θ = collect(eltype(physdt), + vcat(ComponentArrays.ComponentArray(initial_nnθ))) + else + initial_θ = collect(eltype(physdt), initial_nnθ) + end + + # adding ode parameter estimation + nparameters = length(initial_θ) + ninv = length(param) + priors = [ + MvNormal(priorsNNw[1] * ones(nparameters), + LinearAlgebra.Diagonal(map(abs2, priorsNNw[2] .* ones(nparameters)))), + ] + + # append Ode params to all paramvector + if ninv > 0 + # shift ode params(initialise ode params by prior means) + initial_θ = vcat(initial_θ, [Distributions.params(param[i])[1] for i in 1:ninv]) + priors = vcat(priors, param) + nparameters += ninv + end + + t0 = prob.tspan[1] + # dimensions would be total no of params,initial_nnθ for Lux namedTuples + ℓπ = LogTargetDensity(nparameters, prob, recon, st, dataset, priors, + phystd, l2std, autodiff, physdt, ninv, + initial_nnθ) + + try + ℓπ(t0, initial_θ[1:(nparameters - ninv)]) + catch err + if isa(err, DimensionMismatch) + throw(DimensionMismatch("Dimensions of the initial u0 and chain should match")) + else + throw(err) + end + end + + # Define Hamiltonian system (nparameters ~ dimensionality of the sampling space) + metric = Metric(nparameters) + hamiltonian = Hamiltonian(metric, ℓπ, ForwardDiff) + + # parallel sampling option + if nchains != 1 + # Cache to store the chains + chains = Vector{Any}(undef, nchains) + statsc = Vector{Any}(undef, nchains) + samplesc = Vector{Any}(undef, nchains) + + Threads.@threads for i in 1:nchains + # each chain has different initial NNparameter values(better posterior exploration) + initial_θ = vcat(randn(nparameters - ninv), + initial_θ[(nparameters - ninv + 1):end]) + initial_ϵ = find_good_stepsize(hamiltonian, initial_θ) + integrator = integratorchoice(Integrator, initial_ϵ, jitter_rate, + tempering_rate) + adaptor = adaptorchoice(Adaptor, MassMatrixAdaptor(metric), + StepSizeAdaptor(targetacceptancerate, integrator)) + Kernel = AdvancedHMC.make_kernel(kernelchoice(Kernel, max_depth, Δ_max, + n_leapfrog, δ, λ), integrator) + samples, stats = sample(hamiltonian, Kernel, initial_θ, draw_samples, adaptor; + progress = progress, verbose = verbose) + + samplesc[i] = samples + statsc[i] = stats + mcmc_chain = Chains(hcat(samples...)') + chains[i] = mcmc_chain + end + + return chains, samplesc, statsc + else + initial_ϵ = find_good_stepsize(hamiltonian, initial_θ) + integrator = integratorchoice(Integrator, initial_ϵ, jitter_rate, tempering_rate) + adaptor = adaptorchoice(Adaptor, MassMatrixAdaptor(metric), + StepSizeAdaptor(targetacceptancerate, integrator)) + Kernel = AdvancedHMC.make_kernel(kernelchoice(Kernel, max_depth, Δ_max, n_leapfrog, + δ, λ), integrator) + samples, stats = sample(hamiltonian, Kernel, initial_θ, draw_samples, + adaptor; progress = progress, verbose = verbose) + + # return a chain(basic chain),samples and stats + matrix_samples = hcat(samples...) + mcmc_chain = MCMCChains.Chains(matrix_samples') + return mcmc_chain, samples, stats + end +end \ No newline at end of file diff --git a/test/BPINN_Tests.jl b/test/BPINN_Tests.jl new file mode 100644 index 0000000000..3fabbdfc3b --- /dev/null +++ b/test/BPINN_Tests.jl @@ -0,0 +1,278 @@ +# Testing Code +using Test, MCMCChains +using ForwardDiff, Distributions, OrdinaryDiffEq +using NeuralPDE, Flux, OptimizationOptimisers, AdvancedHMC, Lux +using Statistics, Random, Functors, ComponentArrays + +Random.seed!(100) + +# for sampled params->lux ComponentArray +function vector_to_parameters(ps_new::AbstractVector, ps::NamedTuple) + @assert length(ps_new) == Lux.parameterlength(ps) + i = 1 + function get_ps(x) + z = reshape(view(ps_new, i:(i + length(x) - 1)), size(x)) + i += length(x) + return z + end + return Functors.fmap(get_ps, ps) +end + +## PROBLEM-1 (WITHOUT PARAMETER ESTIMATION) +linear_analytic = (u0, p, t) -> u0 + sin(2 * π * t) / (2 * π) +linear = (u, p, t) -> cos(2 * π * t) +tspan = (0.0, 2.0) +u0 = 0.0 +prob = ODEProblem(ODEFunction(linear, analytic = linear_analytic), u0, tspan) + +# Numerical and Analytical Solutions +ta = range(tspan[1], tspan[2], length = 300) +u = [linear_analytic(u0, nothing, ti) for ti in ta] +sol1 = solve(prob, Tsit5()) + +# BPINN AND TRAINING DATASET CREATION, NN create, Reconstruct +x̂ = collect(Float64, Array(u) + 0.02 * randn(size(u))) +time = vec(collect(Float64, ta)) +dataset = [x̂[1:100], time[1:100]] + +# Call BPINN, create chain +chainflux = Flux.Chain(Flux.Dense(1, 7, tanh), Flux.Dense(7, 1)) |> f64 +chainlux = Lux.Chain(Lux.Dense(1, 7, tanh), Lux.Dense(7, 1)) + +fh_mcmc_chain1, fhsamples1, fhstats1 = ahmc_bayesian_pinn_ode(prob, chainflux, + dataset = dataset, + draw_samples = 2500, + n_leapfrog = 30) + +fh_mcmc_chain2, fhsamples2, fhstats2 = ahmc_bayesian_pinn_ode(prob, chainlux, + dataset = dataset, + draw_samples = 2500, + n_leapfrog = 30) + +init1, re1 = destructure(chainflux) +θinit, st = Lux.setup(Random.default_rng(), chainlux) + +# TESTING TIMEPOINTS TO PLOT ON,Actual Sols and actual data +t = time +p = prob.p +physsol1 = [linear_analytic(prob.u0, p, t[i]) for i in eachindex(t)] +physsol2 = [linear(physsol1[i], p, t[i]) for i in eachindex(t)] + +# Mean of last 1000 sampled parameter's curves(flux and lux chains)[Ensemble predictions] +out = re1.(fhsamples1[(end - 500):end]) +yu = collect(out[i](t') for i in eachindex(out)) +fluxmean = [mean(vcat(yu...)[:, i]) for i in eachindex(t)] +meanscurve1 = prob.u0 .+ (t .- prob.tspan[1]) .* fluxmean + +θ = [vector_to_parameters(fhsamples2[i], θinit) for i in 2000:2500] +luxar = [chainlux(t', θ[i], st)[1] for i in 1:500] +luxmean = [mean(vcat(luxar...)[:, i]) for i in eachindex(t)] +meanscurve2 = prob.u0 .+ (t .- prob.tspan[1]) .* luxmean + +@test mean(abs.(x̂ .- meanscurve1)) < 0.05 +@test mean(abs.(physsol1 .- meanscurve1)) < 0.005 +@test mean(abs.(x̂ .- meanscurve2)) < 0.05 +@test mean(abs.(physsol1 .- meanscurve2)) < 0.005 + +println("now parameter estimation problem 1") +## PROBLEM-1 (WITH PARAMETER ESTIMATION) +linear_analytic = (u0, p, t) -> u0 + sin(p * t) / (p) +linear = (u, p, t) -> cos(p * t) +tspan = (0.0, 2.0) +u0 = 0.0 +p = 2 * pi +prob = ODEProblem(ODEFunction(linear, analytic = linear_analytic), u0, tspan, p) + +# Numerical and Analytical Solutions +sol1 = solve(prob, Tsit5(); saveat = 0.01) +u = sol1.u +time = sol1.t + +# BPINN AND TRAINING DATASET CREATION +ta = range(tspan[1], tspan[2], length = 200) +u = [linear_analytic(u0, p, ti) for ti in ta] +x̂ = collect(Float64, Array(u) + 0.02 * randn(size(u))) +time = vec(collect(Float64, ta)) +dataset = [x̂[1:50], time[1:50]] + +# comparing how diff NNs capture non-linearity +chainflux1 = Flux.Chain(Flux.Dense(1, 7, tanh), Flux.Dense(7, 1)) |> f64 +chainlux1 = Lux.Chain(Lux.Dense(1, 7, tanh), Lux.Dense(7, 1)) + +fh_mcmc_chain1, fhsamples1, fhstats1 = ahmc_bayesian_pinn_ode(prob, chainflux1, + dataset = dataset, + draw_samples = 2500, + physdt = 1 / 50.0f0, + priorsNNw = (0.0, 3.0), + param = [LogNormal(9, 0.5)], + Metric = DiagEuclideanMetric, + n_leapfrog = 30) + +fh_mcmc_chain2, fhsamples2, fhstats2 = ahmc_bayesian_pinn_ode(prob, chainlux1, + dataset = dataset, + draw_samples = 2500, + physdt = 1 / 50.0f0, + priorsNNw = (0.0, 3.0), + param = [LogNormal(9, 0.5)], + Metric = DiagEuclideanMetric, + n_leapfrog = 30) + +init1, re1 = destructure(chainflux1) +θinit, st = Lux.setup(Random.default_rng(), chainlux1) + +# PLOT testing points +t = time +p = prob.p +physsol1 = [linear_analytic(prob.u0, p, t[i]) for i in eachindex(t)] +physsol2 = [linear(physsol1[i], p, t[i]) for i in eachindex(t)] + +# Mean of last 1000 sampled parameter's curves(flux and lux chains)[Ensemble predictions] +out = re1.([fhsamples1[i][1:22] for i in 2000:2500]) +yu = collect(out[i](t') for i in eachindex(out)) +fluxmean = [mean(vcat(yu...)[:, i]) for i in eachindex(t)] +meanscurve1 = prob.u0 .+ (t .- prob.tspan[1]) .* fluxmean + +θ = [vector_to_parameters(fhsamples2[i][1:(end - 1)], θinit) for i in 2000:2500] +luxar = [chainlux1(t', θ[i], st)[1] for i in 1:500] +luxmean = [mean(vcat(luxar...)[:, i]) for i in eachindex(t)] +meanscurve2 = prob.u0 .+ (t .- prob.tspan[1]) .* luxmean + +@test mean(abs.(x̂ .- meanscurve1)) < 5e-1 +@test mean(abs.(physsol1 .- meanscurve1)) < 5e-1 +@test mean(abs.(x̂ .- meanscurve2)) < 5e-2 +@test mean(abs.(physsol1 .- meanscurve2)) < 5e-2 + +# ESTIMATED ODE PARAMETERS (NN1 AND NN2) +@test abs(p - mean([fhsamples2[i][23] for i in 2000:2500])) < abs(0.2 * p) +@test abs(p - mean([fhsamples1[i][23] for i in 2000:2500])) < abs(0.2 * p) + +## PROBLEM-2 +linear = (u, p, t) -> -u / p[1] + exp(t / p[2]) * cos(t) +tspan = (0.0, 10.0) +u0 = 0.0 +p = [5.0, -5.0] +prob = ODEProblem(linear, u0, tspan, p) + +# PROBLEM-2 +linear_analytic = (u0, p, t) -> exp(-t / 5) * (u0 + sin(t)) + +# PLOT SOLUTION AND CREATE DATASET +sol = solve(prob, Tsit5(); saveat = 0.05) +u = sol.u[1:100] +time = sol.t[1:100] + +# dataset and BPINN create +x̂ = collect(Float64, Array(u) + 0.05 * randn(size(u))) +dataset = [x̂, time] + +chainflux12 = Flux.Chain(Flux.Dense(1, 6, tanh), Flux.Dense(6, 6, tanh), + Flux.Dense(6, 1)) |> f64 +chainlux12 = Lux.Chain(Lux.Dense(1, 6, tanh), Lux.Dense(6, 6, tanh), Lux.Dense(6, 1)) + +fh_mcmc_chainflux12, fhsamplesflux12, fhstatsflux12 = ahmc_bayesian_pinn_ode(prob, + chainflux12, + dataset = dataset, + draw_samples = 2000, + l2std = [0.05], + phystd = [ + 0.05, + ], + priorsNNw = (0.0, + 3.0), + n_leapfrog = 30) + +fh_mcmc_chainflux22, fhsamplesflux22, fhstatsflux22 = ahmc_bayesian_pinn_ode(prob, + chainflux12, + dataset = dataset, + draw_samples = 2000, + l2std = [0.05], + phystd = [ + 0.05, + ], + priorsNNw = (0.0, + 3.0), + param = [ + Normal(6.5, + 0.5), + Normal(-3, + 0.5), + ], + n_leapfrog = 30) + +fh_mcmc_chainlux12, fhsampleslux12, fhstatslux12 = ahmc_bayesian_pinn_ode(prob, chainlux12, + dataset = dataset, + draw_samples = 2000, + l2std = [0.05], + phystd = [0.05], + priorsNNw = (0.0, + 3.0), + n_leapfrog = 30) + +fh_mcmc_chainlux22, fhsampleslux22, fhstatslux22 = ahmc_bayesian_pinn_ode(prob, chainlux12, + dataset = dataset, + draw_samples = 2000, + l2std = [0.05], + phystd = [0.05], + priorsNNw = (0.0, + 3.0), + param = [ + Normal(6.5, + 0.5), + Normal(-3, + 0.5), + ], + n_leapfrog = 30) + +init1, re1 = destructure(chainflux12) +θinit, st = Lux.setup(Random.default_rng(), chainlux12) + +# PLOT testing points +t = sol.t +p = prob.p +physsol1 = [linear_analytic(prob.u0, p, t[i]) for i in eachindex(t)] + +# Mean of last 500 sampled parameter's curves(flux chains)[Ensemble predictions] +out = re1.([fhsamplesflux12[i][1:61] for i in 1500:2000]) +yu = [out[i](t') for i in eachindex(out)] +fluxmean = [mean(vcat(yu...)[:, i]) for i in eachindex(t)] +meanscurve1_1 = prob.u0 .+ (t .- prob.tspan[1]) .* fluxmean + +@test mean(abs.(sol.u .- meanscurve1_1)) < 1e-2 +@test mean(abs.(physsol1 .- meanscurve1_1)) < 1e-2 + +out = re1.([fhsamplesflux22[i][1:61] for i in 1500:2000]) +yu = [out[i](t') for i in eachindex(out)] +fluxmean = [mean(vcat(yu...)[:, i]) for i in eachindex(t)] +meanscurve1_2 = prob.u0 .+ (t .- prob.tspan[1]) .* fluxmean + +@test mean(abs.(sol.u .- meanscurve1_2)) < 5e-2 +@test mean(abs.(physsol1 .- meanscurve1_2)) < 5e-2 + +# estimated parameters(flux chain) +param1 = mean(i[62] for i in fhsamplesflux22[1500:2000]) +param2 = mean(i[63] for i in fhsamplesflux22[1500:2000]) +@test abs(param1 - p[1]) < abs(0.3 * p[1]) +@test abs(param2 - p[2]) < abs(0.3 * p[2]) + +# Mean of last 500 sampled parameter's curves(lux chains)[Ensemble predictions] +θ = [vector_to_parameters(fhsampleslux12[i], θinit) for i in 1500:2000] +luxar = [chainlux12(t', θ[i], st)[1] for i in 1:500] +luxmean = [mean(vcat(luxar...)[:, i]) for i in eachindex(t)] +meanscurve2_1 = prob.u0 .+ (t .- prob.tspan[1]) .* luxmean + +@test mean(abs.(sol.u .- meanscurve2_1)) < 1e-2 +@test mean(abs.(physsol1 .- meanscurve2_1)) < 1e-2 + +θ = [vector_to_parameters(fhsampleslux22[i][1:(end - 2)], θinit) for i in 1500:2000] +luxar = [chainlux12(t', θ[i], st)[1] for i in 1:500] +luxmean = [mean(vcat(luxar...)[:, i]) for i in eachindex(t)] +meanscurve2_2 = prob.u0 .+ (t .- prob.tspan[1]) .* luxmean + +@test mean(abs.(sol.u .- meanscurve2_2)) < 5e-2 +@test mean(abs.(physsol1 .- meanscurve2_2)) < 5e-2 + +# estimated parameters(lux chain) +param1 = mean(i[62] for i in fhsampleslux22[1500:2000]) +param2 = mean(i[63] for i in fhsampleslux22[1500:2000]) +@test abs(param1 - p[1]) < abs(0.3 * p[1]) +@test abs(param2 - p[2]) < abs(0.3 * p[2]) \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 75bf98a91f..2a176cf870 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -15,12 +15,18 @@ function dev_subpkg(subpkg) end @time begin + #fixes 682 + if GROUP == "All" || GROUP == "ODEBPINN" + @time @safetestset "Bpinn ODE solver" begin include("BPINN_Tests.jl") end + end + if GROUP == "All" || GROUP == "NNPDE1" @time @safetestset "NNPDE" begin include("NNPDE_tests.jl") end end if GROUP == "All" || GROUP == "NNODE" @time @safetestset "NNODE" begin include("NNODE_tests.jl") end end + if GROUP == "All" || GROUP == "NNPDE2" @time @safetestset "Additional Loss" begin include("additional_loss_tests.jl") end @time @safetestset "Direction Function Approximation" begin include("direct_function_tests.jl") end