Skip to content

Commit

Permalink
parameter estimation does not work
Browse files Browse the repository at this point in the history
  • Loading branch information
AstitvaAggarwal committed Jul 17, 2023
1 parent dfed5ee commit f77d04a
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 22 deletions.
41 changes: 20 additions & 21 deletions src/advancedHMC_MCMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ function physloglikelihood(Tar::LogTargetDensity, θ)
p = Tar.prob.p
f = Tar.prob.f

Check warning on line 105 in src/advancedHMC_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/advancedHMC_MCMC.jl#L103-L105

Added lines #L103 - L105 were not covered by tests

allparams = Tar.priors
invparams = allparams[2:length(allparams)]
meaninv = [invparam[1] for invparam in invparams]

Check warning on line 109 in src/advancedHMC_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/advancedHMC_MCMC.jl#L107-L109

Added lines #L107 - L109 were not covered by tests

autodiff = Tar.autodiff
dt = Tar.physdt
t = collect(Float64, Tar.prob.tspan[1]:dt:Tar.prob.tspan[2])

Check warning on line 113 in src/advancedHMC_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/advancedHMC_MCMC.jl#L111-L113

Added lines #L111 - L113 were not covered by tests
Expand All @@ -113,13 +117,17 @@ function physloglikelihood(Tar::LogTargetDensity, θ)

# # this is a vector{vector{dx,dy}}(handle case single u(float passed))
if length(out[:, 1]) == 1

Check warning on line 119 in src/advancedHMC_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/advancedHMC_MCMC.jl#L119

Added line #L119 was not covered by tests
# shifted by prior mean
ode_params = exp.(θ[((length(θ) - Tar.extraparams) + 1):length(θ)] + log.(meaninv))
physsol = [f(out[:, i][1],

Check warning on line 122 in src/advancedHMC_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/advancedHMC_MCMC.jl#L121-L122

Added lines #L121 - L122 were not covered by tests
θ[((length(θ) - Tar.extraparams) + 1):length(θ)],
ode_params,
t[i])
for i in 1:length(out[1, :])]
else
# shifted by prior mean
ode_params = exp.(θ[((length(θ) - Tar.extraparams) + 1):length(θ)] + log.(meaninv))
physsol = [f(out[:, i],

Check warning on line 129 in src/advancedHMC_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/advancedHMC_MCMC.jl#L128-L129

Added lines #L128 - L129 were not covered by tests
θ[((length(θ) - Tar.extraparams) + 1):length(θ)],
ode_params,
t[i])
for i in 1:length(out[1, :])]
end
Expand All @@ -132,9 +140,7 @@ function physloglikelihood(Tar::LogTargetDensity, θ)
n = length(out[1, :])
for i in 1:length(Tar.prob.u0)

Check warning on line 141 in src/advancedHMC_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/advancedHMC_MCMC.jl#L139-L141

Added lines #L139 - L141 were not covered by tests
# can add phystd[i] for u[i]
physlogprob += logpdf(MvNormal(nnsol[i, :],
Diagonal(Tar.phystd[i]^2 .* ones(n))),
physsol[i, :])
physlogprob += logpdf(MvNormal(nnsol[i, :], Tar.phystd[i]), physsol[i, :])
end
return physlogprob

Check warning on line 145 in src/advancedHMC_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/advancedHMC_MCMC.jl#L143-L145

Added lines #L143 - L145 were not covered by tests
end
Expand All @@ -149,36 +155,29 @@ function L2LossData(Tar::LogTargetDensity, θ)
for i in 1:length(Tar.prob.u0)

Check warning on line 155 in src/advancedHMC_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/advancedHMC_MCMC.jl#L153-L155

Added lines #L153 - L155 were not covered by tests
# can add l2std[i] for u[i]
# for u[i] ith vector must be added to dataset,nn[1,:] is the dx in lotka_volterra
L2logprob += logpdf(MvNormal(nn[i, :], Diagonal(Tar.l2std[i]^2 .* ones(n))),
Tar.dataset[i])
L2logprob += logpdf(MvNormal(nn[i, :], Tar.l2std[i]), Tar.dataset[i])
end
return L2logprob

