Skip to content

Commit

Permalink
changes
Browse files Browse the repository at this point in the history
  • Loading branch information
hippyhippohops committed Jul 11, 2024
1 parent 6194b64 commit 7d26f4c
Show file tree
Hide file tree
Showing 3 changed files with 198 additions and 11 deletions.
99 changes: 98 additions & 1 deletion src/dae_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,25 @@ function NNDAE(chain, opt, init_params = nothing; strategy = nothing, autodiff =
NNDAE(chain, opt, init_params, autodiff, strategy, kwargs)
end


function dfdx(phi::ODEPhi{C, T, U}, t::Number, θ,
autodiff::Bool, differential_vars::AbstractVector) where {C, T, U <: Number}
if autodiff
ForwardDiff.derivative(t -> phi(t, θ), t)
else
(phi(t + sqrt(eps(typeof(t))), θ) - phi(t, θ)) / sqrt(eps(typeof(t)))
end
end

function dfdx(phi::ODEPhi{C, T, U}, t::Number, θ,
autodiff::Bool,differential_vars::AbstractVector) where {C, T, U <: AbstractVector}
if autodiff
ForwardDiff.jacobian(t -> phi(t, θ), t)
else
(phi(t + sqrt(eps(typeof(t))), θ) - phi(t, θ)) / sqrt(eps(typeof(t)))
end
end

function dfdx(phi::ODEPhi, t::AbstractVector, θ, autodiff::Bool,
differential_vars::AbstractVector)
if autodiff
Expand All @@ -69,6 +88,19 @@ function inner_loss(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::AbstractVector,
sum(abs2, loss) / length(t)
end

#=
function inner_loss(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::Number, θ,
p, differential_vars::AbstractVector) where {C, T, U}
sum(abs2, dfdx(phi, t, θ, autodiff,differential_vars) .- f(phi(t, θ), t))
end
=#

function inner_loss(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::Number, θ,
p, differential_vars::AbstractVector) where {C, T, U}
dphi = dfdx(phi, t, θ, autodiff,differential_vars)
sum(abs2, f(dphi, phi(t, θ), p, t))
end

function generate_loss(strategy::GridTraining, phi, f, autodiff::Bool, tspan, p,
differential_vars::AbstractVector)
ts = tspan[1]:(strategy.dx):tspan[2]
Expand All @@ -79,6 +111,65 @@ function generate_loss(strategy::GridTraining, phi, f, autodiff::Bool, tspan, p,
return loss
end

function generate_loss(
strategy::WeightedIntervalTraining, phi, f, autodiff::Bool, tspan, p,
differential_vars::AbstractVector)
autodiff && throw(ArgumentError("autodiff not supported for GridTraining."))
minT = tspan[1]
maxT = tspan[2]

weights = strategy.weights ./ sum(strategy.weights)

N = length(weights)
points = strategy.points

difference = (maxT - minT) / N

data = Float64[]
for (index, item) in enumerate(weights)
temp_data = rand(1, trunc(Int, points * item)) .* difference .+ minT .+
((index - 1) * difference)
data = append!(data, temp_data)
end

ts = data

function loss(θ, _)
sum(inner_loss(phi, f, autodiff, ts, θ, p, differential_vars))
end
return loss
end


function generate_loss(strategy::QuadratureTraining, phi, f, autodiff::Bool, tspan, p,
differential_vars::AbstractVector)
integrand(t::Number, θ) = abs2(inner_loss(phi, f, autodiff, t, θ, p, differential_vars))

function integrand(ts, θ)
[sum(abs2, inner_loss(phi, f, autodiff, t, θ, p, differential_vars)) for t in ts]
end

function loss(θ, _)
intf = BatchIntegralFunction(integrand, max_batch = strategy.batch)
intprob = IntegralProblem(intf, (tspan[1], tspan[2]), θ)
sol = solve(intprob, strategy.quadrature_alg; abstol = strategy.abstol,
reltol = strategy.reltol, maxiters = strategy.maxiters)
sol.u
end
return loss
end

function generate_loss(strategy::StochasticTraining, phi, f, autodiff::Bool, tspan, p,
differential_vars::AbstractVector)
autodiff && throw(ArgumentError("autodiff not supported for StochasticTraining."))
function loss(θ, _)
ts = adapt(parameterless_type(θ),
[(tspan[2] - tspan[1]) * rand() + tspan[1] for i in 1:(strategy.points)])
sum(inner_loss(phi, f, autodiff, ts, θ, p, differential_vars))
end
return loss
end

function SciMLBase.__solve(prob::SciMLBase.AbstractDAEProblem,
alg::NNDAE,
args...;
Expand Down Expand Up @@ -136,8 +227,13 @@ function SciMLBase.__solve(prob::SciMLBase.AbstractDAEProblem,
if dt !== nothing
GridTraining(dt)
else
error("dt is not defined")
QuadratureTraining(; quadrature_alg = QuadGKJL(),
reltol = convert(eltype(u0), reltol),
abstol = convert(eltype(u0), abstol), maxiters = maxiters,
batch = 0)
end
else
alg.strategy
end

inner_f = generate_loss(strategy, phi, f, autodiff, tspan, p, differential_vars)
Expand Down Expand Up @@ -189,3 +285,4 @@ function SciMLBase.__solve(prob::SciMLBase.AbstractDAEProblem,
dense_errors = false)
sol
end

5 changes: 5 additions & 0 deletions src/ode_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ end
Representation of the loss function, parametric on the training strategy `strategy`.
"""

function generate_loss(strategy::QuadratureTraining, phi, f, autodiff::Bool, tspan, p,
batch, param_estim::Bool)
integrand(t::Number, θ) = abs2(inner_loss(phi, f, autodiff, t, θ, p, param_estim))
Expand Down Expand Up @@ -304,6 +305,8 @@ function generate_loss(
return loss
end



function evaluate_tstops_loss(phi, f, autodiff::Bool, tstops, p, batch, param_estim::Bool)
function loss(θ, _)
if batch
Expand All @@ -319,6 +322,7 @@ function generate_loss(strategy::QuasiRandomTraining, phi, f, autodiff::Bool, ts
error("QuasiRandomTraining is not supported by NNODE since it's for high dimensional spaces only. Use StochasticTraining instead.")
end


struct NNODEInterpolation{T <: ODEPhi, T2}
phi::T
θ::T2
Expand Down Expand Up @@ -490,3 +494,4 @@ function SciMLBase.__solve(prob::SciMLBase.AbstractODEProblem,
dense_errors = false)
sol
end #solve

105 changes: 95 additions & 10 deletions test/NNDAE_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ Random.seed!(100)
M = [1.0 0
0 0]
f = ODEFunction(example1, mass_matrix = M)
tspan = (0.0f0, 1.0f0)
tspan = (0.0, 1.0)

prob_mm = ODEProblem(f, u₀, tspan)
ground_sol = solve(prob_mm, Rodas5(), reltol = 1e-8, abstol = 1e-8)
Expand All @@ -25,13 +25,13 @@ Random.seed!(100)
differential_vars = [true, false]
prob = DAEProblem(example, du₀, u₀, tspan; differential_vars = differential_vars)
chain = Lux.Chain(Lux.Dense(1, 15, cos), Lux.Dense(15, 15, sin), Lux.Dense(15, 2))
opt = OptimizationOptimisers.Adam(0.1)
alg = NeuralPDE.NNDAE(chain, opt; autodiff = false)
opt = OptimizationOptimJL.BFGS(linesearch = BackTracking())
alg = NNDAE(chain, opt; autodiff = false)

sol = solve(prob,
alg, verbose = false, dt = 1 / 100.0f0,
maxiters = 3000, abstol = 1.0f-10)
@test ground_sol(0:(1 / 100):1)sol atol=0.4
alg, verbose = false, dt = 1 / 100.0,
maxiters = 3000, abstol = 1e-10)
@test reduce(hcat, ground_sol(0:(1 / 100):1).u)reduce(hcat, sol.u) rtol=1e-1
end

@testset "Example 2" begin
Expand All @@ -44,7 +44,7 @@ end
0 1]
u₀ = [0.0, 0.0]
du₀ = [0.0, 0.0]
tspan = (0.0f0, pi / 2.0f0)
tspan = (0.0, pi / 2.0)
f = ODEFunction(example2, mass_matrix = M)
prob_mm = ODEProblem(f, u₀, tspan)
ground_sol = solve(prob_mm, Rodas5(), reltol = 1e-8, abstol = 1e-8)
Expand All @@ -57,8 +57,93 @@ end
alg = NNDAE(chain, OptimizationOptimisers.Adam(0.1); autodiff = false)

sol = solve(prob,
alg, verbose = false, dt = 1 / 100.0f0,
maxiters = 3000, abstol = 1.0f-10)
alg, verbose = false, dt = 1 / 100.0,
maxiters = 3000, abstol = 1e-10)

@test reduce(hcat, ground_sol(0:(1 / 100):(pi / 2.0)).u)reduce(hcat, sol.u) rtol=1e-2
end

@testset "WeightedIntervalTraining" begin
function example2(du, u, p, t)
du[1] = u[1] - t
du[2] = u[2] - t
nothing
end
M = [0.0 0.0
0.0 1.0]
u₀ = [0.0, 0.0]
du₀ = [0.0, 0.0]
tspan = (0.0, pi / 2.0)
f = ODEFunction(example2, mass_matrix = M)
prob_mm = ODEProblem(f, u₀, tspan)
ground_sol = solve(prob_mm, Rodas5(), reltol = 1e-8, abstol = 1e-8)

example = (du, u, p, t) -> [u[1] - t - du[1], u[2] - t - du[2]]
differential_vars = [false, true]
prob = DAEProblem(example, du₀, u₀, tspan; differential_vars = differential_vars)
chain = Lux.Chain(Lux.Dense(1, 15, Lux.σ), Lux.Dense(15, 2))
opt = OptimizationOptimisers.Adam(0.1)
weights = [0.7, 0.2, 0.1]
points = 200
alg = NNDAE(chain, OptimizationOptimisers.Adam(0.1),
strategy = WeightedIntervalTraining(weights, points); autodiff = false)

sol = solve(prob,
alg, verbose = false, dt = 1 / 100.0,
maxiters = 3000, abstol = 1e-10)

@test reduce(hcat, ground_sol(0:(1 / 100):(pi / 2.0)).u)reduce(hcat, sol.u) rtol=1e-2
end

@testset "StochasticTraining" begin
function example2(du, u, p, t)
du[1] = u[1] - t
du[2] = u[2] - t
nothing
end
M = [0.0 0.0
0.0 1.0]
u₀ = [0.0, 0.0]
du₀ = [0.0, 0.0]
tspan = (0.0, pi / 2.0)
f = ODEFunction(example2, mass_matrix = M)
prob_mm = ODEProblem(f, u₀, tspan)
ground_sol = solve(prob_mm, Rodas5(), reltol = 1e-8, abstol = 1e-8)

example = (du, u, p, t) -> [u[1] - t - du[1], u[2] - t - du[2]]
differential_vars = [false, true]
prob = DAEProblem(example, du₀, u₀, tspan; differential_vars = differential_vars)
chain = Lux.Chain(Lux.Dense(1, 15, Lux.σ), Lux.Dense(15, 2))
opt = OptimizationOptimisers.Adam(0.1)
alg = NeuralPDE.NNDAE(chain, OptimizationOptimisers.Adam(0.1),
strategy = NeuralPDE.StochasticTraining(1000); autodiff = false)
sol = solve(prob,
alg, verbose = false, dt = 1 / 100.0,
maxiters = 3000, abstol = 1e-10)
@test reduce(hcat, ground_sol(0:(1 / 100):(pi / 2.0)).u)reduce(hcat, sol.u) rtol=1e-2
end

@testset "QuadratureTraining" begin
function example2(du, u, p, t)
du[1] = u[1] - t
du[2] = u[2] - t
nothing
end
M = [0.0 0.0
0.0 1.0]
u₀ = [0.0, 0.0]
du₀ = [0.0, 0.0]
tspan = (0.0, pi / 2.0)
f = ODEFunction(example2, mass_matrix = M)
prob_mm = ODEProblem(f, u₀, tspan)
ground_sol = solve(prob_mm, Rodas5(), reltol = 1e-8, abstol = 1e-8)

@test ground_sol(0:(1 / 100):(pi / 2))sol atol=0.4
example = (du, u, p, t) -> [u[1] - t - du[1], u[2] - t - du[2]]
differential_vars = [false, true]
prob = DAEProblem(example, du₀, u₀, tspan; differential_vars = differential_vars)
chain = Lux.Chain(Lux.Dense(1, 15, Lux.σ), Lux.Dense(15, 2))
opt = OptimizationOptimJL.BFGS(linesearch = BackTracking())
alg = NeuralPDE.NNDAE(chain, opt; autodiff = false)
sol = solve(prob, alg, verbose = true, maxiters = 6000, abstol = 1e-10, dt = 1/100.0)
@test reduce(hcat, ground_sol(0:(1 / 100):(pi / 2.0)).u)reduce(hcat, sol.u) rtol=1e-2
end

0 comments on commit 7d26f4c

Please sign in to comment.