Skip to content

Commit

Permalink
feat: compatibility of NNODE with CUDA
Browse files Browse the repository at this point in the history
  • Loading branch information
sathvikbhagavan committed Jun 29, 2024
1 parent 3652d6d commit fadb032
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 77 deletions.
5 changes: 4 additions & 1 deletion src/NeuralPDE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@ using SciMLBase: @add_kwonly, parameterless_type
using UnPack: @unpack
import ChainRulesCore, Lux, ComponentArrays
using Lux: FromFluxAdaptor
using ChainRulesCore: @non_differentiable
using ChainRulesCore: @ignore_derivatives
using LuxDeviceUtils: LuxCUDADevice, LuxCPUDevice, cpu_device
using LuxCUDA: CuArray, CUDABackend
using KernelAbstractions: @kernel, @Const, @index

RuntimeGeneratedFunctions.init(@__MODULE__)

Expand Down
154 changes: 78 additions & 76 deletions src/ode_solve.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
abstract type NeuralPDEAlgorithm <: SciMLBase.AbstractODEAlgorithm end

"""
NNODE(chain, opt, init_params = nothing; autodiff = false, batch = 0, additional_loss = nothing, kwargs...)
NNODE(chain, opt, init_params = nothing; autodiff = false, batch = true, additional_loss = nothing, kwargs...)
Algorithm for solving ordinary differential equations using a neural network. This is a specialization
of the physics-informed neural network which is used as a solver for a standard `ODEProblem`.
Expand All @@ -21,6 +21,7 @@ of the physics-informed neural network which is used as a solver for a standard
which thus uses the random initialization provided by the neural network library.
## Keyword Arguments
* `additional_loss`: A function additional_loss(phi, θ) where phi are the neural network trial solutions,
θ are the weights of the neural network(s).
* `autodiff`: The switch between automatic and numerical differentiation for
Expand Down Expand Up @@ -71,7 +72,7 @@ is an accurate interpolation (up to the neural network training result). In addi
Lagaris, Isaac E., Aristidis Likas, and Dimitrios I. Fotiadis. "Artificial neural networks for solving
ordinary and partial differential equations." IEEE Transactions on Neural Networks 9, no. 5 (1998): 987-1000.
"""
struct NNODE{C, O, P, B, PE, K, AL <: Union{Nothing, Function},
struct NNODE{C, O, P, B, PE, K, D, AL <: Union{Nothing, Function},
S <: Union{Nothing, AbstractTrainingStrategy}
} <:
NeuralPDEAlgorithm
Expand All @@ -83,15 +84,33 @@ struct NNODE{C, O, P, B, PE, K, AL <: Union{Nothing, Function},
strategy::S
param_estim::PE
additional_loss::AL
device::D
kwargs::K
end
function NNODE(chain, opt, init_params = nothing;
strategy = nothing,
autodiff = false, batch = true, param_estim = false, additional_loss = nothing, kwargs...)
autodiff = false, batch = true, param_estim = false,
additional_loss = nothing, device = cpu_device(), kwargs...)
!(chain isa Lux.AbstractExplicitLayer) &&
(chain = adapt(FromFluxAdaptor(false, false), chain))
NNODE(chain, opt, init_params, autodiff, batch,
strategy, param_estim, additional_loss, kwargs)
strategy, param_estim, additional_loss, device, kwargs)
end

@kernel function custom_broadcast!(f, du, @Const(out), @Const(p), @Const(t))
i = @index(Global, Linear)
@views @inbounds x = f(out[:, i], p, t[i])
du[:, i] .= x
end

gpu_broadcast = custom_broadcast!(CUDABackend())

function get_array_type(::LuxCUDADevice)
CuArray
end

function get_array_type(::LuxCPUDevice)
Array
end

"""
Expand All @@ -100,53 +119,41 @@ end
Internal struct, used for representing the ODE solution as a neural network in a form that respects boundary conditions, i.e.
`phi(t) = u0 + t*NN(t)`.
"""
mutable struct ODEPhi{C, T, U, S}
mutable struct ODEPhi{C, T, U, S, D}
chain::C
t0::T
u0::U
st::S
function ODEPhi(chain::Lux.AbstractExplicitLayer, t::Number, u0, st)
new{typeof(chain), typeof(t), typeof(u0), typeof(st)}(chain, t, u0, st)
device::D
function ODEPhi(chain::Lux.AbstractExplicitLayer, t0::Number, u0, st, device)
new{typeof(chain), typeof(t0), typeof(u0), typeof(st), typeof(device)}(
chain, t0, u0, st, device)
end
end

