Skip to content

Commit

Permalink
Merge pull request #741 from AstitvaAggarwal/develop
Browse files Browse the repository at this point in the history
changes from final reviews, improved Docs
  • Loading branch information
Vaibhavdixit02 authored Oct 5, 2023
2 parents 100bb0f + 70615bd commit 00cc7f5
Show file tree
Hide file tree
Showing 7 changed files with 149 additions and 190 deletions.
2 changes: 1 addition & 1 deletion docs/pages.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
pages = ["index.md",
"ODE PINN Tutorials" => Any["Introduction to NeuralPDE for ODEs" => "tutorials/ode.md",
"Baysian PINNs - Lotka-Volterra" => "examples/Lotka_Volterra_BPINNs.md"
"Bayesian PINNs for Coupled ODEs - Lotka-Volterra" => "examples/Lotka_Volterra_BPINNs.md"
#"examples/nnrode_example.md", # currently incorrect
],
"PDE PINN Tutorials" => Any["Introduction to NeuralPDE for PDEs" => "tutorials/pdesystem.md",
Expand Down
83 changes: 34 additions & 49 deletions docs/src/examples/Lotka_Volterra_BPINNs.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ The Lotka–Volterra equations, also known as the predator–prey equations, are
These differential equations are frequently used to describe the dynamics of biological systems in which two species interact, one as a predator and the other as prey.
The populations change through time according to the pair of equations

$$
```math
\begin{aligned}
\frac{\mathrm{d}x}{\mathrm{d}t} &= (\alpha - \beta y(t))x(t), \\
\frac{\mathrm{d}y}{\mathrm{d}t} &= (\delta x(t) - \gamma)y(t)
\end{aligned}
$$
```

where $x(t)$ and $y(t)$ denote the populations of prey and predator at time $t$, respectively, and $\alpha, \beta, \gamma, \delta$ are positive parameters.

Expand All @@ -23,7 +23,9 @@ We then solve the equations and estimate the parameters of the model with priors

And also solve the equations for the constructed ODEProblem's provided ideal `p` values using a Lux.jl Neural Network, chain_lux.

```julia
```julia
using NeuralPDE, Flux, Lux, Plots, StatsPlots, OrdinaryDiffEq, Distributions

function lotka_volterra(u, p, t)
# Model parameters.
α, β, γ, δ = p
Expand All @@ -43,11 +45,11 @@ p = [1.5, 1.0, 3.0, 1.0]
tspan = (0.0, 6.0)
prob = ODEProblem(lotka_volterra, u0, tspan, p)

# Plot simulation.
```
With the [`saveat` argument](https://docs.sciml.ai/latest/basics/common_solver_opts/) we can specify that the solution is stored only at `saveat` time units(default saveat=1 / 50.0).

```julia
# Plot solution got by Standard DifferentialEquations.jl ODE solver
solution = solve(prob, Tsit5(); saveat = 0.05)
plot(solve(prob, Tsit5()))

Expand All @@ -58,67 +60,49 @@ To make the example more realistic we add random normally distributed noise to t


```julia
# Dataset creation for parameter estimation
# Dataset creation for parameter estimation (30% noise)
time = solution.t
u = hcat(solution.u...)
x = u[1, :] + 0.5 * randn(length(u[1, :]))
y = u[2, :] + 0.5 * randn(length(u[1, :]))
u = hcat(solution.u...)
x = u[1, :] + (0.3 .*u[1, :]).*randn(length(u[1, :]))
y = u[2, :] + (0.3 .*u[2, :]).*randn(length(u[2, :]))
dataset = [x, y, time]

# Neural Networks must have 2 outputs as u -> [dx,dy] in function lotka_volterra()
chainflux = Flux.Chain(Flux.Dense(1, 6, tanh), Flux.Dense(6, 6, tanh), Flux.Dense(6, 2)) |> Flux.f64
chainflux = Flux.Chain(Flux.Dense(1, 6, tanh), Flux.Dense(6, 6, tanh),
Flux.Dense(6, 2)) |> Flux.f64
chainlux = Lux.Chain(Lux.Dense(1, 7, Lux.tanh), Lux.Dense(7, 7, Lux.tanh),
Lux.Dense(7, 2))

chainlux = Lux.Chain(Lux.Dense(1, 6, Lux.tanh), Lux.Dense(6, 6, Lux.tanh), Lux.Dense(6, 2))
```
A Dataset is required as parameter estimation is being done using provided priors in `param` keyword argument for BNNODE.

```julia
alg1 = NeuralPDE.BNNODE(chainflux,
dataset = dataset,
draw_samples = 1000,
l2std = [
0.05,
0.05,
],
phystd = [
0.05,
0.05,
],
priorsNNw = (0.0,
3.0),
l2std = [0.1, 0.1],
phystd = [0.1, 0.1],
priorsNNw = (0.0, 3.0),
param = [
Normal(1,
2),
Normal(2,
2),
Normal(2,
2),
Normal(0,
2),
],
n_leapfrog = 30, progress = true)
Normal(1, 2),
Normal(2, 2),
Normal(2, 2),
Normal(0, 2),
], progress = true)

sol_flux_pestim = solve(prob, alg1)

# Dataset not needed as we are solving the equation with ideal parameters
alg2 = NeuralPDE.BNNODE(chainlux,
draw_samples = 1000,
l2std = [
0.05,
0.05,
],
phystd = [
0.05,
0.05,
],
priorsNNw = (0.0,
3.0),
n_leapfrog = 30, progress = true)
phystd = [0.05, 0.05],
priorsNNw = (0.0, 10.0),
progress = true)

sol_lux = solve(prob, alg2)

#testing timepoints must match keyword arg `saveat`` timepoints of solve() call
t=collect(Float64,prob.tspan[1]:1/50.0:prob.tspan[2])
t = collect(Float64, prob.tspan[1]:(1 / 50.0):prob.tspan[2])

```

Expand All @@ -133,13 +117,14 @@ plot(t,sol_flux_pestim.ensemblesol[1])
plot!(t,sol_flux_pestim.ensemblesol[2])

# estimated ODE parameters by .estimated_ode_params, weights and biases by .estimated_nn_params
println(sol_flux_pestim.estimated_ode_params)
sol_flux_pestim.estimated_nn_params
sol_flux_pestim.estimated_ode_params


# plotting solution for x,y for chain_lux
plot(t,sol_lux_pestim.ensemblesol[1])
plot!(t,sol_lux_pestim.ensemblesol[2])
plot(t,sol_lux.ensemblesol[1])
plot!(t,sol_lux.ensemblesol[2])

# estimated weights and biases by .estimated_nn_params for chain_lux
sol_lux_pestim.estimated_nn_params
```
# estimated weights and biases by .estimated_nn_params for chain_lux
sol_lux.estimated_nn_params

```
4 changes: 2 additions & 2 deletions docs/src/examples/ks.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@ Let's consider the Kuramoto–Sivashinsky equation, which contains a 4th-order d
∂_t u(x, t) + u(x, t) ∂_x u(x, t) + \alpha ∂^2_x u(x, t) + \beta ∂^3_x u(x, t) + \gamma ∂^4_x u(x, t) = 0 \, ,
```

where `\alpha = \gamma = 1` and `\beta = 4`. The exact solution is:
where $\alpha = \gamma = 1$ and $\beta = 4$. The exact solution is:

```math
u_e(x, t) = 11 + 15 \tanh \theta - 15 \tanh^2 \theta - 15 \tanh^3 \theta \, ,
```

where `\theta = t - x/2` and with initial and boundary conditions:
where $\theta = t - x/2$ and with initial and boundary conditions:

```math
\begin{align*}
Expand Down
2 changes: 1 addition & 1 deletion docs/src/tutorials/param_estim.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Consider a Lorenz System,
\end{align*}
```

with Physics-Informed Neural Networks. Now we would consider the case where we want to optimize the parameters `\sigma`, `\beta`, and `\rho`.
with Physics-Informed Neural Networks. Now we would consider the case where we want to optimize the parameters $\sigma$, $\beta$, and $\rho$.

We start by defining the problem,

Expand Down
105 changes: 40 additions & 65 deletions src/BPINN_ode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,11 @@
```julia
BNNODE(chain, Kernel = HMC; strategy = nothing, draw_samples = 2000,
priorsNNw = (0.0, 2.0), param = [nothing], l2std = [0.05],
phystd = [0.05], dataset = [nothing],
init_params = nothing, physdt = 1 / 20.0, nchains = 1,
autodiff = false, Integrator = Leapfrog,
Adaptor = StanHMCAdaptor, targetacceptancerate = 0.8,
Metric = DiagEuclideanMetric, jitter_rate = 3.0,
tempering_rate = 3.0, max_depth = 10, Δ_max = 1000,
n_leapfrog = 20, δ = 0.65, λ = 0.3, progress = false,
verbose = false)
phystd = [0.05], dataset = [nothing], physdt = 1 / 20.0,
MCMCargs = (n_leapfrog=30), nchains = 1, init_params = nothing,
Adaptorkwargs = (Adaptor = StanHMCAdaptor, targetacceptancerate = 0.8, Metric = DiagEuclideanMetric),
Integratorkwargs = (Integrator = Leapfrog,), autodiff = false,
progress = false, verbose = false)
```
Algorithm for solving ordinary differential equations using a Bayesian neural network. This is a specialization
Expand Down Expand Up @@ -51,17 +48,16 @@ chainlux = Lux.Chain(Lux.Dense(1, 6, tanh), Lux.Dense(6, 6, tanh), Lux.Dense(6,
alg = NeuralPDE.BNNODE(chainlux, draw_samples = 2000,
l2std = [0.05], phystd = [0.05],
priorsNNw = (0.0, 3.0),
n_leapfrog = 30, progress = true)
priorsNNw = (0.0, 3.0), progress = true)
sol_lux = solve(prob, alg)
# with parameter estimation
alg = NeuralPDE.BNNODE(chainlux,dataset = dataset,
draw_samples = 2000,l2std = [0.05],
phystd = [0.05],priorsNNw = (0.0, 10.0),
param = [Normal(6.5, 0.5), Normal(-3, 0.5)],
n_leapfrog = 30, progress = true)
param = [Normal(6.5, 0.5), Normal(-3, 0.5)],
progress = true)
sol_lux_pestim = solve(prob, alg)
```
Expand All @@ -84,9 +80,11 @@ Kevin Linka, Amelie Schäfer, Xuhui Meng, Zongren Zou, George Em Karniadakis, El
"Bayesian Physics Informed Neural Networks for real-world nonlinear dynamical systems"
"""
struct BNNODE{C, K, ST <: Union{Nothing, AbstractTrainingStrategy}, IT, A, M,
struct BNNODE{C, K, IT <: NamedTuple,
A <: NamedTuple, H <: NamedTuple,
ST <: Union{Nothing, AbstractTrainingStrategy},
I <: Union{Nothing, Vector{<:AbstractFloat}},
P <: Union{Vector{Nothing}, Vector{<:Distribution}},
P <: Union{Nothing, Vector{<:Distribution}},
D <:
Union{Vector{Nothing}, Vector{<:Vector{<:AbstractFloat}}}} <:
NeuralPDEAlgorithm
Expand All @@ -99,48 +97,31 @@ struct BNNODE{C, K, ST <: Union{Nothing, AbstractTrainingStrategy}, IT, A, M,
l2std::Vector{Float64}
phystd::Vector{Float64}
dataset::D
init_params::I
physdt::Float64
MCMCkwargs::H
nchains::Int64
init_params::I
Adaptorkwargs::A
Integratorkwargs::IT
autodiff::Bool
Integrator::IT
Adaptor::A
targetacceptancerate::Float64
Metric::M
jitter_rate::Float64
tempering_rate::Float64
max_depth::Int64
Δ_max::Int64
n_leapfrog::Int64
δ::Float64
λ::Float64
progress::Bool
verbose::Bool

function BNNODE(chain, Kernel = HMC; strategy = nothing,
draw_samples = 2000,
priorsNNw = (0.0, 2.0), param = [nothing], l2std = [0.05],
phystd = [0.05], dataset = [nothing],
init_params = nothing,
physdt = 1 / 20.0, nchains = 1,
autodiff = false, Integrator = Leapfrog,
Adaptor = StanHMCAdaptor, targetacceptancerate = 0.8,
Metric = DiagEuclideanMetric, jitter_rate = 3.0,
tempering_rate = 3.0, max_depth = 10, Δ_max = 1000,
n_leapfrog = 20, δ = 0.65, λ = 0.3, progress = false,
verbose = false)
new{typeof(chain), typeof(Kernel), typeof(strategy), typeof(Integrator),
typeof(Adaptor),
typeof(Metric), typeof(init_params), typeof(param),
typeof(dataset)}(chain, Kernel, strategy, draw_samples,
priorsNNw, param, l2std,
phystd, dataset, init_params,
physdt, nchains, autodiff, Integrator,
Adaptor, targetacceptancerate,
Metric, jitter_rate, tempering_rate,
max_depth, Δ_max, n_leapfrog,
δ, λ, progress, verbose)
end
end
function BNNODE(chain, Kernel = HMC; strategy = nothing, draw_samples = 2000,
priorsNNw = (0.0, 2.0), param = nothing, l2std = [0.05], phystd = [0.05],
dataset = [nothing], physdt = 1 / 20.0, MCMCkwargs = (n_leapfrog = 30,), nchains = 1,
init_params = nothing,
Adaptorkwargs = (Adaptor = StanHMCAdaptor,
Metric = DiagEuclideanMetric,
targetacceptancerate = 0.8),
Integratorkwargs = (Integrator = Leapfrog,),
autodiff = false, progress = false, verbose = false)
BNNODE(chain, Kernel, strategy,
draw_samples, priorsNNw, param, l2std,
phystd, dataset, physdt, MCMCkwargs,
nchains, init_params,
Adaptorkwargs, Integratorkwargs,
autodiff, progress, verbose)
end

"""
Expand Down Expand Up @@ -199,12 +180,12 @@ function DiffEqBase.__solve(prob::DiffEqBase.ODEProblem,
maxiters = nothing,
numensemble = floor(Int, alg.draw_samples / 3))
@unpack chain, l2std, phystd, param, priorsNNw, Kernel, strategy,
draw_samples, dataset, init_params, Integrator, Adaptor, Metric,
nchains, max_depth, Δ_max, n_leapfrog, physdt, targetacceptancerate,
jitter_rate, tempering_rate, δ, λ, autodiff, progress, verbose = alg
draw_samples, dataset, init_params,
nchains, physdt, Adaptorkwargs, Integratorkwargs,
MCMCkwargs, autodiff, progress, verbose = alg

# ahmc_bayesian_pinn_ode needs param=[] for easier vcat operation for full vector of parameters
param = param == [nothing] ? [] : param
param = param === nothing ? [] : param
strategy = strategy === nothing ? GridTraining : strategy

if draw_samples < 0
Expand All @@ -222,16 +203,10 @@ function DiffEqBase.__solve(prob::DiffEqBase.ODEProblem,
nchains = nchains,
autodiff = autodiff,
Kernel = Kernel,
Integrator = Integrator,
Adaptor = Adaptor,
targetacceptancerate = targetacceptancerate,
Metric = Metric,
jitter_rate = jitter_rate,
tempering_rate = tempering_rate,
max_depth = max_depth,
Δ_max = Δ_max,
n_leapfrog = n_leapfrog, δ = δ,
λ = λ, progress = progress,
Adaptorkwargs = Adaptorkwargs,
Integratorkwargs = Integratorkwargs,
MCMCkwargs = MCMCkwargs,
progress = progress,
verbose = verbose)

fullsolution = BPINNstats(mcmcchain, samples, statistics)
Expand Down
Loading

0 comments on commit 00cc7f5

Please sign in to comment.