Skip to content

Commit

Permalink
support Chain
Browse files Browse the repository at this point in the history
  • Loading branch information
KirillZubov committed Sep 24, 2024
1 parent 6813c5d commit 786035f
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 14 deletions.
78 changes: 73 additions & 5 deletions src/pino_ode_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,11 @@ function generate_pino_phi_θ(chain::Lux.AbstractLuxLayer, init_params)
PINOPhi(chain, st), init_params
end

function (f::PINOPhi{C, T})(
x, θ) where {C <: Lux.AbstractLuxLayer, T}
function (f::PINOPhi{C, T})(x, θ) where {C <: Lux.AbstractLuxLayer, T}
# θ_ = ComponentArrays.getdata(θ)
# eltypeθ, typeθ = eltype(θ_), parameterless_type(θ_)
# t_ = convert.(eltypeθ, adapt(typeθ, t'))
# y, st = f.chain(t_, θ, f.st)
y, st = f.chain(adapt(parameterless_type(ComponentArrays.getdata(θ)), x), θ, f.st)
y
end
Expand All @@ -86,6 +89,13 @@ function dfdx(phi::PINOPhi{C, T}, x::Tuple, θ) where {C <: DeepONet, T}
(phi(x_left, θ) .- phi(x_right, θ)) ./ sqrt(eps(eltype(t)))
end

#TODO chain
function dfdx(phi::PINOPhi{C, T}, x::Array,
θ) where {C <: Lux.Chain, T}
ε = [zeros(eltype(x), size(x)[1] - 1)..., sqrt(eps(eltype(x)))]
(phi(x .+ ε, θ) - phi(x, θ)) ./ sqrt(eps(eltype(x)))
end

