From 2c704947b90a8b930376f21f7ead6a6ce6b4be95 Mon Sep 17 00:00:00 2001 From: Astitva Aggarwal Date: Wed, 2 Aug 2023 01:28:58 +0530 Subject: [PATCH] minor changes --- test/BPINN_Tests.jl | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/test/BPINN_Tests.jl b/test/BPINN_Tests.jl index d1f5cb8ae..15a7adbdc 100644 --- a/test/BPINN_Tests.jl +++ b/test/BPINN_Tests.jl @@ -2,9 +2,7 @@ using DifferentialEquations, MCMCChains, ForwardDiff, Distributions using NeuralPDE, Flux, OptimizationOptimisers, AdvancedHMC, Lux using StatProfilerHTML, Profile, Statistics, Random, Functors, ComponentArrays -using BenchmarkTools, Plots, StatsPlots, Test -plotly() -Profile.init() +using BenchmarkTools, Test # for sampled params->lux ComponentArray function vector_to_parameters(ps_new::AbstractVector, ps::NamedTuple) @@ -155,12 +153,11 @@ linear_analytic = (u0, p, t) -> exp(-t / 5) * (u0 + sin(t)) sol = solve(prob, Tsit5(); saveat = 0.05) u = sol.u[1:100] time = sol.t[1:100] -plot(sol1.t, sol1.u) # dataset and BPINN create x̂ = collect(Float64, Array(u) + 0.05 * randn(size(u))) dataset = [x̂, time] -plot!(time, x̂) + chainflux1 = Flux.Chain(Flux.Dense(1, 5, tanh), Flux.Dense(5, 5, tanh), Flux.Dense(5, 1)) chainlux1 = Lux.Chain(Lux.Dense(1, 5, tanh), Lux.Dense(5, 5, tanh), Lux.Dense(5, 1)) @@ -212,7 +209,6 @@ init1, re1 = destructure(chainflux1) t = sol.t p = prob.p physsol1 = [linear_analytic(prob.u0, p, t[i]) for i in eachindex(t)] -plot!(t, physsol1) # Mean of last 500 sampled parameter's curves(flux chains)[Ensemble predictions] out = re1.([fhsamplesflux1[i][1:22] for i in 500:1000]) @@ -277,7 +273,6 @@ prob = ODEProblem(lotka_volterra, u0, tspan, p) solution = solve(prob, Tsit5(); saveat = 0.05) # Plot simulation. -plot(solution) time = solution.t u = hcat(solution.u...) # BPINN AND TRAINING DATASET CREATION, NN create, Reconstruct @@ -378,7 +373,6 @@ function getensemble(yu, num_models) end fluxmean = getensemble(yu, length(out)) meanscurve1_1 = prob.u0 .+ (t' .- prob.tspan[1]) .* fluxmean -plot!(t, meanscurve1_1') mean(abs.(u .- meanscurve1_1)) @test mean(abs2.(x̂ .- meanscurve1_1)) < 2e-2 @@ -390,7 +384,6 @@ out = re1.([fhsamplesflux2[i][1:68] for i in 500:1000]) yu = collect(out[i](t') for i in eachindex(out)) fluxmean = getensemble(yu, length(out)) meanscurve1_2 = prob.u0 .+ (t' .- prob.tspan[1]) .* fluxmean -plot!(t, meanscurve1_2') mean(abs.(u .- meanscurve1_2)) @test mean(abs2.(x̂ .- meanscurve1)) < 2e-2