Skip to content

Commit

Permalink
add interpolation
Browse files Browse the repository at this point in the history
  • Loading branch information
KirillZubov committed Jun 27, 2024
1 parent 2818f34 commit 8895693
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 23 deletions.
18 changes: 10 additions & 8 deletions src/pino_ode_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ function physics_loss(
p, t = x
f = prob.f
out = phi(x, θ)
if size(p)[1] == 1
if size(p,1) == 1
fs = f.(out, p, vec(t))
f_vec = vec(fs)
else
Expand All @@ -124,18 +124,18 @@ end

function get_trainset(strategy::GridTraining, bounds, number_of_parameters, tspan)
dt = strategy.dx
if size(bounds)[1] == 1
if size(bounds,1) == 1
bound = bounds[1]
p_ = range(start = bound[1], length = number_of_parameters, stop = bound[2])
p = collect(reshape(p_, 1, size(p_)[1]))
p = collect(reshape(p_, 1, size(p_,1)))
else
p_ = [range(start = b[1], length = number_of_parameters, stop = b[2])
for b in bounds]
p = vcat([collect(reshape(p_i, 1, size(p_i)[1])) for p_i in p_]...)
p = vcat([collect(reshape(p_i, 1, size(p_i,1))) for p_i in p_]...)
end

t_ = collect(tspan[1]:dt:tspan[2])
t = reshape(t_, 1, size(t_)[1], 1)
t = reshape(t_, 1, size(t_,1), 1)
(p, t)
end

Expand All @@ -148,7 +148,7 @@ function generate_loss(
end

function get_trainset(strategy::StochasticTraining, bounds, number_of_parameters, tspan)
if size(bounds)[1] == 1
if size(bounds,1) == 1
bound = bounds[1]
p = (bound[2] .- bound[1]) .* rand(1, number_of_parameters) .+ bound[1]
else
Expand All @@ -173,8 +173,10 @@ struct PINOODEInterpolation{T <: PINOPhi, T2}
θ::T2
end

#TODO
# (f::NNODEInterpolation)(t, ...) = f.phi(t, f.θ)
(f::PINOODEInterpolation)(x) = f.phi(x, f.θ)

SciMLBase.interp_summary(::PINOODEInterpolation) = "Trained neural network interpolation"
SciMLBase.allowscomplex(::PINOODE) = true

function SciMLBase.__solve(prob::SciMLBase.AbstractODEProblem,
alg::PINOODE,
Expand Down
37 changes: 22 additions & 15 deletions test/PINO_ode_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,23 +31,30 @@ using NeuralPDE

bounds = [(pi, 2pi)]
number_of_parameters = 50
# dt = (tspan[2] - tspan[1]) / 40
# strategy = GridTraining(dt)
strategy = StochasticTraining(40)
opt = OptimizationOptimisers.Adam(0.01)
alg = PINOODE(deeponet, opt, bounds, number_of_parameters; strategy = strategy)
sol = solve(prob, alg, verbose = true, maxiters = 5000)
sol.original.objective
# TODO intrepretation output with few mesh
sol = solve(prob, alg, verbose = false, maxiters = 2000)

ground_analytic = (u0, p, t) -> u0 + sin(p * t) / (p)
p_ = range(start = bounds[1][1], length = number_of_parameters, stop = bounds[1][2])
p = collect(reshape(p_, 1, size(p_)[1]))
t_ = collect(tspan[1]:dt:tspan[2])
t = collect(reshape(t_, 1, size(t_)[1], 1))
ground_solution = ground_analytic.(u0, p, t_)
predict_sol = sol.interp.phi((p,t), sol.interp.θ)
function get_trainset(bounds, tspan , number_of_parameters, dt)
p_ = range(start = bounds[1][1], length = number_of_parameters, stop = bounds[1][2])
p = collect(reshape(p_, 1, size(p_, 1)))
t_ = collect(tspan[1]:dt:tspan[2])
t = collect(reshape(t_, 1, size(t_, 1), 1))
(p,t)
end
p,t = get_trainset(bounds, tspan, number_of_parameters, dt)

ground_solution = ground_analytic.(u0, p, vec(t))
predict_sol = sol.interp((p, t))

@test ground_solutionpredict_sol rtol=0.01

p, t = get_trainset(bounds, tspan, 100, 0.01)
ground_solution = ground_analytic.(u0, p, vec(t))
predict_sol = sol.interp((p, t))

@test ground_solutionpredict_sol rtol=0.1
@test ground_solutionpredict_sol rtol=0.01
end

Expand Down Expand Up @@ -76,7 +83,7 @@ end
(p^2 + 1)

p_ = range(start = bounds[1][1], length = number_of_parameters, stop = bounds[1][2])
p = collect(reshape(p_, 1, size(p_)[1]))
p = collect(reshape(p_, 1, size(p_,1)))
ground_solution = ground_analytic_.(u0, p, vec(sol.t[2]))

@test ground_solutionsol.u rtol=0.01
Expand Down Expand Up @@ -163,7 +170,7 @@ end
Chain(Dense(1 => 10, Lux.tanh_fast), Dense(10 => 10, Lux.tanh_fast),
Dense(10 => 10, Lux.tanh_fast)))

u = rand(2, 50)
u = rand(2, 50, 1)
v = rand(1, 40, 1)
θ, st = Lux.setup(Random.default_rng(), deeponet)
c = deeponet((u, v), θ, st)[1]
Expand All @@ -178,7 +185,7 @@ end

ga = (u0, p, t) -> u0 + p[1] / p[2] * sin(p[2] * t)
p_ = [range(start = b[1], length = number_of_parameters, stop = b[2]) for b in bounds]
p = vcat([collect(reshape(p_i, 1, size(p_i)[1])) for p_i in p_]...)
p = vcat([collect(reshape(p_i, 1, size(p_i,1))) for p_i in p_]...)
t = sol.t[2]
ground_solution = reduce(hcat,
[[ga(u0, p[:, i], t[j]) for j in axes(t, 2)] for i in axes(p, 2)])
Expand Down

0 comments on commit 8895693

Please sign in to comment.