Skip to content

Commit

Permalink
Merge pull request #723 from sdesai1287/given_points_training
Browse files Browse the repository at this point in the history
Incorporating given points into training
  • Loading branch information
ChrisRackauckas authored Aug 22, 2023
2 parents 4a469b0 + 6478ea9 commit 14721af
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 17 deletions.
65 changes: 50 additions & 15 deletions src/ode_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ of the physics-informed neural network which is used as a solver for a standard
## Example
```julia
u0 = [1.0, 1.0]
ts=[t for t in 1:100]
(u_, t_) = (analytical_func(ts), ts)
function additional_loss(phi, θ)
Expand Down Expand Up @@ -120,23 +121,21 @@ mutable struct ODEPhi{C, T, U, S}
end
end

function generate_phi_θ(chain::Lux.AbstractExplicitLayer, t, u0, init_params::Nothing)
θ, st = Lux.setup(Random.default_rng(), chain)
ODEPhi(chain, t, u0, st), ComponentArrays.ComponentArray(θ)
end

function generate_phi_θ(chain::Lux.AbstractExplicitLayer, t, u0, init_params)
θ, st = Lux.setup(Random.default_rng(), chain)
ODEPhi(chain, t, u0, st), ComponentArrays.ComponentArray(init_params)
end

function generate_phi_θ(chain::Flux.Chain, t, u0, init_params::Nothing)
θ, re = Flux.destructure(chain)
ODEPhi(re, t, u0), θ
if init_params === nothing
init_params = ComponentArrays.ComponentArray(θ)
else
init_params = ComponentArrays.ComponentArray(init_params)
end
ODEPhi(chain, t, u0, st), init_params
end

function generate_phi_θ(chain::Flux.Chain, t, u0, init_params)
θ, re = Flux.destructure(chain)
if init_params === nothing
init_params = θ
end
ODEPhi(re, t, u0), init_params
end

Expand Down Expand Up @@ -258,6 +257,7 @@ Representation of the loss function, parametric on the training strategy `strate
function generate_loss(strategy::QuadratureTraining, phi, f, autodiff::Bool, tspan, p,
batch)
integrand(t::Number, θ) = abs2(inner_loss(phi, f, autodiff, t, θ, p))

integrand(ts, θ) = [abs2(inner_loss(phi, f, autodiff, t, θ, p)) for t in ts]
@assert batch == 0 # not implemented

