Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NNODE for DAE Problems #695

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/NeuralPDE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ include("discretize.jl")
include("neural_adapter.jl")
include("advancedHMC_MCMC.jl")

export NNODE, TerminalPDEProblem, NNPDEHan, NNPDENS, NNRODE,
export NNODE, NNDAE, TerminalPDEProblem, NNPDEHan, NNPDENS, NNRODE,
KolmogorovPDEProblem, NNKolmogorov, NNStopping, ParamKolmogorovPDEProblem,
KolmogorovParamDomain, NNParamKolmogorov,
PhysicsInformedNN, discretize,
Expand Down
27 changes: 27 additions & 0 deletions src/my_test.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@

using OrdinaryDiffEq, OptimizationPolyalgorithms, Lux, OptimizationOptimJL, Test, Statistics, Plots, Optimisers

function fu(u, p, t)
[p[1] * u[1] - p[2] * u[1] * u[2], -p[3] * u[2] + p[4] * u[1] * u[2]]
end

p = [1.5, 1.0, 3.0, 1.0]
u0 = [1.0, 1.0]
prob_oop = ODEProblem{false}(fu, u0, (0.0, 3.0), p)
true_sol = solve(prob_oop, Tsit5(), saveat = 0.01)
func = Lux.σ
N = 12
chain = Lux.Chain(Lux.Dense(1, N, func), Lux.Dense(N, length(u0)))

opt = Optimisers.Adam(0.01)
dx=0.05
alg = NeuralPDE.NNODE(chain, opt, autodiff = false, strategy = NeuralPDE.GridTraining(dx))
sol = solve(prob_oop, alg, verbose=true, maxiters = 3, saveat = 0.01)

@test abs(mean(sol) - mean(true_sol)) < 0.2

# using Plots

# plot(sol)
# plot!(true_sol)
# ylims!(0,8)
292 changes: 279 additions & 13 deletions src/ode_solve.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
abstract type NeuralPDEAlgorithm <: DiffEqBase.AbstractODEAlgorithm end
abstract type NeuralPDEAlgorithmDAE <: DiffEqBase.AbstractDAEAlgorithm end