function generate_phi_θ(chain::Lux.AbstractExplicitLayer, t, u0, init_params)
function generate_phi_θ(
chain::Lux.AbstractExplicitLayer, t0, u0, init_params, device, p, param_estim)
θ, st = Lux.setup(Random.default_rng(), chain)
isnothing(init_params) && (init_params = θ)
ODEPhi(chain, t, u0, st), init_params
end

function (f::ODEPhi{C, T, U})(t::Number,
θ) where {C <: Lux.AbstractExplicitLayer, T, U <: Number}
y, st = f.chain(
adapt(parameterless_type(ComponentArrays.getdata.depvar)), [t]), θ.depvar, f.st)
ChainRulesCore.@ignore_derivatives f.st = st
f.u0 + (t - f.t0) * first(y)
end

function (f::ODEPhi{C, T, U})(t::AbstractVector,
θ) where {C <: Lux.AbstractExplicitLayer, T, U <: Number}
# Batch via data as row vectors
y, st = f.chain(
adapt(parameterless_type(ComponentArrays.getdata.depvar)), t'), θ.depvar, f.st)
ChainRulesCore.@ignore_derivatives f.st = st
f.u0 .+ (t' .- f.t0) .* y
end

function (f::ODEPhi{C, T, U})(t::Number, θ) where {C <: Lux.AbstractExplicitLayer, T, U}
y, st = f.chain(
adapt(parameterless_type(ComponentArrays.getdata.depvar)), [t]), θ.depvar, f.st)
ChainRulesCore.@ignore_derivatives f.st = st
f.u0 .+ (t .- f.t0) .* y
array_type = get_array_type(device)
init_params = if param_estim
ComponentArrays.ComponentArray(;
depvar = init_params, p = p)
else
ComponentArrays.ComponentArray(;
depvar = init_params)
end
u0_ = u0 isa Number ? u0 : array_type(u0)
ODEPhi(chain, t0, u0_, st, device), adapt(array_type, init_params)
end

function (f::ODEPhi{C, T, U})(t::AbstractVector,
θ) where {C <: Lux.AbstractExplicitLayer, T, U}
function (f::ODEPhi{C, T, U})(
t::AbstractVector, θ) where {C <: Lux.AbstractExplicitLayer, T, U}
# Batch via data as row vectors
y, st = f.chain(
adapt(parameterless_type(ComponentArrays.getdata.depvar)), t'), θ.depvar, f.st)
ChainRulesCore.@ignore_derivatives f.st = st
f.u0 .+ (t' .- f.t0) .* y
@ignore_derivatives f.st = st
f.u0 .+ (t .- f.t0)' .* y
end

"""
Expand Down Expand Up @@ -190,34 +197,37 @@ Simple L2 inner loss at a time `t` with parameters `θ` of the neural network.
function inner_loss end

function inner_loss(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::Number, θ,
p, param_estim::Bool) where {C, T, U <: Number}
p_ = param_estim ? θ.p : p
sum(abs2, ode_dfdx(phi, t, θ, autodiff) - f(phi(t, θ), p_, t))
p, param_estim::Bool) where {C, T, U}
array_type = get_array_type(phi.device)
p = param_estim ? θ.p : p
p = p isa SciMLBase.NullParameters ? p : array_type(p)
t = array_type([t])
dxdtguess = ode_dfdx(phi, t, θ, autodiff)
out = phi(t, θ)
fs = rhs(phi.device, f, phi.u0, out, p, t)
sum(abs2, dxdtguess .- fs)
end

function inner_loss(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::AbstractVector, θ,
p, param_estim::Bool) where {C, T, U <: Number}
p_ = param_estim ? θ.p : p
p, param_estim::Bool) where {C, T, U}
array_type = get_array_type(phi.device)
t = array_type(t)
p = param_estim ? θ.p : p
p = p isa SciMLBase.NullParameters ? p : array_type(p)
out = phi(t, θ)
fs = reduce(hcat, [f(out[i], p_, t[i]) for i in axes(out, 2)])
dxdtguess = Array(ode_dfdx(phi, t, θ, autodiff))
fs = rhs(phi.device, f, phi.u0, out, p, t)
dxdtguess = ode_dfdx(phi, t, θ, autodiff)
sum(abs2, dxdtguess .- fs) / length(t)
end

function inner_loss(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::Number, θ,
p, param_estim::Bool) where {C, T, U}
p_ = param_estim ? θ.p : p
sum(abs2, ode_dfdx(phi, t, θ, autodiff) .- f(phi(t, θ), p_, t))
function rhs(::LuxCPUDevice, f, u0, out, p, t)
u0 isa Number ? reduce(hcat, [f(out[i], p, t[i]) for i in axes(out, 2)]) :
reduce(hcat, [f(out[:, i], p, t[i]) for i in axes(out, 2)])
end

function inner_loss(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::AbstractVector, θ,
p, param_estim::Bool) where {C, T, U}
p_ = param_estim ? θ.p : p
out = Array(phi(t, θ))
arrt = Array(t)
fs = reduce(hcat, [f(out[:, i], p_, arrt[i]) for i in 1:size(out, 2)])
dxdtguess = Array(ode_dfdx(phi, t, θ, autodiff))
sum(abs2, dxdtguess .- fs) / length(t)
function rhs(::LuxCUDADevice, f, u0, out, p, t)
du = similar(out)
gpu_broadcast(f, du, out, p, t; workgroupsize = 64, ndrange = 100)
end

"""
Expand Down Expand Up @@ -323,8 +333,10 @@ struct NNODEInterpolation{T <: ODEPhi, T2}
phi::T
θ::T2
end
(f::NNODEInterpolation)(t, idxs::Nothing, ::Type{Val{0}}, p, continuity) = f.phi(t, f.θ)
(f::NNODEInterpolation)(t, idxs, ::Type{Val{0}}, p, continuity) = f.phi(t, f.θ)[idxs]
function (f::NNODEInterpolation)(t, idxs::Nothing, ::Type{Val{0}}, p, continuity)
vec(f.phi([t], f.θ))
end
(f::NNODEInterpolation)(t, idxs, ::Type{Val{0}}, p, continuity) = vec(f.phi([t], f.θ))[idxs]

function (f::NNODEInterpolation)(t::Vector, idxs::Nothing, ::Type{Val{0}}, p, continuity)
out = f.phi(t, f.θ)
Expand Down Expand Up @@ -358,36 +370,25 @@ function SciMLBase.__solve(prob::SciMLBase.AbstractODEProblem,
p = prob.p
t0 = tspan[1]
param_estim = alg.param_estim

#hidden layer
chain = alg.chain
opt = alg.opt
autodiff = alg.autodiff

#train points generation
init_params = alg.init_params
device = alg.device

!(chain isa Lux.AbstractExplicitLayer) &&
error("Only Lux.AbstractExplicitLayer neural networks are supported")
phi, init_params = generate_phi_θ(chain, t0, u0, init_params)
((eltype(eltype(init_params).types[1]) <: Complex ||
eltype(eltype(init_params).types[2]) <: Complex) &&
phi, init_params = generate_phi_θ(chain, t0, u0, init_params, device, p, param_estim)

(eltype(init_params) <: Complex &&
alg.strategy isa QuadratureTraining) &&
error("QuadratureTraining cannot be used with complex parameters. Use other strategies.")

init_params = if alg.param_estim
ComponentArrays.ComponentArray(;
depvar = ComponentArrays.ComponentArray(init_params), p = prob.p)
else
ComponentArrays.ComponentArray(;
depvar = ComponentArrays.ComponentArray(init_params))
end

isinplace(prob) &&
throw(error("The NNODE solver only supports out-of-place ODE definitions, i.e. du=f(u,p,t)."))

try
phi(t0, init_params)
phi(get_array_type(device)([t0]), init_params)
catch err
if isa(err, DimensionMismatch)
throw(DimensionMismatch("Dimensions of the initial u0 and chain should match"))
Expand Down Expand Up @@ -473,10 +474,11 @@ function SciMLBase.__solve(prob::SciMLBase.AbstractODEProblem,
ts = [tspan[1], tspan[2]]
end

u = phi(ts, res.u)
if u0 isa Number
u = [first(phi(t, res.u)) for t in ts]
u = vec(u)
else
u = [phi(t, res.u) for t in ts]
u = [u[:, i] for i in 1:size(u, 2)]
end

sol = SciMLBase.build_solution(prob, alg, ts, u;
Expand Down

0 comments on commit fadb032

Please sign in to comment.