diff --git a/src/NeuralPDE.jl b/src/NeuralPDE.jl index a2ffc2370..411a26dd1 100644 --- a/src/NeuralPDE.jl +++ b/src/NeuralPDE.jl @@ -47,6 +47,7 @@ include("adaptive_losses.jl") include("ode_solve.jl") # include("rode_solve.jl") include("dae_solve.jl") +include("refactored_solve.jl") include("transform_inf_integral.jl") include("discretize.jl") include("neural_adapter.jl") diff --git a/src/refactored_solve.jl b/src/refactored_solve.jl index f24603ca1..a0f70499a 100644 --- a/src/refactored_solve.jl +++ b/src/refactored_solve.jl @@ -232,7 +232,7 @@ function inner_loss(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::Number, θ, if param_estim isa Bool p_ = param_estim ? θ.p : p sum(abs2, ode_dfdx(phi, t, θ, autodiff) .- f(phi(t, θ), p_, t)) - else if differential_vars isa AbstractVector + elseif differential_vars isa AbstractVector dphi = dfdx(phi, t, θ, autodiff,differential_vars) sum(abs2, f(dphi, phi(t, θ), p, t)) end @@ -248,7 +248,7 @@ function inner_loss(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::AbstractVector, fs = reduce(hcat, [f(out[:, i], p_, arrt[i]) for i in 1:size(out, 2)]) dxdtguess = Array(ode_dfdx(phi, t, θ, autodiff)) sum(abs2, dxdtguess .- fs) / length(t) - else if differential_vars isa AbstractVector + elseif differential_vars isa AbstractVector out = Array(phi(t, θ)) dphi = Array(dfdx(phi, t, θ, autodiff, differential_vars)) arrt = Array(t) @@ -271,7 +271,7 @@ function generate_loss(strategy::QuadratureTraining, phi, f, autodiff::Bool, tsp function integrand(ts, θ) [abs2(inner_loss(phi, f, autodiff, t, θ, p; param_estim)) for t in ts] end - else if differential_vars isa AbstractVector + elseif differential_vars isa AbstractVector integrand(t::Number, θ) = abs2(inner_loss(phi, f, autodiff, t, θ, p; differential_vars)) function integrand(ts, θ) @@ -302,7 +302,7 @@ function generate_loss( end end return loss - else if differential_vars isa AbstractVector + elseif differential_vars isa AbstractVector function loss(θ, _) sum(abs2, inner_loss(phi, f, autodiff, ts, θ, p, differential_vars)) end @@ -325,7 +325,7 @@ function generate_loss(strategy::StochasticTraining, phi, f, autodiff::Bool, tsp end end return loss - else if differential_vars isa AbstractVector + elseif differential_vars isa AbstractVector function loss(θ, _) ts = adapt(parameterless_type(θ), [(tspan[2] - tspan[1]) * rand() + tspan[1] for i in 1:(strategy.points)]) @@ -367,7 +367,7 @@ function generate_loss( end end return loss - else if differential_vars isa AbstractVector + elseif differential_vars isa AbstractVector function loss(θ, _) sum(inner_loss(phi, f, autodiff, ts, θ, p, differential_vars)) end diff --git a/test/NNDAE_tests.jl b/test/NNDAE_tests.jl index d4804aeaa..c8e1c956a 100644 --- a/test/NNDAE_tests.jl +++ b/test/NNDAE_tests.jl @@ -2,7 +2,7 @@ using Test, Flux using Random, NeuralPDE using OrdinaryDiffEq, Statistics using LineSearches -import refactored_solve +using refactored_solve import Lux, OptimizationOptimisers, OptimizationOptimJL @@ -95,7 +95,7 @@ end alg = NNDAE(chain, OptimizationOptimisers.Adam(0.1), strategy = WeightedIntervalTraining(weights, points); autodiff = false) - sol = solve(prob, + sol = refactored_solve(prob, alg, verbose = false, dt = 1 / 100.0, maxiters = 3000, abstol = 1e-10)