diff --git a/Project.toml b/Project.toml index 201fd3dd9..a6f64865a 100644 --- a/Project.toml +++ b/Project.toml @@ -43,45 +43,48 @@ Adapt = "4" AdvancedHMC = "0.6.1" Aqua = "0.8" ArrayInterface = "7.9" -CUDA = "5.3" +CUDA = "5.3.2" ChainRulesCore = "1.24" -ComponentArrays = "0.15.14" +ComponentArrays = "0.15.16" Cubature = "1.5" DiffEqNoiseProcess = "5.20" Distributions = "0.25.107" DocStringExtensions = "0.9.3" DomainSets = "0.6, 0.7" -Flux = "0.14.11" +Flux = "0.14.17" ForwardDiff = "0.10.36" -Functors = "0.4.10" +Functors = "0.4.12" Integrals = "4.4" LineSearches = "7.2" -LinearAlgebra = "1" +LinearAlgebra = "1.10" LogDensityProblems = "2" -Lux = "0.5.58" +Lux = "1.0" LuxCUDA = "0.3.2" +LuxCore = "0.1.24" +LuxLib = "0.3.48" MCMCChains = "6" MethodOfLines = "0.11" ModelingToolkit = "9.9" MonteCarloMeasurements = "1.1" Optim = "1.7.8" -Optimization = "3.24" +Optimization = "3.25" OptimizationOptimJL = "0.2.1" OptimizationOptimisers = "0.2.1" -OrdinaryDiffEq = "6.74" -Pkg = "1" +OrdinaryDiffEq = "6.87" +Pkg = "1.10" +Preferences = "1.4.3" QuasiMonteCarlo = "0.3.2" Random = "1" Reexport = "1.2" RuntimeGeneratedFunctions = "0.5.12" SafeTestsets = "0.1" -SciMLBase = "2.28" +SciMLBase = "2.34" Statistics = "1.10" SymbolicUtils = "1.5, 2, 3" Symbolics = "5.27.1, 6" -Test = "1" +Test = "1.10" UnPack = "1" -Zygote = "0.6.69" +Zygote = "0.6.70" julia = "1.10" [extras] @@ -90,12 +93,15 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" +LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" +LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11" MethodOfLines = "94925ecb-adb7-4558-8ed8-f975c56a0bf4" OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +Preferences = "21216c6a-2e73-6563-6e65-726566657250" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Aqua", "Test", "CUDA", "SafeTestsets", "OptimizationOptimJL", "Pkg", "OrdinaryDiffEq", "LineSearches", "LuxCUDA", "Flux", "MethodOfLines"] +test = ["Aqua", "CUDA", "Flux", "LineSearches", "LuxCUDA", "LuxCore", "LuxLib", "MethodOfLines", "OptimizationOptimJL", "OrdinaryDiffEq", "Pkg", "Preferences", "SafeTestsets", "Test"] \ No newline at end of file diff --git a/src/pino_ode_solve.jl b/src/pino_ode_solve.jl index e199498e3..03a620315 100644 --- a/src/pino_ode_solve.jl +++ b/src/pino_ode_solve.jl @@ -11,7 +11,7 @@ neural operator, which is used as a solver for a parametrized `ODEProblem`. ## Positional Arguments -* `chain`: A neural network architecture, defined as a `Lux.AbstractExplicitLayer` or `Flux.Chain`. +* `chain`: A neural network architecture, defined as a `Lux.AbstractLuxLayer` or `Flux.Chain`. `Flux.Chain` will be converted to `Lux` using `adapt(FromFluxAdaptor(false, false), chain)` * `opt`: The optimizer to train the neural network. * `bounds`: A dictionary containing the bounds for the parameters of the parametric ODE. @@ -51,7 +51,7 @@ function PINOODE(chain, strategy = nothing, additional_loss = nothing, kwargs...) - !(chain isa Lux.AbstractExplicitLayer) && (chain = Lux.transform(chain)) + !(chain isa Lux.AbstractLuxLayer) && (chain = Lux.transform(chain)) PINOODE(chain, opt, bounds, number_of_parameters, init_params, strategy, additional_loss, kwargs) end @@ -59,12 +59,12 @@ end struct PINOPhi{C, S} chain::C st::S - function PINOPhi(chain::Lux.AbstractExplicitLayer, st) + function PINOPhi(chain::Lux.AbstractLuxLayer, st) new{typeof(chain), typeof(st)}(chain, st) end end -function generate_pino_phi_θ(chain::Lux.AbstractExplicitLayer, init_params) +function generate_pino_phi_θ(chain::Lux.AbstractLuxLayer, init_params) θ, st = Lux.setup(Random.default_rng(), chain) init_params = isnothing(init_params) ? θ : init_params init_params = ComponentArrays.ComponentArray(init_params) @@ -72,7 +72,7 @@ function generate_pino_phi_θ(chain::Lux.AbstractExplicitLayer, init_params) end function (f::PINOPhi{C, T})( - x, θ) where {C <: Lux.AbstractExplicitLayer, T} + x, θ) where {C <: Lux.AbstractLuxLayer, T} y, st = f.chain(adapt(parameterless_type(ComponentArrays.getdata(θ)), x), θ, f.st) y end @@ -171,8 +171,8 @@ function SciMLBase.__solve(prob::SciMLBase.AbstractODEProblem, @unpack tspan, u0, f = prob @unpack chain, opt, bounds, number_of_parameters, init_params, strategy, additional_loss = alg - if !(chain isa Lux.AbstractExplicitLayer) - error("Only Lux.AbstractExplicitLayer neural networks are supported") + if !(chain isa Lux.AbstractLuxLayer) + error("Only Lux.AbstractLuxLayer neural networks are supported") if !(chain isa DeepONet) #|| chain isa FourierNeuralOperator) error("Only DeepONet and FourierNeuralOperator neural networks are supported with PINO ODE")