Check warning on line 160 in src/advancedHMC_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/advancedHMC_MCMC.jl#L158-L160

Added lines #L158 - L160 were not covered by tests
end

# priors for NN parameters + ODE constants
function priorweights(Tar::LogTargetDensity, θ)
allparams = Tar.priors

Check warning on line 165 in src/advancedHMC_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/advancedHMC_MCMC.jl#L164-L165

Added lines #L164 - L165 were not covered by tests
# ode parameters
invparams = allparams[2:length(allparams)]
meaninv = [invparam[1] for invparam in invparams]
stdinv = [invparam[2] for invparam in invparams]

Check warning on line 168 in src/advancedHMC_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/advancedHMC_MCMC.jl#L167-L168

Added lines #L167 - L168 were not covered by tests
# nn weights
nnwparams = allparams[1]
varw = nnwparams[2]^2
stdw = nnwparams[2]
prisw = nnwparams[1] .* ones(length(θ) - Tar.extraparams)

Check warning on line 172 in src/advancedHMC_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/advancedHMC_MCMC.jl#L170-L172

Added lines #L170 - L172 were not covered by tests

if Tar.extraparams > 0
# ode parameters
invparams = allparams[2:length(allparams)]
varinv = [invparam[2]^2 for invparam in invparams]

return (logpdf(MvNormal(θ[((length(θ) - Tar.extraparams) + 1):length(θ)],
Diagonal(varinv)), meaninv)
return (logpdf(MvNormal(zeros(Tar.extraparams), stdinv),

Check warning on line 175 in src/advancedHMC_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/advancedHMC_MCMC.jl#L174-L175

Added lines #L174 - L175 were not covered by tests
θ[((length(θ) - Tar.extraparams) + 1):length(θ)])
+
logpdf(MvNormal(θ[1:(length(θ) - Tar.extraparams)],
Diagonal(varw .*
ones(length(θ[1:(length(θ) - Tar.extraparams)])))),
prisw))
logpdf(MvNormal(prisw, stdw), θ[1:(length(θ) - Tar.extraparams)]))
else
return logpdf(MvNormal(θ, Diagonal(varw .* ones(length(θ)))), prisw)
return logpdf(MvNormal(prisw, stdw), θ)

Check warning on line 180 in src/advancedHMC_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/advancedHMC_MCMC.jl#L180

Added line #L180 was not covered by tests
end
end

Expand Down Expand Up @@ -240,7 +239,7 @@ function ahmc_bayesian_pinn_ode(prob::DiffEqBase.DEProblem, chain::Flux.Chain,

# [i[1] for i in param]
if length(param) > 0
append!(initial_θ, randn(length(param)))
append!(initial_θ, zeros(length(param)))
append!(priors, param)

Check warning on line 243 in src/advancedHMC_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/advancedHMC_MCMC.jl#L241-L243

Added lines #L241 - L243 were not covered by tests
end
nparameters = length(initial_θ)

Check warning on line 245 in src/advancedHMC_MCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/advancedHMC_MCMC.jl#L245

Added line #L245 was not covered by tests
Expand Down
9 changes: 8 additions & 1 deletion test/BPINN_Tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -211,4 +211,11 @@ fh_mcmc_chain, fhsamples, fhstats = ahmc_bayesian_pinn_ode(prob, chainfh, datase
autodiff = true, l2std = [0.05],
phystd = [0.05],
priorsNNw = (0.0, 3.0),
param = [(2.3, 0.5), (4.3, 0.5)])
param = [(2.3, 0.5), (4.3, 0.5)])
fhsamples1[1000]
fhsamples[1000]
param = [(1.5, 0.5), (1.2, 0.5), (3.3, 0.5), (1.4, 0.5)]
param = [(2.3, 0.5), (4.3, 0.5)]
yuhj = log.([i[1] for i in param])
exp.(fhsamples1[1000][17:18] + yuhj)
exp.(fhsamples[1000][17:18] + yuhj)

0 comments on commit f77d04a

Please sign in to comment.