Skip to content

Commit

Permalink
minor changes
Browse files Browse the repository at this point in the history
  • Loading branch information
AstitvaAggarwal committed Aug 1, 2023
1 parent 9c07c97 commit 2c70494
Showing 1 changed file with 2 additions and 9 deletions.
11 changes: 2 additions & 9 deletions test/BPINN_Tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@
using DifferentialEquations, MCMCChains, ForwardDiff, Distributions
using NeuralPDE, Flux, OptimizationOptimisers, AdvancedHMC, Lux
using StatProfilerHTML, Profile, Statistics, Random, Functors, ComponentArrays
using BenchmarkTools, Plots, StatsPlots, Test
plotly()
Profile.init()
using BenchmarkTools, Test

# for sampled params->lux ComponentArray
function vector_to_parameters(ps_new::AbstractVector, ps::NamedTuple)
Expand Down Expand Up @@ -155,12 +153,11 @@ linear_analytic = (u0, p, t) -> exp(-t / 5) * (u0 + sin(t))
sol = solve(prob, Tsit5(); saveat = 0.05)
u = sol.u[1:100]
time = sol.t[1:100]
plot(sol1.t, sol1.u)

# dataset and BPINN create
= collect(Float64, Array(u) + 0.05 * randn(size(u)))
dataset = [x̂, time]
plot!(time, x̂)

chainflux1 = Flux.Chain(Flux.Dense(1, 5, tanh), Flux.Dense(5, 5, tanh), Flux.Dense(5, 1))
chainlux1 = Lux.Chain(Lux.Dense(1, 5, tanh), Lux.Dense(5, 5, tanh), Lux.Dense(5, 1))

Expand Down Expand Up @@ -212,7 +209,6 @@ init1, re1 = destructure(chainflux1)
t = sol.t
p = prob.p
physsol1 = [linear_analytic(prob.u0, p, t[i]) for i in eachindex(t)]
plot!(t, physsol1)

# Mean of last 500 sampled parameter's curves(flux chains)[Ensemble predictions]
out = re1.([fhsamplesflux1[i][1:22] for i in 500:1000])
Expand Down Expand Up @@ -277,7 +273,6 @@ prob = ODEProblem(lotka_volterra, u0, tspan, p)
solution = solve(prob, Tsit5(); saveat = 0.05)

# Plot simulation.
plot(solution)
time = solution.t
u = hcat(solution.u...)
# BPINN AND TRAINING DATASET CREATION, NN create, Reconstruct
Expand Down Expand Up @@ -378,7 +373,6 @@ function getensemble(yu, num_models)
end
fluxmean = getensemble(yu, length(out))
meanscurve1_1 = prob.u0 .+ (t' .- prob.tspan[1]) .* fluxmean
plot!(t, meanscurve1_1')
mean(abs.(u .- meanscurve1_1))

@test mean(abs2.(x̂ .- meanscurve1_1)) < 2e-2
Expand All @@ -390,7 +384,6 @@ out = re1.([fhsamplesflux2[i][1:68] for i in 500:1000])
yu = collect(out[i](t') for i in eachindex(out))
fluxmean = getensemble(yu, length(out))
meanscurve1_2 = prob.u0 .+ (t' .- prob.tspan[1]) .* fluxmean
plot!(t, meanscurve1_2')
mean(abs.(u .- meanscurve1_2))

@test mean(abs2.(x̂ .- meanscurve1)) < 2e-2
Expand Down

0 comments on commit 2c70494

Please sign in to comment.