Skip to content

Commit

Permalink
lux v1.0
Browse files Browse the repository at this point in the history
  • Loading branch information
KirillZubov committed Sep 19, 2024
1 parent 2366840 commit b1209ab
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 20 deletions.
32 changes: 19 additions & 13 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,45 +43,48 @@ Adapt = "4"
AdvancedHMC = "0.6.1"
Aqua = "0.8"
ArrayInterface = "7.9"
CUDA = "5.3"
CUDA = "5.3.2"
ChainRulesCore = "1.24"
ComponentArrays = "0.15.14"
ComponentArrays = "0.15.16"
Cubature = "1.5"
DiffEqNoiseProcess = "5.20"
Distributions = "0.25.107"
DocStringExtensions = "0.9.3"
DomainSets = "0.6, 0.7"
Flux = "0.14.11"
Flux = "0.14.17"
ForwardDiff = "0.10.36"
Functors = "0.4.10"
Functors = "0.4.12"
Integrals = "4.4"
LineSearches = "7.2"
LinearAlgebra = "1"
LinearAlgebra = "1.10"
LogDensityProblems = "2"
Lux = "0.5.58"
Lux = "1.0"
LuxCUDA = "0.3.2"
LuxCore = "0.1.24"
LuxLib = "0.3.48"
MCMCChains = "6"
MethodOfLines = "0.11"
ModelingToolkit = "9.9"
MonteCarloMeasurements = "1.1"
Optim = "1.7.8"
Optimization = "3.24"
Optimization = "3.25"
OptimizationOptimJL = "0.2.1"
OptimizationOptimisers = "0.2.1"
OrdinaryDiffEq = "6.74"
Pkg = "1"
OrdinaryDiffEq = "6.87"
Pkg = "1.10"
Preferences = "1.4.3"
QuasiMonteCarlo = "0.3.2"
Random = "1"
Reexport = "1.2"
RuntimeGeneratedFunctions = "0.5.12"
SafeTestsets = "0.1"
SciMLBase = "2.28"
SciMLBase = "2.34"
Statistics = "1.10"
SymbolicUtils = "1.5, 2, 3"
Symbolics = "5.27.1, 6"
Test = "1"
Test = "1.10"
UnPack = "1"
Zygote = "0.6.69"
Zygote = "0.6.70"
julia = "1.10"

[extras]
Expand All @@ -90,12 +93,15 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11"
MethodOfLines = "94925ecb-adb7-4558-8ed8-f975c56a0bf4"
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Aqua", "Test", "CUDA", "SafeTestsets", "OptimizationOptimJL", "Pkg", "OrdinaryDiffEq", "LineSearches", "LuxCUDA", "Flux", "MethodOfLines"]
test = ["Aqua", "CUDA", "Flux", "LineSearches", "LuxCUDA", "LuxCore", "LuxLib", "MethodOfLines", "OptimizationOptimJL", "OrdinaryDiffEq", "Pkg", "Preferences", "SafeTestsets", "Test"]
14 changes: 7 additions & 7 deletions src/pino_ode_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ neural operator, which is used as a solver for a parametrized `ODEProblem`.
## Positional Arguments
* `chain`: A neural network architecture, defined as a `Lux.AbstractExplicitLayer` or `Flux.Chain`.
* `chain`: A neural network architecture, defined as a `Lux.AbstractLuxLayer` or `Flux.Chain`.
`Flux.Chain` will be converted to `Lux` using `adapt(FromFluxAdaptor(false, false), chain)`
* `opt`: The optimizer to train the neural network.
* `bounds`: A dictionary containing the bounds for the parameters of the parametric ODE.
Expand Down Expand Up @@ -51,28 +51,28 @@ function PINOODE(chain,
strategy = nothing,
additional_loss = nothing,
kwargs...)
!(chain isa Lux.AbstractExplicitLayer) && (chain = Lux.transform(chain))
!(chain isa Lux.AbstractLuxLayer) && (chain = Lux.transform(chain))
PINOODE(chain, opt, bounds, number_of_parameters,
init_params, strategy, additional_loss, kwargs)
end

struct PINOPhi{C, S}
chain::C
st::S
function PINOPhi(chain::Lux.AbstractExplicitLayer, st)
function PINOPhi(chain::Lux.AbstractLuxLayer, st)
new{typeof(chain), typeof(st)}(chain, st)
end
end

function generate_pino_phi_θ(chain::Lux.AbstractExplicitLayer, init_params)
function generate_pino_phi_θ(chain::Lux.AbstractLuxLayer, init_params)
θ, st = Lux.setup(Random.default_rng(), chain)
init_params = isnothing(init_params) ? θ : init_params
init_params = ComponentArrays.ComponentArray(init_params)
PINOPhi(chain, st), init_params
end

function (f::PINOPhi{C, T})(
x, θ) where {C <: Lux.AbstractExplicitLayer, T}
x, θ) where {C <: Lux.AbstractLuxLayer, T}
y, st = f.chain(adapt(parameterless_type(ComponentArrays.getdata(θ)), x), θ, f.st)
y
end
Expand Down Expand Up @@ -171,8 +171,8 @@ function SciMLBase.__solve(prob::SciMLBase.AbstractODEProblem,
@unpack tspan, u0, f = prob
@unpack chain, opt, bounds, number_of_parameters, init_params, strategy, additional_loss = alg

if !(chain isa Lux.AbstractExplicitLayer)
error("Only Lux.AbstractExplicitLayer neural networks are supported")
if !(chain isa Lux.AbstractLuxLayer)
error("Only Lux.AbstractLuxLayer neural networks are supported")

if !(chain isa DeepONet) #|| chain isa FourierNeuralOperator)
error("Only DeepONet and FourierNeuralOperator neural networks are supported with PINO ODE")
Expand Down

0 comments on commit b1209ab

Please sign in to comment.