diff --git a/src/pino_ode_solve.jl b/src/pino_ode_solve.jl index 03a620315..8d002e365 100644 --- a/src/pino_ode_solve.jl +++ b/src/pino_ode_solve.jl @@ -71,8 +71,11 @@ function generate_pino_phi_θ(chain::Lux.AbstractLuxLayer, init_params) PINOPhi(chain, st), init_params end -function (f::PINOPhi{C, T})( - x, θ) where {C <: Lux.AbstractLuxLayer, T} +function (f::PINOPhi{C, T})(x, θ) where {C <: Lux.AbstractLuxLayer, T} + # θ_ = ComponentArrays.getdata(θ) + # eltypeθ, typeθ = eltype(θ_), parameterless_type(θ_) + # t_ = convert.(eltypeθ, adapt(typeθ, t')) + # y, st = f.chain(t_, θ, f.st) y, st = f.chain(adapt(parameterless_type(ComponentArrays.getdata(θ)), x), θ, f.st) y end @@ -86,6 +89,13 @@ function dfdx(phi::PINOPhi{C, T}, x::Tuple, θ) where {C <: DeepONet, T} (phi(x_left, θ) .- phi(x_right, θ)) ./ sqrt(eps(eltype(t))) end +#TODO chain +function dfdx(phi::PINOPhi{C, T}, x::Array, + θ) where {C <: Lux.Chain, T} + ε = [zeros(eltype(x), size(x)[1] - 1)..., sqrt(eps(eltype(x)))] + (phi(x .+ ε, θ) - phi(x, θ)) ./ sqrt(eps(eltype(x))) +end + function physics_loss( phi::PINOPhi{C, T}, prob::ODEProblem, x::Tuple, θ) where {C <: DeepONet, T} p, t = x @@ -102,6 +112,23 @@ function physics_loss( sum(abs2, du .- f_vec) / norm end +function physics_loss( + phi::PINOPhi{C, T}, prob::ODEProblem, x::Tuple, θ) where { + C <: Lux.Chain, T} + p, t = x + x_ = reduce(vcat, (p, t)) + f = prob.f + if size(p, 1) == 1 + f_vec = f.(out, p, t) + else + #TODO + # f_vec = reduce( vcat, [[f(out[i], p[i], t[j]) for j in axes(t, 2)] for i in axes(p, 2)]) + end + du = dfdx(phi, x_, θ) + norm = prod(size(out)) + sum(abs2, du .- f_vec) / norm +end + function initial_condition_loss( phi::PINOPhi{C, T}, prob::ODEProblem, x, θ) where { C <: DeepONet, T} @@ -115,6 +142,19 @@ function initial_condition_loss( sum(abs2, u .- u0) / norm end +function initial_condition_loss( + phi::PINOPhi{C, T}, prob::ODEProblem, x::Tuple, θ) where { + C <: Lux.Chain, T} + p, t = x + t0 = fill(prob.tspan[1], size(p)) + x0 = reduce(vcat, (p, t0)) + out = phi(x0, θ) + u = vec(out) + u0 = vec(fill(prob.u0, size(out))) + norm = prod(size(u0)) + sum(abs2, u .- u0) / norm +end + function get_trainset( strategy::GridTraining, chain::DeepONet, bounds, number_of_parameters, tspan) dt = strategy.dx @@ -134,6 +174,25 @@ function get_trainset( (p, t) end +function get_trainset(strategy::GridTraining, chain::Lux.Chain, bounds, + number_of_parameters, tspan) + dt = strategy.dx + p = collect([range(start = b[1], length = number_of_parameters, stop = b[2]) + for b in bounds]...) + t = collect(tspan[1]:dt:tspan[2]) + combinations = (collect(Iterators.product(p, t))) + N = size(p, 1) + M = size(t, 1) + x = zeros(2, N, M) + for i in 1:N + for j in 1:M + x[:, i, j] = [combinations[i, j]...] + end + end + p, t = x[1:(end - 1), :, :], x[[end], :, :] + (p, t) +end + function generate_loss( strategy::GridTraining, prob::ODEProblem, phi, bounds, number_of_parameters, tspan) x = get_trainset(strategy, phi.chain, bounds, number_of_parameters, tspan) @@ -174,8 +233,8 @@ function SciMLBase.__solve(prob::SciMLBase.AbstractODEProblem, if !(chain isa Lux.AbstractLuxLayer) error("Only Lux.AbstractLuxLayer neural networks are supported") - if !(chain isa DeepONet) #|| chain isa FourierNeuralOperator) - error("Only DeepONet and FourierNeuralOperator neural networks are supported with PINO ODE") + if !(chain isa DeepONet) || !(chain isa Chain) + error("Only DeepONet and Chain neural networks are supported with PINO ODE") end end @@ -192,6 +251,11 @@ function SciMLBase.__solve(prob::SciMLBase.AbstractODEProblem, x = (u, v) phi(x, init_params) end + if chain isa Chain + in_dim = chain.layers.layer_1.in_dims + x = rand(Float32, in_dim, number_of_parameters) + phi(x, init_params) + end catch err if isa(err, DimensionMismatch) throw(DimensionMismatch("Dimensions of input data and chain should match")) @@ -233,7 +297,11 @@ function SciMLBase.__solve(prob::SciMLBase.AbstractODEProblem, res = solve(optprob, opt; callback, maxiters, alg.kwargs...) x = get_trainset(strategy, phi.chain, bounds, number_of_parameters, tspan) - u = phi(x, res.u) + if chain isa DeepONet + u = phi(x, res.u) + elseif chain isa Chain + u = phi(reduce(vcat, x), res.u) + end sol = SciMLBase.build_solution(prob, alg, x, u; k = res, dense = true, diff --git a/test/PINO_ode_tests.jl b/test/PINO_ode_tests.jl index 57110a199..28dd0a476 100644 --- a/test/PINO_ode_tests.jl +++ b/test/PINO_ode_tests.jl @@ -5,7 +5,7 @@ using Statistics, Random using NeuralOperators using NeuralPDE -function get_trainset(bounds, tspan, number_of_parameters, dt) +function get_trainset(chain::DeepONet, bounds, number_of_parameters, tspan, dt) p_ = range(start = bounds[1][1], length = number_of_parameters, stop = bounds[1][2]) p = collect(reshape(p_, 1, size(p_, 1))) t_ = collect(tspan[1]:dt:tspan[2]) @@ -13,6 +13,54 @@ function get_trainset(bounds, tspan, number_of_parameters, dt) (p, t) end +function get_trainset(chain::Lux.Chain, bounds, number_of_parameters, tspan, dt) + dt = strategy.dx + p = collect([range(start = b[1], length = number_of_parameters, stop = b[2]) + for b in bounds]...) + t = collect(tspan[1]:dt:tspan[2]) + combinations = (collect(Iterators.product(p, t))) + N = size(p, 1) + M = size(t, 1) + x = zeros(2, N, M) + for i in 1:N + for j in 1:M + x[:, i, j] = [combinations[i, j]...] + end + end + p, t = x[1:(end - 1), :, :], x[[end], :, :] + (p, t) +end + +#Test with Chain +@testset "Example du = cos(p * t)" begin + equation = (u, p, t) -> cos(p * t) + tspan = (0.0f0, 1.0f0) + u0 = 1.0f0 + prob = ODEProblem(equation, u0, tspan) + chain = Chain(Dense(2 => 10, Lux.tanh_fast), Dense(10 => 10, Lux.tanh_fast), Dense(10 => 1)) + x = rand(1, 50) + θ, st = Lux.setup(Random.default_rng(), chain) + b = chain(x, θ, st)[1] + + bounds = [(pi, 2pi)] + number_of_parameters = 50 + strategy = GridTraining(0.1f0) + opt = OptimizationOptimisers.Adam(0.01) + alg = PINOODE(chain, opt, bounds, number_of_parameters; strategy = strategy) + sol = solve(prob, alg, verbose = false, maxiters = 5000) + ground_analytic = (u0, p, t) -> u0 + sin(p * t) / (p) + dt = 0.025f0 + p, t = get_trainset(chain, bounds, number_of_parameters, tspan, dt) + ground_solution = ground_analytic.(u0, p, t) + predict_sol = sol.interp(reduce(vcat, (p, t))) + @test ground_solution≈predict_sol rtol=0.07 + p, t = get_trainset(chain, bounds, 100, tspan, 0.01) + ground_solution = ground_analytic.(u0, p, t) + predict_sol = sol.interp(reduce(vcat, (p, t))) + @test ground_solution≈predict_sol rtol=0.07 +end + +#Test with DeepONet @testset "Example du = cos(p * t)" begin equation = (u, p, t) -> cos(p * t) tspan = (0.0f0, 1.0f0) @@ -42,11 +90,11 @@ end sol = solve(prob, alg, verbose = false, maxiters = 2000) ground_analytic = (u0, p, t) -> u0 + sin(p * t) / (p) dt = 0.025f0 - p, t = get_trainset(bounds, tspan, number_of_parameters, dt) + p, t = get_trainset(deeponet, bounds, number_of_parameters, tspan, dt) ground_solution = ground_analytic.(u0, p, vec(t)) predict_sol = sol.interp((p, t)) @test ground_solution≈predict_sol rtol=0.05 - p, t = get_trainset(bounds, tspan, 100, 0.01) + p, t = get_trainset(deeponet, bounds, tspan, 100, 0.01) ground_solution = ground_analytic.(u0, p, vec(t)) predict_sol = sol.interp((p, t)) @test ground_solution≈predict_sol rtol=0.05 @@ -73,7 +121,7 @@ end #if u0 == 1 ground_analytic_(u0, p, t) = (p * sin(p * t) - cos(p * t) + (p^2 + 2) * exp(t)) / (p^2 + 1) - p, t = get_trainset(bounds, tspan, number_of_parameters, dt) + p, t = get_trainset(deeponet, bounds, number_of_parameters, tspan, dt) ground_solution = ground_analytic_.(u0, p, vec(t)) predict_sol = sol.interp((p, t)) @test ground_solution≈predict_sol rtol=0.5 @@ -106,7 +154,7 @@ end v = rand(1, 40, 1) θ, st = Lux.setup(Random.default_rng(), deeponet) c = deeponet((u, v), θ, st)[1] - p, t = get_trainset(bounds, tspan, number_of_parameters, dt) + p, t = get_trainset(deeponet, bounds, number_of_parameters, tspan, dt) data, tuple_ = get_data() function additional_loss_(phi, θ) u = phi(tuple_, θ) @@ -118,7 +166,7 @@ end additional_loss = additional_loss_) sol = solve(prob, alg, verbose = false, maxiters = 2000) - p, t = get_trainset(bounds, tspan, number_of_parameters, dt) + p, t = get_trainset(deeponet, bounds, number_of_parameters, tspan, dt) ground_solution = ground_analytic.(u0, p, vec(t)) predict_sol = sol.interp((p, t)) @test ground_solution≈predict_sol rtol=0.05 @@ -150,7 +198,7 @@ end alg = PINOODE(deeponet, opt, bounds, number_of_parameters; strategy = strategy) sol = solve(prob, alg, verbose = false, maxiters = 3000) - function get_trainset(bounds, tspan, number_of_parameters, dt) + function get_trainset(deeponet, bounds, number_of_parameters, tspan, dt) p_ = [range(start = b[1], length = number_of_parameters, stop = b[2]) for b in bounds] p = vcat([collect(reshape(p_i, 1, size(p_i, 1))) for p_i in p_]...) @@ -165,12 +213,12 @@ end [[ground_solution(u0, p[:, i], t[j]) for j in axes(t, 2)] for i in axes(p, 2)]) end - (p, t) = get_trainset(bounds, tspan, 50, 0.025f0) + (p, t) = get_trainset(deeponet, bounds, 50, tspan, 0.025f0) ground_solution_ = ground_solution_f(p, t) predict = sol.interp((p, t)) @test ground_solution_≈predict rtol=0.05 - p, t = get_trainset(bounds, tspan, 100, 0.01f0) + p, t = get_trainset(deeponet, bounds, 100, tspan, 0.01f0) ground_solution_ = ground_solution_f(p, t) predict = sol.interp((p, t)) @test ground_solution_≈predict rtol=0.05