diff --git a/src/advancedHMC_MCMC.jl b/src/advancedHMC_MCMC.jl index 18452df0c..98a8d5db2 100644 --- a/src/advancedHMC_MCMC.jl +++ b/src/advancedHMC_MCMC.jl @@ -104,6 +104,10 @@ function physloglikelihood(Tar::LogTargetDensity, θ) p = Tar.prob.p f = Tar.prob.f + allparams = Tar.priors + invparams = allparams[2:length(allparams)] + meaninv = [invparam[1] for invparam in invparams] + autodiff = Tar.autodiff dt = Tar.physdt t = collect(Float64, Tar.prob.tspan[1]:dt:Tar.prob.tspan[2]) @@ -113,13 +117,17 @@ function physloglikelihood(Tar::LogTargetDensity, θ) # # this is a vector{vector{dx,dy}}(handle case single u(float passed)) if length(out[:, 1]) == 1 + # shifted by prior mean + ode_params = exp.(θ[((length(θ) - Tar.extraparams) + 1):length(θ)] + log.(meaninv)) physsol = [f(out[:, i][1], - θ[((length(θ) - Tar.extraparams) + 1):length(θ)], + ode_params, t[i]) for i in 1:length(out[1, :])] else + # shifted by prior mean + ode_params = exp.(θ[((length(θ) - Tar.extraparams) + 1):length(θ)] + log.(meaninv)) physsol = [f(out[:, i], - θ[((length(θ) - Tar.extraparams) + 1):length(θ)], + ode_params, t[i]) for i in 1:length(out[1, :])] end @@ -132,9 +140,7 @@ function physloglikelihood(Tar::LogTargetDensity, θ) n = length(out[1, :]) for i in 1:length(Tar.prob.u0) # can add phystd[i] for u[i] - physlogprob += logpdf(MvNormal(nnsol[i, :], - Diagonal(Tar.phystd[i]^2 .* ones(n))), - physsol[i, :]) + physlogprob += logpdf(MvNormal(nnsol[i, :], Tar.phystd[i]), physsol[i, :]) end return physlogprob end @@ -149,8 +155,7 @@ function L2LossData(Tar::LogTargetDensity, θ) for i in 1:length(Tar.prob.u0) # can add l2std[i] for u[i] # for u[i] ith vector must be added to dataset,nn[1,:] is the dx in lotka_volterra - L2logprob += logpdf(MvNormal(nn[i, :], Diagonal(Tar.l2std[i]^2 .* ones(n))), - Tar.dataset[i]) + L2logprob += logpdf(MvNormal(nn[i, :], Tar.l2std[i]), Tar.dataset[i]) end return L2logprob end @@ -158,27 +163,21 @@ end # priors for NN parameters + ODE constants function priorweights(Tar::LogTargetDensity, θ) allparams = Tar.priors + # ode parameters invparams = allparams[2:length(allparams)] - meaninv = [invparam[1] for invparam in invparams] + stdinv = [invparam[2] for invparam in invparams] # nn weights nnwparams = allparams[1] - varw = nnwparams[2]^2 + stdw = nnwparams[2] prisw = nnwparams[1] .* ones(length(θ) - Tar.extraparams) if Tar.extraparams > 0 - # ode parameters - invparams = allparams[2:length(allparams)] - varinv = [invparam[2]^2 for invparam in invparams] - - return (logpdf(MvNormal(θ[((length(θ) - Tar.extraparams) + 1):length(θ)], - Diagonal(varinv)), meaninv) + return (logpdf(MvNormal(zeros(Tar.extraparams), stdinv), + θ[((length(θ) - Tar.extraparams) + 1):length(θ)]) + - logpdf(MvNormal(θ[1:(length(θ) - Tar.extraparams)], - Diagonal(varw .* - ones(length(θ[1:(length(θ) - Tar.extraparams)])))), - prisw)) + logpdf(MvNormal(prisw, stdw), θ[1:(length(θ) - Tar.extraparams)])) else - return logpdf(MvNormal(θ, Diagonal(varw .* ones(length(θ)))), prisw) + return logpdf(MvNormal(prisw, stdw), θ) end end @@ -240,7 +239,7 @@ function ahmc_bayesian_pinn_ode(prob::DiffEqBase.DEProblem, chain::Flux.Chain, # [i[1] for i in param] if length(param) > 0 - append!(initial_θ, randn(length(param))) + append!(initial_θ, zeros(length(param))) append!(priors, param) end nparameters = length(initial_θ) diff --git a/test/BPINN_Tests.jl b/test/BPINN_Tests.jl index 240ef411a..ccfd7ceb0 100644 --- a/test/BPINN_Tests.jl +++ b/test/BPINN_Tests.jl @@ -211,4 +211,11 @@ fh_mcmc_chain, fhsamples, fhstats = ahmc_bayesian_pinn_ode(prob, chainfh, datase autodiff = true, l2std = [0.05], phystd = [0.05], priorsNNw = (0.0, 3.0), - param = [(2.3, 0.5), (4.3, 0.5)]) \ No newline at end of file + param = [(2.3, 0.5), (4.3, 0.5)]) +fhsamples1[1000] +fhsamples[1000] +param = [(1.5, 0.5), (1.2, 0.5), (3.3, 0.5), (1.4, 0.5)] +param = [(2.3, 0.5), (4.3, 0.5)] +yuhj = log.([i[1] for i in param]) +exp.(fhsamples1[1000][17:18] + yuhj) +exp.(fhsamples[1000][17:18] + yuhj) \ No newline at end of file