Expand Down Expand Up @@ -290,6 +290,7 @@ function generate_loss(strategy::StochasticTraining, phi, f, autodiff::Bool, tsp
function loss(θ, _)
ts = adapt(parameterless_type(θ),
[(tspan[2] - tspan[1]) * rand() + tspan[1] for i in 1:(strategy.points)])

if batch
sum(abs2, inner_loss(phi, f, autodiff, ts, θ, p))
else
Expand Down Expand Up @@ -330,10 +331,24 @@ function generate_loss(strategy::WeightedIntervalTraining, phi, f, autodiff::Boo
return loss
end

function evaluate_tstops_loss(phi, f, autodiff::Bool, tstops, p, batch)

# sum(abs2,inner_loss(t,θ) for t in ts) but Zygote generators are broken
function loss(θ, _)
if batch
sum(abs2, inner_loss(phi, f, autodiff, tstops, θ, p))
else
sum(abs2, [inner_loss(phi, f, autodiff, t, θ, p) for t in tstops])
end
end
return loss
end

function generate_loss(strategy::QuasiRandomTraining, phi, f, autodiff::Bool, tspan)
error("QuasiRandomTraining is not supported by NNODE since it's for high dimensional spaces only. Use StochasticTraining instead.")
end


struct NNODEInterpolation{T <: ODEPhi, T2}
phi::T
θ::T2
Expand Down Expand Up @@ -364,7 +379,8 @@ function DiffEqBase.__solve(prob::DiffEqBase.AbstractODEProblem,
reltol = 1.0f-3,
verbose = false,
saveat = nothing,
maxiters = nothing)
maxiters = nothing,
tstops = nothing)
u0 = prob.u0
tspan = prob.tspan
f = prob.f
Expand Down Expand Up @@ -429,9 +445,29 @@ function DiffEqBase.__solve(prob::DiffEqBase.AbstractODEProblem,
function total_loss(θ, _)
L2_loss = inner_f(θ, phi)
if !(additional_loss isa Nothing)
return additional_loss(phi, θ) + L2_loss
L2_loss = L2_loss + additional_loss(phi, θ)
end
L2_loss
if !(tstops isa Nothing)
num_tstops_points = length(tstops)
tstops_loss_func = evaluate_tstops_loss(phi, f, autodiff, tstops, p, batch)
total_tstops_loss = tstops_loss_func(θ, phi) * num_tstops_points
if strategy isa GridTraining
num_original_points = length(tspan[1]:(strategy.dx):tspan[2])
elseif strategy isa WeightedIntervalTraining
num_original_points = strategy.samples
elseif strategy isa StochasticTraining
num_original_points = strategy.points
else
L2_loss = L2_loss + tstops_loss_func(θ, phi)
return L2_loss
end

total_original_loss = L2_loss * num_original_points
total_points = num_original_points + num_tstops_points
L2_loss = (total_original_loss + total_tstops_loss) / total_points

end
return L2_loss
end

# Choice of Optimization Algo for Training Strategies
Expand All @@ -440,7 +476,6 @@ function DiffEqBase.__solve(prob::DiffEqBase.AbstractODEProblem,
else
Optimization.AutoZygote()
end

# Creates OptimizationFunction Object from total_loss
optf = OptimizationFunction(total_loss, opt_algo)

Expand Down
2 changes: 1 addition & 1 deletion src/training_strategies.jl
Original file line number Diff line number Diff line change
Expand Up @@ -331,4 +331,4 @@ function get_loss_function(loss_function, train_set, eltypeθ,
strategy::WeightedIntervalTraining;
τ = nothing)
loss = (θ) -> mean(abs2, loss_function(train_set, θ))
end
end
67 changes: 67 additions & 0 deletions test/NNODE_tstops_test.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
using OrdinaryDiffEq, Lux, OptimizationOptimisers, Test, Statistics, Optimisers, NeuralPDE

function fu(u, p, t)
[p[1] * u[1] - p[2] * u[1] * u[2], -p[3] * u[2] + p[4] * u[1] * u[2]]
end

p = [1.5, 1.0, 3.0, 1.0]
u0 = [1.0, 1.0]
tspan = (0.0, 3.0)
points1 = [rand() for i=1:280]
points2 = [rand() + 1 for i=1:80]
points3 = [rand() + 2 for i=1:40]
addedPoints = vcat(points1, points2, points3)

saveat = 0.01
maxiters = 30000

prob_oop = ODEProblem{false}(fu, u0, tspan, p)
true_sol = solve(prob_oop, Tsit5(), saveat = saveat)
func = Lux.σ
N = 12
chain = Lux.Chain(Lux.Dense(1, N, func), Lux.Dense(N, N, func), Lux.Dense(N, N, func),
Lux.Dense(N, N, func), Lux.Dense(N, length(u0)))

opt = Optimisers.Adam(0.01)
threshold = 0.2

#bad choices for weights, samples and dx so that the algorithm will fail without the added points
weights = [0.3, 0.3, 0.4]
samples = 3
dx = 1.0

#Grid Training without added points (difference between solutions should be high)
alg = NeuralPDE.NNODE(chain, opt, autodiff = false, strategy = NeuralPDE.GridTraining(dx))
sol = solve(prob_oop, alg, verbose=true, maxiters = maxiters, saveat = saveat)

@test abs(mean(sol) - mean(true_sol)) > threshold

#Grid Training with added points (difference between solutions should be low)
alg = NeuralPDE.NNODE(chain, opt, autodiff = false, strategy = NeuralPDE.GridTraining(dx))
sol = solve(prob_oop, alg, verbose=true, maxiters = maxiters, saveat = saveat, tstops = addedPoints)

@test abs(mean(sol) - mean(true_sol)) < threshold

#WeightedIntervalTraining without added points (difference between solutions should be high)
alg = NeuralPDE.NNODE(chain, opt, autodiff = false, strategy = NeuralPDE.WeightedIntervalTraining(weights, samples))
sol = solve(prob_oop, alg, verbose=true, maxiters = maxiters, saveat = saveat)

@test abs(mean(sol) - mean(true_sol)) > threshold

#WeightedIntervalTraining with added points (difference between solutions should be low)
alg = NeuralPDE.NNODE(chain, opt, autodiff = false, strategy = NeuralPDE.WeightedIntervalTraining(weights, samples))
sol = solve(prob_oop, alg, verbose=true, maxiters = maxiters, saveat = saveat, tstops = addedPoints)

@test abs(mean(sol) - mean(true_sol)) < threshold

#StochasticTraining without added points (difference between solutions should be high)
alg = NeuralPDE.NNODE(chain, opt, autodiff = false, strategy = NeuralPDE.StochasticTraining(samples))
sol = solve(prob_oop, alg, verbose=true, maxiters = maxiters, saveat = saveat)

@test abs(mean(sol) - mean(true_sol)) > threshold

#StochasticTraining with added points (difference between solutions should be low)
alg = NeuralPDE.NNODE(chain, opt, autodiff = false, strategy = NeuralPDE.StochasticTraining(samples))
sol = solve(prob_oop, alg, verbose=true, maxiters = maxiters, saveat = saveat, tstops = addedPoints)

@test abs(mean(sol) - mean(true_sol)) < threshold
3 changes: 2 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ end
end
if GROUP == "All" || GROUP == "NNODE"
@time @safetestset "NNODE" begin include("NNODE_tests.jl") end
@time @safetestset "NNODE_tstops" begin include("NNODE_tstops_test.jl") end
end

if GROUP == "All" || GROUP == "NNPDE2"
Expand Down Expand Up @@ -60,4 +61,4 @@ end
@safetestset "NNPDE_gpu" begin include("NNPDE_tests_gpu.jl") end
@safetestset "NNPDE_gpu_Lux" begin include("NNPDE_tests_gpu_Lux.jl") end
end
end
end

0 comments on commit 14721af

Please sign in to comment.