function physics_loss(
phi::PINOPhi{C, T}, prob::ODEProblem, x::Tuple, θ) where {C <: DeepONet, T}
p, t = x
Expand All @@ -102,6 +112,23 @@ function physics_loss(
sum(abs2, du .- f_vec) / norm
end

function physics_loss(
phi::PINOPhi{C, T}, prob::ODEProblem, x::Tuple, θ) where {
C <: Lux.Chain, T}
p, t = x
x_ = reduce(vcat, (p, t))
f = prob.f
if size(p, 1) == 1
f_vec = f.(out, p, t)
else
#TODO
# f_vec = reduce( vcat, [[f(out[i], p[i], t[j]) for j in axes(t, 2)] for i in axes(p, 2)])
end
du = dfdx(phi, x_, θ)
norm = prod(size(out))
sum(abs2, du .- f_vec) / norm
end

function initial_condition_loss(
phi::PINOPhi{C, T}, prob::ODEProblem, x, θ) where {
C <: DeepONet, T}
Expand All @@ -115,6 +142,19 @@ function initial_condition_loss(
sum(abs2, u .- u0) / norm
end

function initial_condition_loss(
phi::PINOPhi{C, T}, prob::ODEProblem, x::Tuple, θ) where {
C <: Lux.Chain, T}
p, t = x
t0 = fill(prob.tspan[1], size(p))
x0 = reduce(vcat, (p, t0))
out = phi(x0, θ)
u = vec(out)
u0 = vec(fill(prob.u0, size(out)))
norm = prod(size(u0))
sum(abs2, u .- u0) / norm
end

function get_trainset(
strategy::GridTraining, chain::DeepONet, bounds, number_of_parameters, tspan)
dt = strategy.dx
Expand All @@ -134,6 +174,25 @@ function get_trainset(
(p, t)
end

function get_trainset(strategy::GridTraining, chain::Lux.Chain, bounds,
number_of_parameters, tspan)
dt = strategy.dx
p = collect([range(start = b[1], length = number_of_parameters, stop = b[2])
for b in bounds]...)
t = collect(tspan[1]:dt:tspan[2])
combinations = (collect(Iterators.product(p, t)))
N = size(p, 1)
M = size(t, 1)
x = zeros(2, N, M)
for i in 1:N
for j in 1:M
x[:, i, j] = [combinations[i, j]...]
end
end
p, t = x[1:(end - 1), :, :], x[[end], :, :]
(p, t)
end

function generate_loss(
strategy::GridTraining, prob::ODEProblem, phi, bounds, number_of_parameters, tspan)
x = get_trainset(strategy, phi.chain, bounds, number_of_parameters, tspan)
Expand Down Expand Up @@ -174,8 +233,8 @@ function SciMLBase.__solve(prob::SciMLBase.AbstractODEProblem,
if !(chain isa Lux.AbstractLuxLayer)
error("Only Lux.AbstractLuxLayer neural networks are supported")

if !(chain isa DeepONet) #|| chain isa FourierNeuralOperator)
error("Only DeepONet and FourierNeuralOperator neural networks are supported with PINO ODE")
if !(chain isa DeepONet) || !(chain isa Chain)
error("Only DeepONet and Chain neural networks are supported with PINO ODE")
end
end

Expand All @@ -192,6 +251,11 @@ function SciMLBase.__solve(prob::SciMLBase.AbstractODEProblem,
x = (u, v)
phi(x, init_params)
end
if chain isa Chain
in_dim = chain.layers.layer_1.in_dims
x = rand(Float32, in_dim, number_of_parameters)
phi(x, init_params)
end
catch err
if isa(err, DimensionMismatch)
throw(DimensionMismatch("Dimensions of input data and chain should match"))
Expand Down Expand Up @@ -233,7 +297,11 @@ function SciMLBase.__solve(prob::SciMLBase.AbstractODEProblem,
res = solve(optprob, opt; callback, maxiters, alg.kwargs...)

x = get_trainset(strategy, phi.chain, bounds, number_of_parameters, tspan)
u = phi(x, res.u)
if chain isa DeepONet
u = phi(x, res.u)
elseif chain isa Chain
u = phi(reduce(vcat, x), res.u)
end

sol = SciMLBase.build_solution(prob, alg, x, u;
k = res, dense = true,
Expand Down
66 changes: 57 additions & 9 deletions test/PINO_ode_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,62 @@ using Statistics, Random
using NeuralOperators
using NeuralPDE

function get_trainset(bounds, tspan, number_of_parameters, dt)
function get_trainset(chain::DeepONet, bounds, number_of_parameters, tspan, dt)
p_ = range(start = bounds[1][1], length = number_of_parameters, stop = bounds[1][2])
p = collect(reshape(p_, 1, size(p_, 1)))
t_ = collect(tspan[1]:dt:tspan[2])
t = collect(reshape(t_, 1, size(t_, 1), 1))
(p, t)
end

function get_trainset(chain::Lux.Chain, bounds, number_of_parameters, tspan, dt)
dt = strategy.dx
p = collect([range(start = b[1], length = number_of_parameters, stop = b[2])
for b in bounds]...)
t = collect(tspan[1]:dt:tspan[2])
combinations = (collect(Iterators.product(p, t)))
N = size(p, 1)
M = size(t, 1)
x = zeros(2, N, M)
for i in 1:N
for j in 1:M
x[:, i, j] = [combinations[i, j]...]
end
end
p, t = x[1:(end - 1), :, :], x[[end], :, :]
(p, t)
end

#Test with Chain
@testset "Example du = cos(p * t)" begin
equation = (u, p, t) -> cos(p * t)
tspan = (0.0f0, 1.0f0)
u0 = 1.0f0
prob = ODEProblem(equation, u0, tspan)
chain = Chain(Dense(2 => 10, Lux.tanh_fast), Dense(10 => 10, Lux.tanh_fast), Dense(10 => 1))
x = rand(1, 50)
θ, st = Lux.setup(Random.default_rng(), chain)
b = chain(x, θ, st)[1]

bounds = [(pi, 2pi)]
number_of_parameters = 50
strategy = GridTraining(0.1f0)
opt = OptimizationOptimisers.Adam(0.01)
alg = PINOODE(chain, opt, bounds, number_of_parameters; strategy = strategy)
sol = solve(prob, alg, verbose = false, maxiters = 5000)
ground_analytic = (u0, p, t) -> u0 + sin(p * t) / (p)
dt = 0.025f0
p, t = get_trainset(chain, bounds, number_of_parameters, tspan, dt)
ground_solution = ground_analytic.(u0, p, t)
predict_sol = sol.interp(reduce(vcat, (p, t)))
@test ground_solutionpredict_sol rtol=0.07
p, t = get_trainset(chain, bounds, 100, tspan, 0.01)
ground_solution = ground_analytic.(u0, p, t)
predict_sol = sol.interp(reduce(vcat, (p, t)))
@test ground_solutionpredict_sol rtol=0.07
end

#Test with DeepONet
@testset "Example du = cos(p * t)" begin
equation = (u, p, t) -> cos(p * t)
tspan = (0.0f0, 1.0f0)
Expand Down Expand Up @@ -42,11 +90,11 @@ end
sol = solve(prob, alg, verbose = false, maxiters = 2000)
ground_analytic = (u0, p, t) -> u0 + sin(p * t) / (p)
dt = 0.025f0
p, t = get_trainset(bounds, tspan, number_of_parameters, dt)
p, t = get_trainset(deeponet, bounds, number_of_parameters, tspan, dt)
ground_solution = ground_analytic.(u0, p, vec(t))
predict_sol = sol.interp((p, t))
@test ground_solutionpredict_sol rtol=0.05
p, t = get_trainset(bounds, tspan, 100, 0.01)
p, t = get_trainset(deeponet, bounds, tspan, 100, 0.01)
ground_solution = ground_analytic.(u0, p, vec(t))
predict_sol = sol.interp((p, t))
@test ground_solutionpredict_sol rtol=0.05
Expand All @@ -73,7 +121,7 @@ end
#if u0 == 1
ground_analytic_(u0, p, t) = (p * sin(p * t) - cos(p * t) + (p^2 + 2) * exp(t)) /
(p^2 + 1)
p, t = get_trainset(bounds, tspan, number_of_parameters, dt)
p, t = get_trainset(deeponet, bounds, number_of_parameters, tspan, dt)
ground_solution = ground_analytic_.(u0, p, vec(t))
predict_sol = sol.interp((p, t))
@test ground_solutionpredict_sol rtol=0.5
Expand Down Expand Up @@ -106,7 +154,7 @@ end
v = rand(1, 40, 1)
θ, st = Lux.setup(Random.default_rng(), deeponet)
c = deeponet((u, v), θ, st)[1]
p, t = get_trainset(bounds, tspan, number_of_parameters, dt)
p, t = get_trainset(deeponet, bounds, number_of_parameters, tspan, dt)
data, tuple_ = get_data()
function additional_loss_(phi, θ)
u = phi(tuple_, θ)
Expand All @@ -118,7 +166,7 @@ end
additional_loss = additional_loss_)
sol = solve(prob, alg, verbose = false, maxiters = 2000)

p, t = get_trainset(bounds, tspan, number_of_parameters, dt)
p, t = get_trainset(deeponet, bounds, number_of_parameters, tspan, dt)
ground_solution = ground_analytic.(u0, p, vec(t))
predict_sol = sol.interp((p, t))
@test ground_solutionpredict_sol rtol=0.05
Expand Down Expand Up @@ -150,7 +198,7 @@ end
alg = PINOODE(deeponet, opt, bounds, number_of_parameters; strategy = strategy)
sol = solve(prob, alg, verbose = false, maxiters = 3000)

function get_trainset(bounds, tspan, number_of_parameters, dt)
function get_trainset(deeponet, bounds, number_of_parameters, tspan, dt)
p_ = [range(start = b[1], length = number_of_parameters, stop = b[2])
for b in bounds]
p = vcat([collect(reshape(p_i, 1, size(p_i, 1))) for p_i in p_]...)
Expand All @@ -165,12 +213,12 @@ end
[[ground_solution(u0, p[:, i], t[j]) for j in axes(t, 2)] for i in axes(p, 2)])
end

(p, t) = get_trainset(bounds, tspan, 50, 0.025f0)
(p, t) = get_trainset(deeponet, bounds, 50, tspan, 0.025f0)
ground_solution_ = ground_solution_f(p, t)
predict = sol.interp((p, t))
@test ground_solution_predict rtol=0.05

p, t = get_trainset(bounds, tspan, 100, 0.01f0)
p, t = get_trainset(deeponet, bounds, 100, tspan, 0.01f0)
ground_solution_ = ground_solution_f(p, t)
predict = sol.interp((p, t))
@test ground_solution_predict rtol=0.05
Expand Down

0 comments on commit 786035f

Please sign in to comment.