diff --git a/src/ode_solve.jl b/src/ode_solve.jl index e40499e3a..db4602d73 100644 --- a/src/ode_solve.jl +++ b/src/ode_solve.jl @@ -30,6 +30,7 @@ of the physics-informed neural network which is used as a solver for a standard ## Example ```julia +u0 = [1.0, 1.0] ts=[t for t in 1:100] (u_, t_) = (analytical_func(ts), ts) function additional_loss(phi, θ) @@ -120,23 +121,21 @@ mutable struct ODEPhi{C, T, U, S} end end -function generate_phi_θ(chain::Lux.AbstractExplicitLayer, t, u0, init_params::Nothing) - θ, st = Lux.setup(Random.default_rng(), chain) - ODEPhi(chain, t, u0, st), ComponentArrays.ComponentArray(θ) -end - function generate_phi_θ(chain::Lux.AbstractExplicitLayer, t, u0, init_params) θ, st = Lux.setup(Random.default_rng(), chain) - ODEPhi(chain, t, u0, st), ComponentArrays.ComponentArray(init_params) -end - -function generate_phi_θ(chain::Flux.Chain, t, u0, init_params::Nothing) - θ, re = Flux.destructure(chain) - ODEPhi(re, t, u0), θ + if init_params === nothing + init_params = ComponentArrays.ComponentArray(θ) + else + init_params = ComponentArrays.ComponentArray(init_params) + end + ODEPhi(chain, t, u0, st), init_params end function generate_phi_θ(chain::Flux.Chain, t, u0, init_params) θ, re = Flux.destructure(chain) + if init_params === nothing + init_params = θ + end ODEPhi(re, t, u0), init_params end @@ -258,6 +257,7 @@ Representation of the loss function, parametric on the training strategy `strate function generate_loss(strategy::QuadratureTraining, phi, f, autodiff::Bool, tspan, p, batch) integrand(t::Number, θ) = abs2(inner_loss(phi, f, autodiff, t, θ, p)) + integrand(ts, θ) = [abs2(inner_loss(phi, f, autodiff, t, θ, p)) for t in ts] @assert batch == 0 # not implemented @@ -290,6 +290,7 @@ function generate_loss(strategy::StochasticTraining, phi, f, autodiff::Bool, tsp function loss(θ, _) ts = adapt(parameterless_type(θ), [(tspan[2] - tspan[1]) * rand() + tspan[1] for i in 1:(strategy.points)]) + if batch sum(abs2, inner_loss(phi, f, autodiff, ts, θ, p)) else @@ -330,10 +331,24 @@ function generate_loss(strategy::WeightedIntervalTraining, phi, f, autodiff::Boo return loss end +function evaluate_tstops_loss(phi, f, autodiff::Bool, tstops, p, batch) + + # sum(abs2,inner_loss(t,θ) for t in ts) but Zygote generators are broken + function loss(θ, _) + if batch + sum(abs2, inner_loss(phi, f, autodiff, tstops, θ, p)) + else + sum(abs2, [inner_loss(phi, f, autodiff, t, θ, p) for t in tstops]) + end + end + return loss +end + function generate_loss(strategy::QuasiRandomTraining, phi, f, autodiff::Bool, tspan) error("QuasiRandomTraining is not supported by NNODE since it's for high dimensional spaces only. Use StochasticTraining instead.") end + struct NNODEInterpolation{T <: ODEPhi, T2} phi::T θ::T2 @@ -364,7 +379,8 @@ function DiffEqBase.__solve(prob::DiffEqBase.AbstractODEProblem, reltol = 1.0f-3, verbose = false, saveat = nothing, - maxiters = nothing) + maxiters = nothing, + tstops = nothing) u0 = prob.u0 tspan = prob.tspan f = prob.f @@ -429,9 +445,29 @@ function DiffEqBase.__solve(prob::DiffEqBase.AbstractODEProblem, function total_loss(θ, _) L2_loss = inner_f(θ, phi) if !(additional_loss isa Nothing) - return additional_loss(phi, θ) + L2_loss + L2_loss = L2_loss + additional_loss(phi, θ) end - L2_loss + if !(tstops isa Nothing) + num_tstops_points = length(tstops) + tstops_loss_func = evaluate_tstops_loss(phi, f, autodiff, tstops, p, batch) + total_tstops_loss = tstops_loss_func(θ, phi) * num_tstops_points + if strategy isa GridTraining + num_original_points = length(tspan[1]:(strategy.dx):tspan[2]) + elseif strategy isa WeightedIntervalTraining + num_original_points = strategy.samples + elseif strategy isa StochasticTraining + num_original_points = strategy.points + else + L2_loss = L2_loss + tstops_loss_func(θ, phi) + return L2_loss + end + + total_original_loss = L2_loss * num_original_points + total_points = num_original_points + num_tstops_points + L2_loss = (total_original_loss + total_tstops_loss) / total_points + + end + return L2_loss end # Choice of Optimization Algo for Training Strategies @@ -440,7 +476,6 @@ function DiffEqBase.__solve(prob::DiffEqBase.AbstractODEProblem, else Optimization.AutoZygote() end - # Creates OptimizationFunction Object from total_loss optf = OptimizationFunction(total_loss, opt_algo) diff --git a/src/training_strategies.jl b/src/training_strategies.jl index 6c6dacbb7..3b2f92a4e 100644 --- a/src/training_strategies.jl +++ b/src/training_strategies.jl @@ -331,4 +331,4 @@ function get_loss_function(loss_function, train_set, eltypeθ, strategy::WeightedIntervalTraining; τ = nothing) loss = (θ) -> mean(abs2, loss_function(train_set, θ)) -end +end \ No newline at end of file diff --git a/test/NNODE_tstops_test.jl b/test/NNODE_tstops_test.jl new file mode 100644 index 000000000..7d355897e --- /dev/null +++ b/test/NNODE_tstops_test.jl @@ -0,0 +1,67 @@ +using OrdinaryDiffEq, Lux, OptimizationOptimisers, Test, Statistics, Optimisers, NeuralPDE + +function fu(u, p, t) + [p[1] * u[1] - p[2] * u[1] * u[2], -p[3] * u[2] + p[4] * u[1] * u[2]] +end + +p = [1.5, 1.0, 3.0, 1.0] +u0 = [1.0, 1.0] +tspan = (0.0, 3.0) +points1 = [rand() for i=1:280] +points2 = [rand() + 1 for i=1:80] +points3 = [rand() + 2 for i=1:40] +addedPoints = vcat(points1, points2, points3) + +saveat = 0.01 +maxiters = 30000 + +prob_oop = ODEProblem{false}(fu, u0, tspan, p) +true_sol = solve(prob_oop, Tsit5(), saveat = saveat) +func = Lux.σ +N = 12 +chain = Lux.Chain(Lux.Dense(1, N, func), Lux.Dense(N, N, func), Lux.Dense(N, N, func), + Lux.Dense(N, N, func), Lux.Dense(N, length(u0))) + +opt = Optimisers.Adam(0.01) +threshold = 0.2 + +#bad choices for weights, samples and dx so that the algorithm will fail without the added points +weights = [0.3, 0.3, 0.4] +samples = 3 +dx = 1.0 + +#Grid Training without added points (difference between solutions should be high) +alg = NeuralPDE.NNODE(chain, opt, autodiff = false, strategy = NeuralPDE.GridTraining(dx)) +sol = solve(prob_oop, alg, verbose=true, maxiters = maxiters, saveat = saveat) + +@test abs(mean(sol) - mean(true_sol)) > threshold + +#Grid Training with added points (difference between solutions should be low) +alg = NeuralPDE.NNODE(chain, opt, autodiff = false, strategy = NeuralPDE.GridTraining(dx)) +sol = solve(prob_oop, alg, verbose=true, maxiters = maxiters, saveat = saveat, tstops = addedPoints) + +@test abs(mean(sol) - mean(true_sol)) < threshold + +#WeightedIntervalTraining without added points (difference between solutions should be high) +alg = NeuralPDE.NNODE(chain, opt, autodiff = false, strategy = NeuralPDE.WeightedIntervalTraining(weights, samples)) +sol = solve(prob_oop, alg, verbose=true, maxiters = maxiters, saveat = saveat) + +@test abs(mean(sol) - mean(true_sol)) > threshold + +#WeightedIntervalTraining with added points (difference between solutions should be low) +alg = NeuralPDE.NNODE(chain, opt, autodiff = false, strategy = NeuralPDE.WeightedIntervalTraining(weights, samples)) +sol = solve(prob_oop, alg, verbose=true, maxiters = maxiters, saveat = saveat, tstops = addedPoints) + +@test abs(mean(sol) - mean(true_sol)) < threshold + +#StochasticTraining without added points (difference between solutions should be high) +alg = NeuralPDE.NNODE(chain, opt, autodiff = false, strategy = NeuralPDE.StochasticTraining(samples)) +sol = solve(prob_oop, alg, verbose=true, maxiters = maxiters, saveat = saveat) + +@test abs(mean(sol) - mean(true_sol)) > threshold + +#StochasticTraining with added points (difference between solutions should be low) +alg = NeuralPDE.NNODE(chain, opt, autodiff = false, strategy = NeuralPDE.StochasticTraining(samples)) +sol = solve(prob_oop, alg, verbose=true, maxiters = maxiters, saveat = saveat, tstops = addedPoints) + +@test abs(mean(sol) - mean(true_sol)) < threshold diff --git a/test/runtests.jl b/test/runtests.jl index 2a176cf87..afe7186a2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -25,6 +25,7 @@ end end if GROUP == "All" || GROUP == "NNODE" @time @safetestset "NNODE" begin include("NNODE_tests.jl") end + @time @safetestset "NNODE_tstops" begin include("NNODE_tstops_test.jl") end end if GROUP == "All" || GROUP == "NNPDE2" @@ -60,4 +61,4 @@ end @safetestset "NNPDE_gpu" begin include("NNPDE_tests_gpu.jl") end @safetestset "NNPDE_gpu_Lux" begin include("NNPDE_tests_gpu_Lux.jl") end end -end +end \ No newline at end of file