"""
```julia
Expand Down Expand Up @@ -29,7 +30,8 @@ of the physics-informed neural network which is used as a solver for a standard

## Example

```julia
```juliap = [1.5, 1.0, 3.0, 1.0]
u0 = [1.0, 1.0]
ts=[t for t in 1:100]
(u_, t_) = (analytical_func(ts), ts)
function additional_loss(phi, θ)
Expand Down Expand Up @@ -96,6 +98,25 @@ function NNODE(chain, opt, init_params = nothing;
NNODE(chain, opt, init_params, autodiff, batch, strategy, additional_loss, kwargs)
end

struct NNDAE{C, O, P, B, K, AL <: Union{Nothing, Function},
S <: Union{Nothing, AbstractTrainingStrategy}
} <:
NeuralPDEAlgorithmDAE
chain::C
opt::O
init_params::P
autodiff::Bool
batch::B
strategy::S
additional_loss::AL
kwargs::K
end
function NNDAE(chain, opt, init_params = nothing;
strategy = nothing,
autodiff = false, batch = nothing, additional_loss = nothing, kwargs...)
NNDAE(chain, opt, init_params, autodiff, batch, strategy, additional_loss, kwargs)
end

"""
```julia
ODEPhi(chain::Lux.AbstractExplicitLayer, t, u0, st)
Expand All @@ -120,23 +141,21 @@ mutable struct ODEPhi{C, T, U, S}
end
end

function generate_phi_θ(chain::Lux.AbstractExplicitLayer, t, u0, init_params::Nothing)
θ, st = Lux.setup(Random.default_rng(), chain)
ODEPhi(chain, t, u0, st), ComponentArrays.ComponentArray(θ)
end

function generate_phi_θ(chain::Lux.AbstractExplicitLayer, t, u0, init_params)
θ, st = Lux.setup(Random.default_rng(), chain)
ODEPhi(chain, t, u0, st), ComponentArrays.ComponentArray(init_params)
end

function generate_phi_θ(chain::Flux.Chain, t, u0, init_params::Nothing)
θ, re = Flux.destructure(chain)
ODEPhi(re, t, u0), θ
if init_params === nothing
init_params = ComponentArrays.ComponentArray(θ)
else
init_params = ComponentArrays.ComponentArray(init_params)
end
ODEPhi(chain, t, u0, st), init_params
end

function generate_phi_θ(chain::Flux.Chain, t, u0, init_params)
θ, re = Flux.destructure(chain)
if init_params === nothing
init_params = θ
end
ODEPhi(re, t, u0), init_params
end

Expand Down Expand Up @@ -252,6 +271,39 @@ function inner_loss(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::AbstractVector,
sum(abs2, dxdtguess .- fs) / length(t)
end

"""
L2 inner loss for DAEProblems
"""

function inner_loss_DAE end

function inner_loss_DAE(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::Number, θ,
p) where {C, T, U <: Number}
sum(abs2,f(ode_dfdx(phi, t, θ, autodiff), phi(t, θ), p, t))
end

function inner_loss_DAE(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::AbstractVector, θ,
p) where {C, T, U <: Number}
out = phi(t, θ)
dxdtguess = Array(ode_dfdx(phi, t, θ, autodiff))
fs = reduce(hcat, [f(dxdtguess[:, i], out, p, arrt[i]) for i in 1:size(out, 2)])
sum(abs2, fs) / length(t)
end

function inner_loss_DAE(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::Number, θ,
p) where {C, T, U}
sum(abs2,f(ode_dfdx(phi, t, θ, autodiff), phi(t, θ), p, t))
end

function inner_loss_DAE(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::AbstractVector, θ,
p) where {C, T, U}
out = Array(phi(t, θ))
arrt = Array(t)
dxdtguess = Array(ode_dfdx(phi, t, θ, autodiff))
fs = reduce(hcat, [f(dxdtguess[:, i], out, p, arrt[i]) for i in 1:size(out, 2)])
sum(abs2, fs) / length(t)
end

"""
Representation of the loss function, parametric on the training strategy `strategy`
"""
Expand Down Expand Up @@ -334,6 +386,89 @@ function generate_loss(strategy::QuasiRandomTraining, phi, f, autodiff::Bool, ts
error("QuasiRandomTraining is not supported by NNODE since it's for high dimensional spaces only. Use StochasticTraining instead.")
end

"""
Representation of the loss function, parametric on the training strategy `strategy` for DAE problems
"""
function generate_loss_DAE(strategy::QuadratureTraining, phi, f, autodiff::Bool, tspan, p,
batch)
integrand(t::Number, θ) = abs2(inner_loss_DAE(phi, f, autodiff, t, θ, p))
integrand(ts, θ) = [abs2(inner_loss_DAE(phi, f, autodiff, t, θ, p)) for t in ts]
@assert batch == 0 # not implemented

function loss(θ, _)
intprob = IntegralProblem(integrand, tspan[1], tspan[2], θ)
sol = solve(intprob, QuadGKJL(); abstol = strategy.abstol, reltol = strategy.reltol)
sol.u
end

return loss
end

function generate_loss_DAE(strategy::GridTraining, phi, f, autodiff::Bool, tspan, p, batch)
ts = tspan[1]:(strategy.dx):tspan[2]

# sum(abs2,inner_loss(t,θ) for t in ts) but Zygote generators are broken
function loss(θ, _)
if batch
sum(abs2, inner_loss_DAE(phi, f, autodiff, ts, θ, p))
else
sum(abs2, [inner_loss_DAE(phi, f, autodiff, t, θ, p) for t in ts])
end
end
return loss
end

function generate_loss_DAE(strategy::StochasticTraining, phi, f, autodiff::Bool, tspan, p,
batch)
# sum(abs2,inner_loss(t,θ) for t in ts) but Zygote generators are broken
function loss(θ, _)
ts = adapt(parameterless_type(θ),
[(tspan[2] - tspan[1]) * rand() + tspan[1] for i in 1:(strategy.points)])
if batch
sum(abs2, inner_loss_DAE(phi, f, autodiff, ts, θ, p))
else
sum(abs2, [inner_loss_DAE(phi, f, autodiff, t, θ, p) for t in ts])
end
end
return loss
end

function generate_loss_DAE(strategy::WeightedIntervalTraining, phi, f, autodiff::Bool, tspan, p,
batch)
minT = tspan[1]
maxT = tspan[2]

weights = strategy.weights ./ sum(strategy.weights)

N = length(weights)
samples = strategy.samples

difference = (maxT - minT) / N

data = Float64[]
for (index, item) in enumerate(weights)
temp_data = rand(1, trunc(Int, samples * item)) .* difference .+ minT .+
((index - 1) * difference)
data = append!(data, temp_data)
end

ts = data

function loss(θ, _)
if batch
sum(abs2, inner_loss_DAE(phi, f, autodiff, ts, θ, p))
else
sum(abs2, [inner_loss_DAE(phi, f, autodiff, t, θ, p) for t in ts])
end
end
return loss
end

function generate_loss_DAE(strategy::QuasiRandomTraining, phi, f, autodiff::Bool, tspan, p, batch)
error("QuasiRandomTraining is not supported by NNODE since it's for high dimensional spaces only. Use StochasticTraining instead.")
end


struct NNODEInterpolation{T <: ODEPhi, T2}
phi::T
θ::T2
Expand Down Expand Up @@ -440,7 +575,6 @@ function DiffEqBase.__solve(prob::DiffEqBase.AbstractODEProblem,
else
Optimization.AutoZygote()
end

# Creates OptimizationFunction Object from total_loss
optf = OptimizationFunction(total_loss, opt_algo)

Expand Down Expand Up @@ -483,3 +617,135 @@ function DiffEqBase.__solve(prob::DiffEqBase.AbstractODEProblem,
dense_errors = false)
sol
end #solve

function DiffEqBase.__solve(prob::DiffEqBase.AbstractDAEProblem,
alg::NNDAE,
args...;
dt = nothing,
timeseries_errors = true,
save_everystep = true,
adaptive = false,
abstol = 1.0f-6,
reltol = 1.0f-3,
verbose = false,
saveat = nothing,
maxiters = nothing)

u0 = prob.u0
tspan = prob.tspan
f = prob.f
p = prob.p
t0 = tspan[1]

#hidden layer
chain = alg.chain
opt = alg.opt
autodiff = alg.autodiff

#train points generation
init_params = alg.init_params

if chain isa Lux.AbstractExplicitLayer || chain isa Flux.Chain
phi, init_params = generate_phi_θ(chain, t0, u0, init_params)
else
error("Only Lux.AbstractExplicitLayer and Flux.Chain neural networks are supported")
end

# if isinplace(prob)
# throw(error("The NNODE solver only supports out-of-place ODE definitions, i.e. du=f(u,p,t)."))
# end

try
phi(t0, init_params)
catch err
if isa(err, DimensionMismatch)
throw(DimensionMismatch("Dimensions of the initial u0 and chain should match"))
else
throw(err)
end
end

strategy = if alg.strategy === nothing
if dt !== nothing
GridTraining(dt)
else
QuadratureTraining(; quadrature_alg = QuadGKJL(),
reltol = convert(eltype(u0), reltol),
abstol = convert(eltype(u0), abstol), maxiters = maxiters,
batch = 0)
end
else
alg.strategy
end

batch = if alg.batch === nothing
if strategy isa QuadratureTraining
strategy.batch
else
true
end
else
alg.batch
end

inner_f = generate_loss_DAE(strategy, phi, f, autodiff, tspan, p, batch)
additional_loss = alg.additional_loss

# Creates OptimizationFunction Object from total_loss
function total_loss(θ, _)
L2_loss = inner_f(θ, phi)
if !(additional_loss isa Nothing)
return additional_loss(phi, θ) + L2_loss
end
L2_loss
end
# Choice of Optimization Algo for Training Strategies
opt_algo = if strategy isa QuadratureTraining
Optimization.AutoForwardDiff()
else
Optimization.AutoZygote()
end

# Creates OptimizationFunction Object from total_loss
optf = OptimizationFunction(total_loss, opt_algo)

iteration = 0
callback = function (p, l)
iteration += 1
verbose && println("Current loss is: $l, Iteration: $iteration")
l < abstol
end

optprob = OptimizationProblem(optf, init_params)
println("attempting to solve")
res = solve(optprob, opt; callback, maxiters, alg.kwargs...)

#solutions at timepoints
if saveat isa Number
ts = tspan[1]:saveat:tspan[2]
elseif saveat isa AbstractArray
ts = saveat
elseif dt !== nothing
ts = tspan[1]:dt:tspan[2]
elseif save_everystep
ts = range(tspan[1], tspan[2], length = 100)
else
ts = [tspan[1], tspan[2]]
end

if u0 isa Number
u = [first(phi(t, res.u)) for t in ts]
else
u = [phi(t, res.u) for t in ts]
end

sol = DiffEqBase.build_solution(prob, alg, ts, u;
k = res, dense = true,
interp = NNODEInterpolation(phi, res.u),
calculate_error = false,
retcode = ReturnCode.Success)
DiffEqBase.has_analytic(prob.f) &&
DiffEqBase.calculate_solution_errors!(sol; timeseries_errors = true,
dense_errors = false)
sol
end #solve
Loading
Loading