Skip to content

Commit

Permalink
Changes made so far
Browse files Browse the repository at this point in the history
  • Loading branch information
hippyhippohops committed Sep 9, 2024
1 parent 2968094 commit b79e5b8
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 8 deletions.
1 change: 1 addition & 0 deletions src/NeuralPDE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ include("adaptive_losses.jl")
include("ode_solve.jl")
# include("rode_solve.jl")
include("dae_solve.jl")
include("refactored_solve.jl")
include("transform_inf_integral.jl")
include("discretize.jl")
include("neural_adapter.jl")
Expand Down
12 changes: 6 additions & 6 deletions src/refactored_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ function inner_loss(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::Number, θ,
if param_estim isa Bool
p_ = param_estim ? θ.p : p
sum(abs2, ode_dfdx(phi, t, θ, autodiff) .- f(phi(t, θ), p_, t))
else if differential_vars isa AbstractVector
elseif differential_vars isa AbstractVector
dphi = dfdx(phi, t, θ, autodiff,differential_vars)
sum(abs2, f(dphi, phi(t, θ), p, t))
end
Expand All @@ -248,7 +248,7 @@ function inner_loss(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::AbstractVector,
fs = reduce(hcat, [f(out[:, i], p_, arrt[i]) for i in 1:size(out, 2)])
dxdtguess = Array(ode_dfdx(phi, t, θ, autodiff))
sum(abs2, dxdtguess .- fs) / length(t)
else if differential_vars isa AbstractVector
elseif differential_vars isa AbstractVector
out = Array(phi(t, θ))
dphi = Array(dfdx(phi, t, θ, autodiff, differential_vars))
arrt = Array(t)
Expand All @@ -271,7 +271,7 @@ function generate_loss(strategy::QuadratureTraining, phi, f, autodiff::Bool, tsp
function integrand(ts, θ)
[abs2(inner_loss(phi, f, autodiff, t, θ, p; param_estim)) for t in ts]
end
else if differential_vars isa AbstractVector
elseif differential_vars isa AbstractVector
integrand(t::Number, θ) = abs2(inner_loss(phi, f, autodiff, t, θ, p; differential_vars))

function integrand(ts, θ)
Expand Down Expand Up @@ -302,7 +302,7 @@ function generate_loss(
end
end
return loss
else if differential_vars isa AbstractVector
elseif differential_vars isa AbstractVector
function loss(θ, _)
sum(abs2, inner_loss(phi, f, autodiff, ts, θ, p, differential_vars))
end
Expand All @@ -325,7 +325,7 @@ function generate_loss(strategy::StochasticTraining, phi, f, autodiff::Bool, tsp
end
end
return loss
else if differential_vars isa AbstractVector
elseif differential_vars isa AbstractVector
function loss(θ, _)
ts = adapt(parameterless_type(θ),
[(tspan[2] - tspan[1]) * rand() + tspan[1] for i in 1:(strategy.points)])
Expand Down Expand Up @@ -367,7 +367,7 @@ function generate_loss(
end
end
return loss
else if differential_vars isa AbstractVector
elseif differential_vars isa AbstractVector
function loss(θ, _)
sum(inner_loss(phi, f, autodiff, ts, θ, p, differential_vars))
end
Expand Down
4 changes: 2 additions & 2 deletions test/NNDAE_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ using Test, Flux
using Random, NeuralPDE
using OrdinaryDiffEq, Statistics
using LineSearches
import refactored_solve
using refactored_solve
import Lux, OptimizationOptimisers, OptimizationOptimJL


Expand Down Expand Up @@ -95,7 +95,7 @@ end
alg = NNDAE(chain, OptimizationOptimisers.Adam(0.1),
strategy = WeightedIntervalTraining(weights, points); autodiff = false)

sol = solve(prob,
sol = refactored_solve(prob,
alg, verbose = false, dt = 1 / 100.0,
maxiters = 3000, abstol = 1e-10)

Expand Down

0 comments on commit b79e5b8

Please sign in to comment.