From 043d429aca1c6985ee7a45c81380ef2f4855d0de Mon Sep 17 00:00:00 2001 From: xtalax Date: Wed, 12 Apr 2023 17:46:50 +0100 Subject: [PATCH 01/40] begin symbolics overhaul --- Project.toml | 3 +- src/NeuralPDE.jl | 4 +- src/discretize.jl | 142 ++------ src/eq_data.jl | 54 +++ src/pinn_types.jl | 668 +++++++++++++++++++++----------------- src/symbolic_utilities.jl | 125 ++----- 6 files changed, 479 insertions(+), 517 deletions(-) create mode 100644 src/eq_data.jl diff --git a/Project.toml b/Project.toml index 50f95c9cab..a4be001957 100644 --- a/Project.toml +++ b/Project.toml @@ -23,6 +23,7 @@ ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78" Optim = "429524aa-4258-5aef-a3af-852621145aeb" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" +PDEBase = "a7812802-0625-4b9e-961c-d332478797e5" QuasiMonteCarlo = "8a4e6c94-4038-4cdc-81c3-7e6ffdb2a71b" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" @@ -79,4 +80,4 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "CUDA", "SafeTestsets", "OptimizationOptimisers", "OptimizationOptimJL", "Pkg", "OrdinaryDiffEq", "IntegralsCuba"] \ No newline at end of file +test = ["Test", "CUDA", "SafeTestsets", "OptimizationOptimisers", "OptimizationOptimJL", "Pkg", "OrdinaryDiffEq", "IntegralsCuba"] diff --git a/src/NeuralPDE.jl b/src/NeuralPDE.jl index f61544274f..42f2088436 100644 --- a/src/NeuralPDE.jl +++ b/src/NeuralPDE.jl @@ -15,6 +15,7 @@ using Integrals, IntegralsCubature using QuasiMonteCarlo using RuntimeGeneratedFunctions using SciMLBase +using PDEBase using Statistics using ArrayInterface import Optim @@ -32,9 +33,8 @@ import RecursiveArrayTools import ChainRulesCore, Flux, Lux, ComponentArrays import ChainRulesCore: @non_differentiable -RuntimeGeneratedFunctions.init(@__MODULE__) -abstract type AbstractPINN end +RuntimeGeneratedFunctions.init(@__MODULE__) abstract type AbstractTrainingStrategy end diff --git a/src/discretize.jl b/src/discretize.jl index 4308a79b4e..13bade9828 100644 --- a/src/discretize.jl +++ b/src/discretize.jl @@ -49,7 +49,7 @@ function build_symbolic_loss_function(pinnrep::PINNRepresentation, eqs; dict_transformation_vars = nothing, transformation_vars = nothing, integrating_depvars = pinnrep.depvars) - @unpack indvars, depvars, dict_indvars, dict_depvars, dict_depvar_input, + @unpack v, eqdata, phi, derivative, integral, multioutput, init_params, strategy, eq_params, param_estim, default_p = pinnrep @@ -287,12 +287,11 @@ function get_bounds(domains, eqs, bcs, eltypeθ, _indvars::Array, _depvars::Arra return get_bounds(domains, eqs, bcs, eltypeθ, dict_indvars, dict_depvars, strategy) end -function get_bounds(domains, eqs, bcs, eltypeθ, dict_indvars, dict_depvars, - strategy::QuadratureTraining) - dict_lower_bound = Dict([Symbol(d.variables) => infimum(d.domain) for d in domains]) - dict_upper_bound = Dict([Symbol(d.variables) => supremum(d.domain) for d in domains]) - - pde_args = get_argument(eqs, dict_indvars, dict_depvars) +function get_bounds(domains, eqs, bcs, eltypeθ, v::VariableMap, strategy::QuadratureTraining) + dict_lower_bound = Dict([d.variables => infimum(d.domain) for d in domains]) + dict_upper_bound = Dict([d.variables => supremum(d.domain) for d in domains]) + #! Fix this to work with a var_eq mapping + pde_args = get_argument(eqs, v) pde_lower_bounds = map(pde_args) do pd span = map(p -> get(dict_lower_bound, p, p), pd) @@ -342,10 +341,9 @@ function get_bounds(domains, eqs, bcs, eltypeθ, dict_indvars, dict_depvars, str end function get_numeric_integral(pinnrep::PINNRepresentation) - @unpack strategy, indvars, depvars, multioutput, derivative, - depvars, indvars, dict_indvars, dict_depvars = pinnrep + @unpack strategy, multioutput, derivative, varmap = pinnrep - integral = (u, cord, phi, integrating_var_id, integrand_func, lb, ub, θ; strategy = strategy, indvars = indvars, depvars = depvars, dict_indvars = dict_indvars, dict_depvars = dict_depvars) -> begin + integral = (u, cord, phi, integrating_var_id, integrand_func, lb, ub, θ; strategy = strategy, varmap=varmap) -> begin function integration_(cord, lb, ub, θ) cord_ = cord function integrand_(x, p) @@ -416,92 +414,19 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem, additional_loss = discretization.additional_loss adaloss = discretization.adaptive_loss - depvars, indvars, dict_indvars, dict_depvars, dict_depvar_input = get_vars(pde_system.indvars, - pde_system.depvars) + v = VariableMap(pde_system, discretization) - multioutput = discretization.multioutput - init_params = discretization.init_params + # Find the derivative orders in the bcs + bcorders = Dict(map(x -> x => d_orders(x, pdesys.bcs), all_ivs(v))) + # Create a map of each variable to their boundary conditions including initial conditions + boundarymap = parse_bcs(pdesys.bcs, v, bcorders) - if init_params === nothing - # Use the initialization of the neural network framework - # But for Lux, default to Float64 - # For Flux, default to the types matching the values in the neural network - # This is done because Float64 is almost always better for these applications - # But with Flux there's already a chosen type from the user - - if chain isa AbstractArray - if chain[1] isa Flux.Chain - init_params = map(chain) do x - _x = Flux.destructure(x)[1] - end - else - x = map(chain) do x - _x = ComponentArrays.ComponentArray(Lux.initialparameters(Random.default_rng(), - x)) - Float64.(_x) # No ComponentArray GPU support - end - names = ntuple(i -> depvars[i], length(chain)) - init_params = ComponentArrays.ComponentArray(NamedTuple{names}(i - for i in x)) - end - else - if chain isa Flux.Chain - init_params = Flux.destructure(chain)[1] - init_params = init_params isa Array ? Float64.(init_params) : - init_params - else - init_params = Float64.(ComponentArrays.ComponentArray(Lux.initialparameters(Random.default_rng(), - chain))) - end - end - else - init_params = init_params - end - - if (discretization.phi isa Vector && discretization.phi[1].f isa Optimisers.Restructure) || - (!(discretization.phi isa Vector) && discretization.phi.f isa Optimisers.Restructure) - # Flux.Chain - flat_init_params = multioutput ? reduce(vcat, init_params) : init_params - flat_init_params = param_estim == false ? flat_init_params : - vcat(flat_init_params, - adapt(typeof(flat_init_params), default_p)) - else - flat_init_params = if init_params isa ComponentArrays.ComponentArray - init_params - elseif multioutput - @assert length(init_params) == length(depvars) - names = ntuple(i -> depvars[i], length(init_params)) - x = ComponentArrays.ComponentArray(NamedTuple{names}(i for i in init_params)) - else - ComponentArrays.ComponentArray(init_params) - end - flat_init_params = if param_estim == false && multioutput - ComponentArrays.ComponentArray(; depvar = flat_init_params) - elseif param_estim == false && !multioutput - flat_init_params - else - ComponentArrays.ComponentArray(; depvar = flat_init_params, p = default_p) - end - end - - eltypeθ = eltype(flat_init_params) - - if adaloss === nothing - adaloss = NonAdaptiveLoss{eltypeθ}() - end + eqdata = EquationData(pdesys, v) + multioutput = discretization.multioutput + init_params = discretization.init_params phi = discretization.phi - if (phi isa Vector && phi[1].f isa Lux.AbstractExplicitLayer) - for ϕ in phi - ϕ.st = adapt(parameterless_type(ComponentArrays.getdata(flat_init_params)), - ϕ.st) - end - elseif (!(phi isa Vector) && phi.f isa Lux.AbstractExplicitLayer) - phi.st = adapt(parameterless_type(ComponentArrays.getdata(flat_init_params)), - phi.st) - end - derivative = discretization.derivative strategy = discretization.strategy @@ -510,32 +435,11 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem, iteration = discretization.iteration self_increment = discretization.self_increment - if !(eqs isa Array) - eqs = [eqs] - end - - pde_indvars = if strategy isa QuadratureTraining - get_argument(eqs, dict_indvars, dict_depvars) - else - get_variables(eqs, dict_indvars, dict_depvars) - end - - bc_indvars = if strategy isa QuadratureTraining - get_argument(bcs, dict_indvars, dict_depvars) - else - get_variables(bcs, dict_indvars, dict_depvars) - end - - pde_integration_vars = get_integration_variables(eqs, dict_indvars, dict_depvars) - bc_integration_vars = get_integration_variables(bcs, dict_indvars, dict_depvars) - pinnrep = PINNRepresentation(eqs, bcs, domains, eq_params, defaults, default_p, - param_estim, additional_loss, adaloss, depvars, indvars, - dict_indvars, dict_depvars, dict_depvar_input, logger, + param_estim, additional_loss, adaloss, v, logger, multioutput, iteration, init_params, flat_init_params, phi, derivative, - strategy, pde_indvars, bc_indvars, pde_integration_vars, - bc_integration_vars, nothing, nothing, nothing, nothing) + strategy, eqdata, nothing, nothing, nothing, nothing) integral = get_numeric_integral(pinnrep) @@ -553,15 +457,9 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem, pinnrep.symbolic_pde_loss_functions = symbolic_pde_loss_functions pinnrep.symbolic_bc_loss_functions = symbolic_bc_loss_functions - datafree_pde_loss_functions = [build_loss_function(pinnrep, eq, pde_indvar) - for (eq, pde_indvar, integration_indvar) in zip(eqs, - pde_indvars, - pde_integration_vars)] + datafree_pde_loss_functions = [build_loss_function(pinnrep, eq) for eq in eqs] - datafree_bc_loss_functions = [build_loss_function(pinnrep, bc, bc_indvar) - for (bc, bc_indvar, integration_indvar) in zip(bcs, - bc_indvars, - bc_integration_vars)] + datafree_bc_loss_functions = [build_loss_function(pinnrep, bc) for bc in bcs] pde_loss_functions, bc_loss_functions = merge_strategy_with_loss_function(pinnrep, strategy, diff --git a/src/eq_data.jl b/src/eq_data.jl new file mode 100644 index 0000000000..5475184d1e --- /dev/null +++ b/src/eq_data.jl @@ -0,0 +1,54 @@ +struct EquationData <: PDEBase.AbstractVarEqMapping + depvarmap + indvarmap + pde_indvars + bc_indvars + argmap +end + +function EquationData(pdesys, v) + eqs = pdesys.eqs + bcs = pdesys.bcs + alleqs = vcat(eqs, bcs) + + argmap = map(alleqs) do eq + eq => get_argument([eq], v)[1] + end + depvarmap = map(alleqs) do eq + eq => get_depvars(eq, v.depvar_ops) + end + indvarmap = map(alleqs) do eq + eq => get_indvars(eq, indvars(v)) + end + pde_indvars = if strategy isa QuadratureTraining + get_argument(eqs, v) + else + get_variables(eqs, v) + end + + bc_indvars = if strategy isa QuadratureTraining + get_argument(bcs, v) + else + get_variables(bcs, v) + end + + EquationData(depvarmap, indvarmap, pde_indvars, bc_depvars, argmap) +end + +function depvars(eq, eqdata::EquationData) + eqdata.depvarmap[eq] +end + +function indvars(eq, eqdata::EquationData) + eqdata.indvarmap[eq] +end + +function pde_indvars(eqdata::EquationData) + eqdata.pde_indvars +end + +function bc_indvars(eqdata::EquationData) + eqdata.bc_indvars +end + +argument(eq, eqdata) = eqdata.argmap[eq] diff --git a/src/pinn_types.jl b/src/pinn_types.jl index ebb69be08c..f73a53fb1e 100644 --- a/src/pinn_types.jl +++ b/src/pinn_types.jl @@ -2,43 +2,43 @@ ??? """ struct LogOptions - log_frequency::Int64 - # TODO: add in an option for saving plots in the log. this is currently not done because the type of plot is dependent on the PDESystem - # possible solution: pass in a plot function? - # this is somewhat important because we want to support plotting adaptive weights that depend on pde independent variables - # and not just one weight for each loss function, i.e. pde_loss_weights(i, t, x) and since this would be function-internal, - # we'd want the plot & log to happen internally as well - # plots of the learned function can happen in the outer callback, but we might want to offer that here too - - SciMLBase.@add_kwonly function LogOptions(; log_frequency = 50) - new(convert(Int64, log_frequency)) - end + log_frequency::Int64 + # TODO: add in an option for saving plots in the log. this is currently not done because the type of plot is dependent on the PDESystem + # possible solution: pass in a plot function? + # this is somewhat important because we want to support plotting adaptive weights that depend on pde independent variables + # and not just one weight for each loss function, i.e. pde_loss_weights(i, t, x) and since this would be function-internal, + # we'd want the plot & log to happen internally as well + # plots of the learned function can happen in the outer callback, but we might want to offer that here too + + SciMLBase.@add_kwonly function LogOptions(; log_frequency = 50) + new(convert(Int64, log_frequency)) + end end """This function is defined here as stubs to be overriden by the subpackage NeuralPDELogging if imported""" function logvector(logger, v::AbstractVector{R}, name::AbstractString, - step::Integer) where {R <: Real} - nothing + step::Integer) where {R <: Real} + nothing end """This function is defined here as stubs to be overriden by the subpackage NeuralPDELogging if imported""" function logscalar(logger, s::R, name::AbstractString, step::Integer) where {R <: Real} - nothing + nothing end """ ```julia PhysicsInformedNN(chain, - strategy; - init_params = nothing, - phi = nothing, - param_estim = false, - additional_loss = nothing, - adaptive_loss = nothing, - logger = nothing, - log_options = LogOptions(), - iteration = nothing, - kwargs...) where {iip} + strategy; + init_params = nothing, + phi = nothing, + param_estim = false, + additional_loss = nothing, + adaptive_loss = nothing, + logger = nothing, + log_options = LogOptions(), + iteration = nothing, + kwargs...) where {iip} ``` A `discretize` algorithm for the ModelingToolkit PDESystem interface, which transforms a @@ -60,6 +60,7 @@ methodology. `init_params` should match `Flux.destructure(chain)[1]` in shape. If `init_params` is not given, then the neural network default parameters are used. Note that for Lux, the default will convert to Float64. +* `flat_init_params`: the initial parameters of the neural networks, flattened into a vector. * `phi`: a trial solution, specified as `phi(x,p)` where `x` is the coordinates vector for the dependent variable and `p` are the weights of the phi function (generally the weights of the neural network defining `phi`). By default, this is generated from the `chain`. This @@ -78,76 +79,155 @@ methodology. * `iteration`: used to control the iteration counter??? * `kwargs`: Extra keyword arguments which are splatted to the `OptimizationProblem` on `solve`. """ -struct PhysicsInformedNN{T, P, PH, DER, PE, AL, ADA, LOG, K} <: AbstractPINN - chain::Any - strategy::T - init_params::P - phi::PH - derivative::DER - param_estim::PE - additional_loss::AL - adaptive_loss::ADA - logger::LOG - log_options::LogOptions - iteration::Vector{Int64} - self_increment::Bool - multioutput::Bool - kwargs::K - - @add_kwonly function PhysicsInformedNN(chain, - strategy; - init_params = nothing, - phi = nothing, - derivative = nothing, - param_estim = false, - additional_loss = nothing, - adaptive_loss = nothing, - logger = nothing, - log_options = LogOptions(), - iteration = nothing, - kwargs...) where {iip} - multioutput = typeof(chain) <: AbstractArray - - if phi === nothing - if multioutput - _phi = Phi.(chain) - else - _phi = Phi(chain) - end - else - _phi = phi - end - - if derivative === nothing - _derivative = numeric_derivative - else - _derivative = derivative - end - - if iteration isa Vector{Int64} - self_increment = false - else - iteration = [1] - self_increment = true - end - - new{typeof(strategy), typeof(init_params), typeof(_phi), typeof(_derivative), - typeof(param_estim), - typeof(additional_loss), typeof(adaptive_loss), typeof(logger), typeof(kwargs)}(chain, - strategy, - init_params, - _phi, - _derivative, - param_estim, - additional_loss, - adaptive_loss, - logger, - log_options, - iteration, - self_increment, - multioutput, - kwargs) - end +struct PhysicsInformedNN{T, P, PH, DER, PE, AL, ADA, LOG, K, F} <: SciMLBase.AbstractDiscretization + chain::Any + strategy::T + init_params::P + flat_init_params::F + phi::PH + derivative::DER + param_estim::PE + additional_loss::AL + adaptive_loss::ADA + logger::LOG + log_options::LogOptions + iteration::Vector{Int64} + self_increment::Bool + multioutput::Bool + kwargs::K + + @add_kwonly function PhysicsInformedNN(chain, + strategy; + init_params = nothing, + phi = nothing, + derivative = nothing, + param_estim = false, + additional_loss = nothing, + adaptive_loss = nothing, + logger = nothing, + log_options = LogOptions(), + iteration = nothing, + kwargs...) where {iip} + multioutput = typeof(chain) <: AbstractArray + + if phi === nothing + if multioutput + phi = Phi.(chain) + else + phi = Phi(chain) + end + end + + if derivative === nothing + _derivative = numeric_derivative + else + _derivative = derivative + end + + if iteration isa Vector{Int64} + self_increment = false + else + iteration = [1] + self_increment = true + end + + if init_params === nothing + # Use the initialization of the neural network framework + # But for Lux, default to Float64 + # For Flux, default to the types matching the values in the neural network + # This is done because Float64 is almost always better for these applications + # But with Flux there's already a chosen type from the user + + if chain isa AbstractArray + if chain[1] isa Flux.Chain + init_params = map(chain) do x + _x = Flux.destructure(x)[1] + end + else + x = map(chain) do x + _x = ComponentArrays.ComponentArray(Lux.initialparameters(Random.default_rng(), + x)) + Float64.(_x) # No ComponentArray GPU support + end + names = ntuple(i -> depvars[i], length(chain)) + init_params = ComponentArrays.ComponentArray(NamedTuple{names}(i + for i in x)) + end + else + if chain isa Flux.Chain + init_params = Flux.destructure(chain)[1] + init_params = init_params isa Array ? Float64.(init_params) : + init_params + else + init_params = Float64.(ComponentArrays.ComponentArray(Lux.initialparameters(Random.default_rng(), + chain))) + end + end + else + init_params = init_params + end + + if (phi isa Vector && phi[1].f isa Optimisers.Restructure) || + (!(phi isa Vector) && phi.f isa Optimisers.Restructure) + # Flux.Chain + flat_init_params = multioutput ? reduce(vcat, init_params) : init_params + flat_init_params = param_estim == false ? flat_init_params : + vcat(flat_init_params, + adapt(typeof(flat_init_params), default_p)) + else + flat_init_params = if init_params isa ComponentArrays.ComponentArray + init_params + elseif multioutput + @assert length(init_params) == length(depvars) + names = ntuple(i -> depvars[i], length(init_params)) + x = ComponentArrays.ComponentArray(NamedTuple{names}(i for i in init_params)) + else + ComponentArrays.ComponentArray(init_params) + end + flat_init_params = if param_estim == false && multioutput + ComponentArrays.ComponentArray(; depvar = flat_init_params) + elseif param_estim == false && !multioutput + flat_init_params + else + ComponentArrays.ComponentArray(; depvar = flat_init_params, p = default_p) + end + end + + if (phi isa Vector && phi[1].f isa Lux.AbstractExplicitLayer) + for ϕ in phi + ϕ.st = adapt(parameterless_type(ComponentArrays.getdata(flat_init_params)), + ϕ.st) + end + elseif (!(phi isa Vector) && phi.f isa Lux.AbstractExplicitLayer) + phi.st = adapt(parameterless_type(ComponentArrays.getdata(flat_init_params)), + phi.st) + end + + eltypeθ = eltype(flat_init_params) + + if adaloss === nothing + adaloss = NonAdaptiveLoss{eltypeθ}() + end + + new{typeof(strategy), typeof(init_params), typeof(_phi), typeof(_derivative), + typeof(param_estim), + typeof(additional_loss), typeof(adaptive_loss), + typeof(logger), typeof(kwargs), typeof(flat_init_params)}(chain, + strategy, + init_params, + flat_init_params, + _phi, + _derivative, + param_estim, + additional_loss, + adaptive_loss, + logger, + log_options, + iteration, + self_increment, + multioutput, + kwargs) + end end """ @@ -161,137 +241,121 @@ used internally and returned for introspection by `symbolic_discretize`. $(FIELDS) """ mutable struct PINNRepresentation - """ - The equations of the PDE - """ - eqs::Any - """ - The boundary condition equations - """ - bcs::Any - """ - The domains for each of the independent variables - """ - domains::Any - """ - ??? - """ - eq_params::Any - """ - ??? - """ - defaults::Any - """ - ??? - """ - default_p::Any - """ - Whether parameters are to be appended to the `additional_loss` - """ - param_estim::Any - """ - The `additional_loss` function as provided by the user - """ - additional_loss::Any - """ - The adaptive loss function - """ - adaloss::Any - """ - The dependent variables of the system - """ - depvars::Any - """ - The independent variables of the system - """ - indvars::Any - """ - A dictionary form of the independent variables. Define the structure ??? - """ - dict_indvars::Any - """ - A dictionary form of the dependent variables. Define the structure ??? - """ - dict_depvars::Any - """ - ??? - """ - dict_depvar_input::Any - """ - The logger as provided by the user - """ - logger::Any - """ - Whether there are multiple outputs, i.e. a system of PDEs - """ - multioutput::Bool - """ - The iteration counter used inside the cost function - """ - iteration::Vector{Int} - """ - The initial parameters as provided by the user. If the PDE is a system of PDEs, this - will be an array of arrays. If Lux.jl is used, then this is an array of ComponentArrays. - """ - init_params::Any - """ - The initial parameters as a flattened array. This is the array that is used in the - construction of the OptimizationProblem. If a Lux.jl neural network is used, then this - flattened form is a `ComponentArray`. If the equation is a system of equations, then - `flat_init_params.depvar.x` are the parameters for the neural network corresponding - to the dependent variable `x`, and i.e. if `depvar[i] == :x` then for `phi[i]`. - If `param_estim = true`, then `flat_init_params.p` are the parameters and - `flat_init_params.depvar.x` are the neural network parameters, so - `flat_init_params.depvar.x` would be the parameters of the neural network for the - dependent variable `x` if it's a system. If a Flux.jl neural network is used, this is - simply an `AbstractArray` to be indexed and the sizes from the chains must be - remembered/stored/used. - """ - flat_init_params::Any - """ - The representation of the test function of the PDE solution - """ - phi::Any - """ - The function used for computing the derivative - """ - derivative::Any - """ - The training strategy as provided by the user - """ - strategy::AbstractTrainingStrategy - """ - ??? - """ - pde_indvars::Any - """ - ??? - """ - bc_indvars::Any - """ - ??? - """ - pde_integration_vars::Any - """ - ??? - """ - bc_integration_vars::Any - """ - ??? - """ - integral::Any - """ - The PDE loss functions as represented in Julia AST - """ - symbolic_pde_loss_functions::Any - """ - The boundary condition loss functions as represented in Julia AST - """ - symbolic_bc_loss_functions::Any - """ - The PINNLossFunctions, i.e. the generated loss functions - """ - loss_functions::Any + """ + The equations of the PDE + """ + eqs::Any + """ + The boundary condition equations + """ + bcs::Any + """ + The domains for each of the independent variables + """ + domains::Any + """ + ??? + """ + eq_params::Any + """ + ??? + """ + defaults::Any + """ + ??? + """ + default_p::Any + """ + Whether parameters are to be appended to the `additional_loss` + """ + param_estim::Any + """ + The `additional_loss` function as provided by the user + """ + additional_loss::Any + """ + The adaptive loss function + """ + adaloss::Any + """ + The VariableMap of the PDESystem + """ + varmap::Any + """ + The logger as provided by the user + """ + logger::Any + """ + Whether there are multiple outputs, i.e. a system of PDEs + """ + multioutput::Bool + """ + The iteration counter used inside the cost function + """ + iteration::Vector{Int} + """ + The initial parameters as provided by the user. If the PDE is a system of PDEs, this + will be an array of arrays. If Lux.jl is used, then this is an array of ComponentArrays. + """ + init_params::Any + """ + The initial parameters as a flattened array. This is the array that is used in the + construction of the OptimizationProblem. If a Lux.jl neural network is used, then this + flattened form is a `ComponentArray`. If the equation is a system of equations, then + `flat_init_params.depvar.x` are the parameters for the neural network corresponding + to the dependent variable `x`, and i.e. if `depvar[i] == :x` then for `phi[i]`. + If `param_estim = true`, then `flat_init_params.p` are the parameters and + `flat_init_params.depvar.x` are the neural network parameters, so + `flat_init_params.depvar.x` would be the parameters of the neural network for the + dependent variable `x` if it's a system. If a Flux.jl neural network is used, this is + simply an `AbstractArray` to be indexed and the sizes from the chains must be + remembered/stored/used. + """ + flat_init_params::Any + """ + The representation of the test function of the PDE solution + """ + phi::Any + """ + The function used for computing the derivative + """ + derivative::Any + """ + The training strategy as provided by the user + """ + strategy::AbstractTrainingStrategy + """ + ??? + """ + pde_indvars::Any + """ + ??? + """ + bc_indvars::Any + """ + ??? + """ + pde_integration_vars::Any + """ + ??? + """ + bc_integration_vars::Any + """ + ??? + """ + integral::Any + """ + The PDE loss functions as represented in Julia AST + """ + symbolic_pde_loss_functions::Any + """ + The boundary condition loss functions as represented in Julia AST + """ + symbolic_bc_loss_functions::Any + """ + The PINNLossFunctions, i.e. the generated loss functions + """ + loss_functions::Any end """ @@ -304,31 +368,31 @@ The generated functions from the PINNRepresentation $(FIELDS) """ struct PINNLossFunctions - """ - The boundary condition loss functions - """ - bc_loss_functions::Any - """ - The PDE loss functions - """ - pde_loss_functions::Any - """ - The full loss function, combining the PDE and boundary condition loss functions. - This is the loss function that is used by the optimizer. - """ - full_loss_function::Any - """ - The wrapped `additional_loss`, as pieced together for the optimizer. - """ - additional_loss_function::Any - """ - The pre-data version of the PDE loss function - """ - datafree_pde_loss_functions::Any - """ - The pre-data version of the BC loss function - """ - datafree_bc_loss_functions::Any + """ + The boundary condition loss functions + """ + bc_loss_functions::Any + """ + The PDE loss functions + """ + pde_loss_functions::Any + """ + The full loss function, combining the PDE and boundary condition loss functions. + This is the loss function that is used by the optimizer. + """ + full_loss_function::Any + """ + The wrapped `additional_loss`, as pieced together for the optimizer. + """ + additional_loss_function::Any + """ + The pre-data version of the PDE loss function + """ + datafree_pde_loss_functions::Any + """ + The pre-data version of the BC loss function + """ + datafree_bc_loss_functions::Any end """ @@ -343,72 +407,72 @@ Fields: It should be updated on each call. """ mutable struct Phi{C, S} - f::C - st::S - function Phi(chain::Lux.AbstractExplicitLayer) - st = Lux.initialstates(Random.default_rng(), chain) - new{typeof(chain), typeof(st)}(chain, st) - end - function Phi(chain::Flux.Chain) - re = Flux.destructure(chain)[2] - new{typeof(re), Nothing}(re, nothing) - end + f::C + st::S + function Phi(chain::Lux.AbstractExplicitLayer) + st = Lux.initialstates(Random.default_rng(), chain) + new{typeof(chain), typeof(st)}(chain, st) + end + function Phi(chain::Flux.Chain) + re = Flux.destructure(chain)[2] + new{typeof(re), Nothing}(re, nothing) + end end function (f::Phi{<:Lux.AbstractExplicitLayer})(x::Number, θ) - y, st = f.f(adapt(parameterless_type(ComponentArrays.getdata(θ)), [x]), θ, f.st) - ChainRulesCore.@ignore_derivatives f.st = st - y + y, st = f.f(adapt(parameterless_type(ComponentArrays.getdata(θ)), [x]), θ, f.st) + ChainRulesCore.@ignore_derivatives f.st = st + y end function (f::Phi{<:Lux.AbstractExplicitLayer})(x::AbstractArray, θ) - y, st = f.f(adapt(parameterless_type(ComponentArrays.getdata(θ)), x), θ, f.st) - ChainRulesCore.@ignore_derivatives f.st = st - y + y, st = f.f(adapt(parameterless_type(ComponentArrays.getdata(θ)), x), θ, f.st) + ChainRulesCore.@ignore_derivatives f.st = st + y end function (f::Phi{<:Optimisers.Restructure})(x, θ) - f.f(θ)(adapt(parameterless_type(θ), x)) + f.f(θ)(adapt(parameterless_type(θ), x)) end function get_u() - u = (cord, θ, phi) -> phi(cord, θ) + u = (cord, θ, phi) -> phi(cord, θ) end # the method to calculate the derivative function numeric_derivative(phi, u, x, εs, order, θ) - _type = parameterless_type(ComponentArrays.getdata(θ)) - - ε = εs[order] - _epsilon = inv(first(ε[ε .!= zero(ε)])) - - ε = adapt(_type, ε) - x = adapt(_type, x) - - # any(x->x!=εs[1],εs) - # εs is the epsilon for each order, if they are all the same then we use a fancy formula - # if order 1, this is trivially true - - if order > 4 || any(x -> x != εs[1], εs) - return (numeric_derivative(phi, u, x .+ ε, @view(εs[1:(end - 1)]), order - 1, θ) - .- - numeric_derivative(phi, u, x .- ε, @view(εs[1:(end - 1)]), order - 1, θ)) .* - _epsilon ./ 2 - elseif order == 4 - return (u(x .+ 2 .* ε, θ, phi) .- 4 .* u(x .+ ε, θ, phi) - .+ - 6 .* u(x, θ, phi) - .- - 4 .* u(x .- ε, θ, phi) .+ u(x .- 2 .* ε, θ, phi)) .* _epsilon^4 - elseif order == 3 - return (u(x .+ 2 .* ε, θ, phi) .- 2 .* u(x .+ ε, θ, phi) .+ 2 .* u(x .- ε, θ, phi) - - - u(x .- 2 .* ε, θ, phi)) .* _epsilon^3 ./ 2 - elseif order == 2 - return (u(x .+ ε, θ, phi) .+ u(x .- ε, θ, phi) .- 2 .* u(x, θ, phi)) .* _epsilon^2 - elseif order == 1 - return (u(x .+ ε, θ, phi) .- u(x .- ε, θ, phi)) .* _epsilon ./ 2 - else - error("This shouldn't happen!") - end + _type = parameterless_type(ComponentArrays.getdata(θ)) + + ε = εs[order] + _epsilon = inv(first(ε[ε.!=zero(ε)])) + + ε = adapt(_type, ε) + x = adapt(_type, x) + + # any(x->x!=εs[1],εs) + # εs is the epsilon for each order, if they are all the same then we use a fancy formula + # if order 1, this is trivially true + + if order > 4 || any(x -> x != εs[1], εs) + return (numeric_derivative(phi, u, x .+ ε, @view(εs[1:(end-1)]), order - 1, θ) + .- + numeric_derivative(phi, u, x .- ε, @view(εs[1:(end-1)]), order - 1, θ)) .* + _epsilon ./ 2 + elseif order == 4 + return (u(x .+ 2 .* ε, θ, phi) .- 4 .* u(x .+ ε, θ, phi) + .+ + 6 .* u(x, θ, phi) + .- + 4 .* u(x .- ε, θ, phi) .+ u(x .- 2 .* ε, θ, phi)) .* _epsilon^4 + elseif order == 3 + return (u(x .+ 2 .* ε, θ, phi) .- 2 .* u(x .+ ε, θ, phi) .+ 2 .* u(x .- ε, θ, phi) + - + u(x .- 2 .* ε, θ, phi)) .* _epsilon^3 ./ 2 + elseif order == 2 + return (u(x .+ ε, θ, phi) .+ u(x .- ε, θ, phi) .- 2 .* u(x, θ, phi)) .* _epsilon^2 + elseif order == 1 + return (u(x .+ ε, θ, phi) .- u(x .- ε, θ, phi)) .* _epsilon ./ 2 + else + error("This shouldn't happen! Got an order of $(order).") + end end diff --git a/src/symbolic_utilities.jl b/src/symbolic_utilities.jl index 9161f3c365..9fcdf2378f 100644 --- a/src/symbolic_utilities.jl +++ b/src/symbolic_utilities.jl @@ -342,81 +342,17 @@ function pair(eq, depvars, dict_depvars, dict_depvar_input) Dict(filter(p -> p !== nothing, pair_)) end -function get_vars(indvars_, depvars_) - indvars = ModelingToolkit.getname.(indvars_) - depvars = Symbol[] - dict_depvar_input = Dict{Symbol, Vector{Symbol}}() - for d in depvars_ - if unwrap(d) isa SymbolicUtils.BasicSymbolic - dname = ModelingToolkit.getname(d) - push!(depvars, dname) - push!(dict_depvar_input, - dname => [nameof(unwrap(argument)) - for argument in arguments(unwrap(d))]) - else - dname = ModelingToolkit.getname(d) - push!(depvars, dname) - push!(dict_depvar_input, dname => indvars) # default to all inputs if not given - end - end - - dict_indvars = get_dict_vars(indvars) - dict_depvars = get_dict_vars(depvars) - return depvars, indvars, dict_indvars, dict_depvars, dict_depvar_input -end - -function get_integration_variables(eqs, _indvars::Array, _depvars::Array) - depvars, indvars, dict_indvars, dict_depvars, dict_depvar_input = get_vars(_indvars, - _depvars) - get_integration_variables(eqs, dict_indvars, dict_depvars) -end - -function get_integration_variables(eqs, dict_indvars, dict_depvars) - exprs = toexpr.(eqs) - vars = map(exprs) do expr - _vars = Symbol.(filter(indvar -> length(find_thing_in_expr(expr, indvar)) > 0, - sort(collect(keys(dict_indvars))))) - end -end - -""" -``julia -get_variables(eqs,_indvars,_depvars) -``` - -Returns all variables that are used in each equations or boundary condition. -""" -function get_variables end - -function get_variables(eqs, _indvars::Array, _depvars::Array) - depvars, indvars, dict_indvars, dict_depvars, dict_depvar_input = get_vars(_indvars, - _depvars) - return get_variables(eqs, dict_indvars, dict_depvars) -end - -function get_variables(eqs, dict_indvars, dict_depvars) - bc_args = get_argument(eqs, dict_indvars, dict_depvars) - return map(barg -> filter(x -> x isa Symbol, barg), bc_args) -end +function pair(eq, v::VariableMap) -function get_number(eqs, dict_indvars, dict_depvars) - bc_args = get_argument(eqs, dict_indvars, dict_depvars) - return map(barg -> filter(x -> x isa Number, barg), bc_args) -end - -function find_thing_in_expr(ex::Expr, thing; ans = []) - if thing in ex.args - push!(ans, ex) - end - for e in ex.args - if e isa Expr - if thing in e.args - push!(ans, e) - end - find_thing_in_expr(e, thing; ans = ans) + pair_ = map(v.depvar_ops) do op + if !isempty(find_thing_in_expr(toexpr(eq), depvar)) + depvar => v.depvar_input[depvar] end end - return collect(Set(ans)) + +function get_integration_variables(eqs, v::VariableMap) + ivs = all_ivs(v) + return map(eq -> get_indvars(eq, ivs), eqs) end """ @@ -428,34 +364,43 @@ Returns all arguments that are used in each equations or boundary condition. """ function get_argument end -# Get arguments from boundary condition functions -function get_argument(eqs, _indvars::Array, _depvars::Array) - depvars, indvars, dict_indvars, dict_depvars, dict_depvar_input = get_vars(_indvars, - _depvars) - get_argument(eqs, dict_indvars, dict_depvars) -end -function get_argument(eqs, dict_indvars, dict_depvars) - exprs = toexpr.(eqs) - vars = map(exprs) do expr - _vars = map(depvar -> find_thing_in_expr(expr, depvar), collect(keys(dict_depvars))) +function get_argument(eqs, v::VariableMap) + vars = map(eqs) do eq + _vars = map(depvar -> get_depvars(eq, depvar), v.depvar_ops) f_vars = filter(x -> !isempty(x), _vars) map(x -> first(x), f_vars) end args_ = map(vars) do _vars - ind_args_ = map(var -> var.args[2:end], _vars) - syms = Set{Symbol}() - filter(vcat(ind_args_...)) do ind_arg - if ind_arg isa Symbol - if ind_arg ∈ syms + seen = [] + filter(reduce(vcat, arguments.(_vars))) do x + if x isa Number + true + else + if any(isequal(x), seen) false else - push!(syms, ind_arg) + push!(seen, x) true end - else - true end end end return args_ # TODO for all arguments end + +""" +``julia +get_variables(eqs,_indvars,_depvars) +``` + +Returns all variables that are used in each equations or boundary condition. +""" +function get_variables(eqs, v::VariableMap) + args = get_argument(eqs, v) + return map(arg -> filter(x -> !(x isa Number), arg), args) +end + +function get_number(eqs, v::VariableMap) + args = get_argument(eqs, v) + return map(arg -> filter(x -> x isa Number, arg), args) +end From 867ed377d55fd39ea67649c7fc2daf735506e92e Mon Sep 17 00:00:00 2001 From: xtalax Date: Wed, 12 Apr 2023 17:47:29 +0100 Subject: [PATCH 02/40] include --- src/NeuralPDE.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/NeuralPDE.jl b/src/NeuralPDE.jl index 42f2088436..1054891b10 100644 --- a/src/NeuralPDE.jl +++ b/src/NeuralPDE.jl @@ -39,6 +39,7 @@ RuntimeGeneratedFunctions.init(@__MODULE__) abstract type AbstractTrainingStrategy end include("pinn_types.jl") +include("eq_data.jl") include("symbolic_utilities.jl") include("training_strategies.jl") include("adaptive_losses.jl") From 9d52d56aaf8dbe713a27332dfad61b79843928c0 Mon Sep 17 00:00:00 2001 From: xtalax Date: Thu, 13 Apr 2023 17:20:53 +0100 Subject: [PATCH 03/40] further progress --- src/discretize.jl | 14 +++++++------- src/eq_data.jl | 2 +- src/symbolic_utilities.jl | 6 +++--- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/discretize.jl b/src/discretize.jl index 13bade9828..8f4d61b5da 100644 --- a/src/discretize.jl +++ b/src/discretize.jl @@ -40,15 +40,15 @@ for Flux.Chain, and for Lux.AbstractExplicitLayer """ -function build_symbolic_loss_function(pinnrep::PINNRepresentation, eqs; +function build_symbolic_loss_function(pinnrep::PINNRepresentation, eq; eq_params = SciMLBase.NullParameters(), param_estim = false, default_p = nothing, - bc_indvars = pinnrep.indvars, + bc_indvars = pinnrep.v.x̄, integrand = nothing, dict_transformation_vars = nothing, transformation_vars = nothing, - integrating_depvars = pinnrep.depvars) + integrating_depvars = pinnrep.v.ū) @unpack v, eqdata, phi, derivative, integral, multioutput, init_params, strategy, eq_params, @@ -57,14 +57,14 @@ function build_symbolic_loss_function(pinnrep::PINNRepresentation, eqs; eltypeθ = eltype(pinnrep.flat_init_params) if integrand isa Nothing - loss_function = parse_equation(pinnrep, eqs) - this_eq_pair = pair(eqs, depvars, dict_depvars, dict_depvar_input) - this_eq_indvars = unique(vcat(values(this_eq_pair)...)) + loss_function = parse_equation(pinnrep, eq) + this_eq_pair = pair(eq, depvars, dict_depvars, dict_depvar_input) + this_eq_indvars = indvars(eq, eqmap) else this_eq_pair = Dict(map(intvars -> dict_depvars[intvars] => dict_depvar_input[intvars], integrating_depvars)) this_eq_indvars = transformation_vars isa Nothing ? - unique(vcat(values(this_eq_pair)...)) : transformation_vars + unique(indvars(eq, eqmap)) : transformation_vars loss_function = integrand end diff --git a/src/eq_data.jl b/src/eq_data.jl index 5475184d1e..241f584146 100644 --- a/src/eq_data.jl +++ b/src/eq_data.jl @@ -32,7 +32,7 @@ function EquationData(pdesys, v) get_variables(bcs, v) end - EquationData(depvarmap, indvarmap, pde_indvars, bc_depvars, argmap) + EquationData(depvarmap, indvarmap, pde_indvars, bc_indvars, argmap) end function depvars(eq, eqdata::EquationData) diff --git a/src/symbolic_utilities.jl b/src/symbolic_utilities.jl index 9fcdf2378f..7a03cca4ad 100644 --- a/src/symbolic_utilities.jl +++ b/src/symbolic_utilities.jl @@ -342,11 +342,11 @@ function pair(eq, depvars, dict_depvars, dict_depvar_input) Dict(filter(p -> p !== nothing, pair_)) end -function pair(eq, v::VariableMap) +function pair(eq, v::VariableMap, eqmap) pair_ = map(v.depvar_ops) do op - if !isempty(find_thing_in_expr(toexpr(eq), depvar)) - depvar => v.depvar_input[depvar] + if !isempty(depvars(eq, eqmap)) + depvar => v.args[depvar] end end From 5e27b180df111ea632ab5d855424d736d66fcaf3 Mon Sep 17 00:00:00 2001 From: xtalax Date: Thu, 27 Apr 2023 18:14:23 +0100 Subject: [PATCH 04/40] start new loss --- src/new_loss.jl | 79 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) create mode 100644 src/new_loss.jl diff --git a/src/new_loss.jl b/src/new_loss.jl new file mode 100644 index 0000000000..8a2d8862d8 --- /dev/null +++ b/src/new_loss.jl @@ -0,0 +1,79 @@ +function build_symbolic_loss_function(pinnrep::PINNRepresentation, eq; + eq_params = SciMLBase.NullParameters(), + param_estim = false, + default_p = nothing, + bc_indvars = pinnrep.v.x̄, + integrand = nothing, + dict_transformation_vars = nothing, + transformation_vars = nothing, + integrating_depvars = pinnrep.v.ū) + @unpack v, eqdata, + phi, derivative, integral, + multioutput, init_params, strategy, eq_params, + param_estim, default_p = pinnrep + + eltypeθ = eltype(pinnrep.flat_init_params) + + if integrand isa Nothing + loss_function = parse_equation(pinnrep, eq) + this_eq_indvars = indvars(eq, eqmap) + this_eq_depvars = depvars(eq, eqmap) + else + this_eq_indvars = transformation_vars isa Nothing ? + unique(indvars(eq, eqmap)) : transformation_vars + loss_function = integrand + end + + n = length(this_eq_indvars) + + full_loss_func = (cord, θ, phi, derivative, integral, u, p) -> begin + ivs = [cord[[i], :] for i in 1:n] + cords = map(this_eq_depvars) do w + idxs = map(x -> x2i(v, w, x), v.args[operation(w)])) + vcat(ivs[idxs]...) + end + loss_function(cords, θ, phi, derivative, integral, u, p) + end +end + +function operations(ex) + if istree(ex) + op = operation(ex) + return vcat(operations.(arguments(ex))..., op) + end + return [] +end + +function parse_equation(pinnrep::PINNRepresentation, ex; is_integral = false, + dict_transformation_vars = nothing, + transformation_vars = nothing) + @unpack v, eqdata, derivative, integral = pinnrep + + expr = scalarize(ex) + ex_vars = vars(expr) + ignore = vcat(operation.(ex_vars), getindex, Differential, Integral, ~) + ex_ops = filter(x -> !any(isequal(x), ignore), ex_ops) + op_rules = [@rule $(op)(~~a) => broadcast(op, ~a...) for op in ex_ops] + + dummyvars = @variables phi, u, x, θ + deriv_rules = generate_derivatives_rules(eq, eqdata, dummyvars) + + ch = Postwalk(Chain([deriv_rules; op_rules])) + expr = ch(expr) + + args = [phi, u, x, θ] + + ex = Func(args, [], eq.rhs) |> toexpr + + +end + +function generate_derivative_rules(eq, eqdata, dummyvars) + phi, u, coord, θ = dummyvars + @register_symbolic derivative(phi, u, coord, εs, order, θ) + rs = [[@rule $(Differential(x)^(~d)(w)) => derivative(phi, u, coord, get_εs(w), d, θ) + for x in all_ivs(w, v)] + for w in depvars(eq, eqdata)] + # TODO: add mixed derivatives + return reduce(vcat, rs) +end From 1947fb8c5dc290b813d2a701f7271b6aac05739a Mon Sep 17 00:00:00 2001 From: xtalax Date: Tue, 16 May 2023 17:20:26 +0100 Subject: [PATCH 05/40] cardinalize eqs --- src/NeuralPDE.jl | 1 + src/discretize.jl | 1 + src/new_loss.jl | 6 +++--- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/NeuralPDE.jl b/src/NeuralPDE.jl index 1054891b10..c5151985ff 100644 --- a/src/NeuralPDE.jl +++ b/src/NeuralPDE.jl @@ -16,6 +16,7 @@ using QuasiMonteCarlo using RuntimeGeneratedFunctions using SciMLBase using PDEBase +using PDEBase: cardinalize_eqs! using Statistics using ArrayInterface import Optim diff --git a/src/discretize.jl b/src/discretize.jl index 8f4d61b5da..b471ffc929 100644 --- a/src/discretize.jl +++ b/src/discretize.jl @@ -400,6 +400,7 @@ For more information, see `discretize` and `PINNRepresentation`. """ function SciMLBase.symbolic_discretize(pde_system::PDESystem, discretization::PhysicsInformedNN) + cardinalize_eqs!(pde_system) eqs = pde_system.eqs bcs = pde_system.bcs chain = discretization.chain diff --git a/src/new_loss.jl b/src/new_loss.jl index 8a2d8862d8..97a93bfcea 100644 --- a/src/new_loss.jl +++ b/src/new_loss.jl @@ -56,16 +56,16 @@ function parse_equation(pinnrep::PINNRepresentation, ex; is_integral = false, op_rules = [@rule $(op)(~~a) => broadcast(op, ~a...) for op in ex_ops] dummyvars = @variables phi, u, x, θ - deriv_rules = generate_derivatives_rules(eq, eqdata, dummyvars) + deriv_rules = generate_derivative_rules(eq, eqdata, dummyvars) ch = Postwalk(Chain([deriv_rules; op_rules])) expr = ch(expr) args = [phi, u, x, θ] - ex = Func(args, [], eq.rhs) |> toexpr - + ex = Func(args, [], ch(eq.lhs)) |> toexpr + return ex end function generate_derivative_rules(eq, eqdata, dummyvars) From 27fc8bce57c135ce28fc50025c9883202ae0d4b0 Mon Sep 17 00:00:00 2001 From: xtalax Date: Wed, 17 May 2023 19:25:06 +0100 Subject: [PATCH 06/40] add todos --- src/new_loss.jl | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/src/new_loss.jl b/src/new_loss.jl index 97a93bfcea..33a7a93f18 100644 --- a/src/new_loss.jl +++ b/src/new_loss.jl @@ -1,3 +1,7 @@ +# TODO: add multioutput +# TODO: add param_estim +# TODO: add integrals + function build_symbolic_loss_function(pinnrep::PINNRepresentation, eq; eq_params = SciMLBase.NullParameters(), param_estim = false, @@ -15,9 +19,9 @@ function build_symbolic_loss_function(pinnrep::PINNRepresentation, eq; eltypeθ = eltype(pinnrep.flat_init_params) if integrand isa Nothing - loss_function = parse_equation(pinnrep, eq) this_eq_indvars = indvars(eq, eqmap) this_eq_depvars = depvars(eq, eqmap) + loss_function = parse_equation(pinnrep, eq, this_eq_indvars) else this_eq_indvars = transformation_vars isa Nothing ? unique(indvars(eq, eqmap)) : transformation_vars @@ -29,7 +33,7 @@ function build_symbolic_loss_function(pinnrep::PINNRepresentation, eq; full_loss_func = (cord, θ, phi, derivative, integral, u, p) -> begin ivs = [cord[[i], :] for i in 1:n] cords = map(this_eq_depvars) do w - idxs = map(x -> x2i(v, w, x), v.args[operation(w)])) + idxs = map(x -> x2i(v, w, x), v.args[operation(w)]) vcat(ivs[idxs]...) end loss_function(cords, θ, phi, derivative, integral, u, p) @@ -44,36 +48,37 @@ function operations(ex) return [] end -function parse_equation(pinnrep::PINNRepresentation, ex; is_integral = false, +function parse_equation(pinnrep::PINNRepresentation, eq, ivs; is_integral = false, dict_transformation_vars = nothing, transformation_vars = nothing) @unpack v, eqdata, derivative, integral = pinnrep - expr = scalarize(ex) + expr = eq.lhs ex_vars = vars(expr) ignore = vcat(operation.(ex_vars), getindex, Differential, Integral, ~) ex_ops = filter(x -> !any(isequal(x), ignore), ex_ops) op_rules = [@rule $(op)(~~a) => broadcast(op, ~a...) for op in ex_ops] - dummyvars = @variables phi, u, x, θ + dummyvars = @variables phi, u, θ deriv_rules = generate_derivative_rules(eq, eqdata, dummyvars) ch = Postwalk(Chain([deriv_rules; op_rules])) + + sym_coords = DestructuredArgs(ivs) + expr = ch(expr) - args = [phi, u, x, θ] + args = [phi, u, sym_coords, θ] - ex = Func(args, [], ch(eq.lhs)) |> toexpr + ex = Func(args, [], expr) |> toexpr return ex end function generate_derivative_rules(eq, eqdata, dummyvars) - phi, u, coord, θ = dummyvars + phi, u, θ = dummyvars @register_symbolic derivative(phi, u, coord, εs, order, θ) - rs = [[@rule $(Differential(x)^(~d)(w)) => derivative(phi, u, coord, get_εs(w), d, θ) - for x in all_ivs(w, v)] - for w in depvars(eq, eqdata)] + rs = [@rule $(Differential(~x)^(~d)(~w)) => derivative(phi, u, ~x, get_εs(~w), ~d, θ)] # TODO: add mixed derivatives - return reduce(vcat, rs) + return rs end From 1fcd390198938141bb84a0cdd0c4e7b38e5442da Mon Sep 17 00:00:00 2001 From: xtalax Date: Thu, 25 May 2023 18:57:33 +0100 Subject: [PATCH 07/40] polish loss and refactor --- src/NeuralPDE.jl | 1 + src/discretize.jl | 12 +-- src/eq_data.jl | 84 +++++++++++++++----- src/loss_function_generation.jl | 131 ++++++++++++++++++++++++++++++++ src/new_loss.jl | 84 -------------------- src/symbolic_utilities.jl | 8 -- 6 files changed, 199 insertions(+), 121 deletions(-) create mode 100644 src/loss_function_generation.jl delete mode 100644 src/new_loss.jl diff --git a/src/NeuralPDE.jl b/src/NeuralPDE.jl index c5151985ff..f1d25464e7 100644 --- a/src/NeuralPDE.jl +++ b/src/NeuralPDE.jl @@ -47,6 +47,7 @@ include("adaptive_losses.jl") include("ode_solve.jl") include("rode_solve.jl") include("transform_inf_integral.jl") +include("loss_function_generation.jl") include("discretize.jl") include("neural_adapter.jl") diff --git a/src/discretize.jl b/src/discretize.jl index b471ffc929..3a638b1997 100644 --- a/src/discretize.jl +++ b/src/discretize.jl @@ -444,15 +444,9 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem, integral = get_numeric_integral(pinnrep) - symbolic_pde_loss_functions = [build_symbolic_loss_function(pinnrep, eq; - bc_indvars = pde_indvar) - for (eq, pde_indvar) in zip(eqs, pde_indvars, - pde_integration_vars)] - - symbolic_bc_loss_functions = [build_symbolic_loss_function(pinnrep, bc; - bc_indvars = bc_indvar) - for (bc, bc_indvar) in zip(bcs, bc_indvars, - bc_integration_vars)] + symbolic_pde_loss_functions = [build_symbolic_loss_function(pinnrep, eq) for eq in eqs] + + symbolic_bc_loss_functions = [build_symbolic_loss_function(pinnrep, bc) for bc in bcs] pinnrep.integral = integral pinnrep.symbolic_pde_loss_functions = symbolic_pde_loss_functions diff --git a/src/eq_data.jl b/src/eq_data.jl index 241f584146..f471353671 100644 --- a/src/eq_data.jl +++ b/src/eq_data.jl @@ -1,8 +1,8 @@ struct EquationData <: PDEBase.AbstractVarEqMapping depvarmap indvarmap - pde_indvars - bc_indvars + args + ivargs argmap end @@ -13,26 +13,31 @@ function EquationData(pdesys, v) argmap = map(alleqs) do eq eq => get_argument([eq], v)[1] - end + end |> Dict depvarmap = map(alleqs) do eq eq => get_depvars(eq, v.depvar_ops) - end + end |> Dict indvarmap = map(alleqs) do eq eq => get_indvars(eq, indvars(v)) - end - pde_indvars = if strategy isa QuadratureTraining - get_argument(eqs, v) - else - get_variables(eqs, v) - end + end |> Dict - bc_indvars = if strategy isa QuadratureTraining - get_argument(bcs, v) - else - get_variables(bcs, v) - end + args = map(alleqs) do eq + if strategy isa QuadratureTraining + eq => get_argument(bcs, v) + else + eq => get_variables(bcs, v) + end + end |> Dict + + ivargs = map(alleqs) do eq + if strategy isa QuadratureTraining + eq => get_iv_argument(eqs, v) + else + eq => get_iv_variables(eqs, v) + end + end |> Dict - EquationData(depvarmap, indvarmap, pde_indvars, bc_indvars, argmap) + EquationData(depvarmap, indvarmap, args, ivargs, argmap) end function depvars(eq, eqdata::EquationData) @@ -43,12 +48,51 @@ function indvars(eq, eqdata::EquationData) eqdata.indvarmap[eq] end -function pde_indvars(eqdata::EquationData) - eqdata.pde_indvars +function eq_args(eq, eqdata::EquationData) + eqdata.args[eq] end -function bc_indvars(eqdata::EquationData) - eqdata.bc_indvars +function eq_iv_args(eq, eqdata::EquationData) + eqdata.ivargs[eq] end argument(eq, eqdata) = eqdata.argmap[eq] + + +function get_iv_argument(eqs, v::VariableMap) + vars = map(eqs) do eq + _vars = map(depvar -> get_depvars(eq, depvar), v.depvar_ops) + f_vars = filter(x -> !isempty(x), _vars) + v.args[operation(map(x -> first(x), f_vars))] + end + args_ = map(vars) do _vars + seen = [] + args_ = map(vars) do _vars + seen = [] + filter(reduce(vcat, arguments.(_vars))) do x + if x isa Number + true + else + if any(isequal(x), seen) + false + else + push!(seen, x) + true + end + end + end + end + return args_ # TODO for all arguments +end + +""" +``julia +get_variables(eqs,_indvars,_depvars) +``` + +Returns all variables that are used in each equations or boundary condition. +""" +function get_iv_variables(eqs, v::VariableMap) + args = get_iv_argument(eqs, v) + return map(arg -> filter(x -> !(x isa Number), arg), args) +end diff --git a/src/loss_function_generation.jl b/src/loss_function_generation.jl new file mode 100644 index 0000000000..97beeee8c4 --- /dev/null +++ b/src/loss_function_generation.jl @@ -0,0 +1,131 @@ +# TODO: add multioutput +# TODO: add integrals + +function build_symbolic_loss_function(pinnrep::PINNRepresentation, eq; + eq_params = SciMLBase.NullParameters(), + param_estim = false, + default_p = [], + integrand = nothing, + transformation_vars = nothing) + @unpack v, eqdata, + phi, derivative, integral, + multioutput, init_params, strategy, eq_params, + param_estim, default_p = pinnrep + + eltypeθ = eltype(pinnrep.flat_init_params) + + eq_args = get(eqdata.eq_args, eq, v.x̄) + + if integrand isa Nothing + this_eq_indvars = indvars(eq, eqmap) + this_eq_depvars = depvars(eq, eqmap) + loss_function = parse_equation(pinnrep, eq, eq_iv_args(eq, eqmap)) + else + this_eq_indvars = transformation_vars isa Nothing ? + unique(indvars(eq, eqmap)) : transformation_vars + loss_function = integrand + end + + n = length(this_eq_indvars) + + if param_estim == true && eq_params != SciMLBase.NullParameters() + param_len = length(eq_params) + # check parameter format to use correct indexing + psform = (phi isa Vector && phi[1].f isa Optimisers.Restructure) || + (!(phi isa Vector) && phi.f isa Optimisers.Restructure) + + if psform + last_indx = [0; accumulate(+, map(length, init_params))][end] + ps_range = 1:param_len .+ last_indx + get_ps = (θ) -> θ[ps_range] + else + ps_range = 1:param_len + get_ps = (θ) -> θ.p[ps_range] + end + else + get_ps = (θ) -> default_p + end + + function get_coords(cord) + map(enumerate(eq_args)) do (i, x) + if x isa Number + fill(x, size(cord[[1], :])) + else + cord[[i], :] + end + end + end + + full_loss_func = (cord, θ, phi, derivative, integral, u, p) -> begin + loss_function(get_coords(cord), θ, phi, derivative, integral, u, get_ps(θ)) + end + return full_loss_func +end + +function build_loss_function(pinnrep, eqs) + @unpack eq_params, param_estim, default_p, phi, derivative, integral = pinnrep + + _loss_function = build_symbolic_loss_function(pinnrep, eqs, + eq_params = eq_params, + param_estim = param_estim) + + u = get_u() + loss_function = (cord, θ) -> begin _loss_function(cord, θ, phi, derivative, integral, u, + default_p) end + return loss_function +end + +function operations(ex) + if istree(ex) + op = operation(ex) + return vcat(operations.(arguments(ex))..., op) + end + return [] +end + +############################################################################################ +# Parse equation +############################################################################################ + +function parse_equation(pinnrep::PINNRepresentation, eq, ivs; is_integral = false, + dict_transformation_vars = nothing, + transformation_vars = nothing) + @unpack v, eqdata, derivative, integral = pinnrep + + expr = eq isa Equation ? eq.lhs : eq + ex_vars = vars(expr) + ignore = vcat(operation.(ex_vars), getindex, Differential, Integral, ~) + ex_ops = filter(x -> !any(isequal(x), ignore), ex_ops) + op_rules = [@rule $(op)(~~a) => broadcast(op, ~a...) for op in ex_ops] + + dummyvars = @variables phi, u, θ + deriv_rules = generate_derivative_rules(eq, eqdata, dummyvars) + + ch = Postwalk(Chain([deriv_rules; op_rules])) + expr = ch(expr) + + sym_coords = DestructuredArgs(ivs) + ps = DestructuredArgs(v.ps) + + + args = [sym_coords, θ, phi, u, ps] + + ex = Func(args, [], expr) |> toexpr + + return ex +end + +function generate_derivative_rules(eq, eqdata, dummyvars) + phi, u, θ = dummyvars + @register_symbolic derivative(phi, u, coord, εs, order, θ) + rs = [@rule $(Differential(~x)^(~d::isinteger)(~w)) => derivative(phi, u, ~x, get_εs(~w), ~d, θ)] + # TODO: add mixed derivatives + return rs +end + +function generate_integral_rules(eq, eqdata, dummyvars) + phi, u, θ = dummyvars + #! all that should be needed is to solve an integral problem, the trick is doing this + #! with rules without putting symbols through the solve + +end diff --git a/src/new_loss.jl b/src/new_loss.jl deleted file mode 100644 index 33a7a93f18..0000000000 --- a/src/new_loss.jl +++ /dev/null @@ -1,84 +0,0 @@ -# TODO: add multioutput -# TODO: add param_estim -# TODO: add integrals - -function build_symbolic_loss_function(pinnrep::PINNRepresentation, eq; - eq_params = SciMLBase.NullParameters(), - param_estim = false, - default_p = nothing, - bc_indvars = pinnrep.v.x̄, - integrand = nothing, - dict_transformation_vars = nothing, - transformation_vars = nothing, - integrating_depvars = pinnrep.v.ū) - @unpack v, eqdata, - phi, derivative, integral, - multioutput, init_params, strategy, eq_params, - param_estim, default_p = pinnrep - - eltypeθ = eltype(pinnrep.flat_init_params) - - if integrand isa Nothing - this_eq_indvars = indvars(eq, eqmap) - this_eq_depvars = depvars(eq, eqmap) - loss_function = parse_equation(pinnrep, eq, this_eq_indvars) - else - this_eq_indvars = transformation_vars isa Nothing ? - unique(indvars(eq, eqmap)) : transformation_vars - loss_function = integrand - end - - n = length(this_eq_indvars) - - full_loss_func = (cord, θ, phi, derivative, integral, u, p) -> begin - ivs = [cord[[i], :] for i in 1:n] - cords = map(this_eq_depvars) do w - idxs = map(x -> x2i(v, w, x), v.args[operation(w)]) - vcat(ivs[idxs]...) - end - loss_function(cords, θ, phi, derivative, integral, u, p) - end -end - -function operations(ex) - if istree(ex) - op = operation(ex) - return vcat(operations.(arguments(ex))..., op) - end - return [] -end - -function parse_equation(pinnrep::PINNRepresentation, eq, ivs; is_integral = false, - dict_transformation_vars = nothing, - transformation_vars = nothing) - @unpack v, eqdata, derivative, integral = pinnrep - - expr = eq.lhs - ex_vars = vars(expr) - ignore = vcat(operation.(ex_vars), getindex, Differential, Integral, ~) - ex_ops = filter(x -> !any(isequal(x), ignore), ex_ops) - op_rules = [@rule $(op)(~~a) => broadcast(op, ~a...) for op in ex_ops] - - dummyvars = @variables phi, u, θ - deriv_rules = generate_derivative_rules(eq, eqdata, dummyvars) - - ch = Postwalk(Chain([deriv_rules; op_rules])) - - sym_coords = DestructuredArgs(ivs) - - expr = ch(expr) - - args = [phi, u, sym_coords, θ] - - ex = Func(args, [], expr) |> toexpr - - return ex -end - -function generate_derivative_rules(eq, eqdata, dummyvars) - phi, u, θ = dummyvars - @register_symbolic derivative(phi, u, coord, εs, order, θ) - rs = [@rule $(Differential(~x)^(~d)(~w)) => derivative(phi, u, ~x, get_εs(~w), ~d, θ)] - # TODO: add mixed derivatives - return rs -end diff --git a/src/symbolic_utilities.jl b/src/symbolic_utilities.jl index 7a03cca4ad..7c423657db 100644 --- a/src/symbolic_utilities.jl +++ b/src/symbolic_utilities.jl @@ -342,14 +342,6 @@ function pair(eq, depvars, dict_depvars, dict_depvar_input) Dict(filter(p -> p !== nothing, pair_)) end -function pair(eq, v::VariableMap, eqmap) - - pair_ = map(v.depvar_ops) do op - if !isempty(depvars(eq, eqmap)) - depvar => v.args[depvar] - end - end - function get_integration_variables(eqs, v::VariableMap) ivs = all_ivs(v) return map(eq -> get_indvars(eq, ivs), eqs) From b85eaa83503d258a2a504fdb9dca6e485630a06a Mon Sep 17 00:00:00 2001 From: xtalax Date: Thu, 25 May 2023 18:58:13 +0100 Subject: [PATCH 08/40] fix --- src/eq_data.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/eq_data.jl b/src/eq_data.jl index f471353671..65373ccc61 100644 --- a/src/eq_data.jl +++ b/src/eq_data.jl @@ -18,7 +18,7 @@ function EquationData(pdesys, v) eq => get_depvars(eq, v.depvar_ops) end |> Dict indvarmap = map(alleqs) do eq - eq => get_indvars(eq, indvars(v)) + eq => get_indvars(eq, v) end |> Dict args = map(alleqs) do eq From 8e39ede9b199763596c73644b6cd6aeee64187b8 Mon Sep 17 00:00:00 2001 From: xtalax Date: Fri, 26 May 2023 15:46:53 +0100 Subject: [PATCH 09/40] "final" parsing updates --- src/discretize.jl | 341 +++++++++++++++++-------------------- src/eq_data.jl | 4 +- src/training_strategies.jl | 8 +- 3 files changed, 163 insertions(+), 190 deletions(-) diff --git a/src/discretize.jl b/src/discretize.jl index 3a638b1997..e26b45b016 100644 --- a/src/discretize.jl +++ b/src/discretize.jl @@ -40,159 +40,159 @@ for Flux.Chain, and for Lux.AbstractExplicitLayer """ -function build_symbolic_loss_function(pinnrep::PINNRepresentation, eq; - eq_params = SciMLBase.NullParameters(), - param_estim = false, - default_p = nothing, - bc_indvars = pinnrep.v.x̄, - integrand = nothing, - dict_transformation_vars = nothing, - transformation_vars = nothing, - integrating_depvars = pinnrep.v.ū) - @unpack v, eqdata, - phi, derivative, integral, - multioutput, init_params, strategy, eq_params, - param_estim, default_p = pinnrep - - eltypeθ = eltype(pinnrep.flat_init_params) - - if integrand isa Nothing - loss_function = parse_equation(pinnrep, eq) - this_eq_pair = pair(eq, depvars, dict_depvars, dict_depvar_input) - this_eq_indvars = indvars(eq, eqmap) - else - this_eq_pair = Dict(map(intvars -> dict_depvars[intvars] => dict_depvar_input[intvars], - integrating_depvars)) - this_eq_indvars = transformation_vars isa Nothing ? - unique(indvars(eq, eqmap)) : transformation_vars - loss_function = integrand - end - - vars = :(cord, $θ, phi, derivative, integral, u, p) - ex = Expr(:block) - if multioutput - θ_nums = Symbol[] - phi_nums = Symbol[] - for v in depvars - num = dict_depvars[v] - push!(θ_nums, :($(Symbol(:($θ), num)))) - push!(phi_nums, :($(Symbol(:phi, num)))) - end - - expr_θ = Expr[] - expr_phi = Expr[] - - acum = [0; accumulate(+, map(length, init_params))] - sep = [(acum[i] + 1):acum[i + 1] for i in 1:(length(acum) - 1)] - - for i in eachindex(depvars) - if (phi isa Vector && phi[1].f isa Optimisers.Restructure) || - (!(phi isa Vector) && phi.f isa Optimisers.Restructure) - # Flux.Chain - push!(expr_θ, :($θ[$(sep[i])])) - else # Lux.AbstractExplicitLayer - push!(expr_θ, :($θ.depvar.$(depvars[i]))) - end - push!(expr_phi, :(phi[$i])) - end - - vars_θ = Expr(:(=), build_expr(:tuple, θ_nums), build_expr(:tuple, expr_θ)) - push!(ex.args, vars_θ) - - vars_phi = Expr(:(=), build_expr(:tuple, phi_nums), build_expr(:tuple, expr_phi)) - push!(ex.args, vars_phi) - end - - #Add an expression for parameter symbols - if param_estim == true && eq_params != SciMLBase.NullParameters() - param_len = length(eq_params) - last_indx = [0; accumulate(+, map(length, init_params))][end] - params_symbols = Symbol[] - expr_params = Expr[] - for (i, eq_param) in enumerate(eq_params) - if (phi isa Vector && phi[1].f isa Optimisers.Restructure) || - (!(phi isa Vector) && phi.f isa Optimisers.Restructure) - push!(expr_params, :($θ[$((i + last_indx):(i + last_indx))])) - else - push!(expr_params, :($θ.p[$((i):(i))])) - end - push!(params_symbols, Symbol(:($eq_param))) - end - params_eq = Expr(:(=), build_expr(:tuple, params_symbols), - build_expr(:tuple, expr_params)) - push!(ex.args, params_eq) - end - - if eq_params != SciMLBase.NullParameters() && param_estim == false - params_symbols = Symbol[] - expr_params = Expr[] - for (i, eq_param) in enumerate(eq_params) - push!(expr_params, :(ArrayInterface.allowed_getindex(p, ($i):($i)))) - push!(params_symbols, Symbol(:($eq_param))) - end - params_eq = Expr(:(=), build_expr(:tuple, params_symbols), - build_expr(:tuple, expr_params)) - push!(ex.args, params_eq) - end - - eq_pair_expr = Expr[] - for i in keys(this_eq_pair) - push!(eq_pair_expr, :($(Symbol(:cord, :($i))) = vcat($(this_eq_pair[i]...)))) - end - vcat_expr = Expr(:block, :($(eq_pair_expr...))) - vcat_expr_loss_functions = Expr(:block, vcat_expr, loss_function) # TODO rename - - if strategy isa QuadratureTraining - indvars_ex = get_indvars_ex(bc_indvars) - left_arg_pairs, right_arg_pairs = this_eq_indvars, indvars_ex - vars_eq = Expr(:(=), build_expr(:tuple, left_arg_pairs), - build_expr(:tuple, right_arg_pairs)) - else - indvars_ex = [:($:cord[[$i], :]) for (i, x) in enumerate(this_eq_indvars)] - left_arg_pairs, right_arg_pairs = this_eq_indvars, indvars_ex - vars_eq = Expr(:(=), build_expr(:tuple, left_arg_pairs), - build_expr(:tuple, right_arg_pairs)) - end - - if !(dict_transformation_vars isa Nothing) - transformation_expr_ = Expr[] - - for (i, u) in dict_transformation_vars - push!(transformation_expr_, :($i = $u)) - end - transformation_expr = Expr(:block, :($(transformation_expr_...))) - vcat_expr_loss_functions = Expr(:block, transformation_expr, vcat_expr, - loss_function) - end - let_ex = Expr(:let, vars_eq, vcat_expr_loss_functions) - push!(ex.args, let_ex) - expr_loss_function = :(($vars) -> begin $ex end) -end - -""" -```julia -build_loss_function(eqs, indvars, depvars, phi, derivative, init_params; bc_indvars=nothing) -``` - -Returns the body of loss function, which is the executable Julia function, for the main -equation or boundary condition. -""" -function build_loss_function(pinnrep::PINNRepresentation, eqs, bc_indvars) - @unpack eq_params, param_estim, default_p, phi, derivative, integral = pinnrep - - bc_indvars = bc_indvars === nothing ? pinnrep.indvars : bc_indvars - - expr_loss_function = build_symbolic_loss_function(pinnrep, eqs; - bc_indvars = bc_indvars, - eq_params = eq_params, - param_estim = param_estim, - default_p = default_p) - u = get_u() - _loss_function = @RuntimeGeneratedFunction(expr_loss_function) - loss_function = (cord, θ) -> begin _loss_function(cord, θ, phi, derivative, integral, u, - default_p) end - return loss_function -end +# function build_symbolic_loss_function(pinnrep::PINNRepresentation, eq; +# eq_params = SciMLBase.NullParameters(), +# param_estim = false, +# default_p = nothing, +# bc_indvars = pinnrep.v.x̄, +# integrand = nothing, +# dict_transformation_vars = nothing, +# transformation_vars = nothing, +# integrating_depvars = pinnrep.v.ū) +# @unpack v, eqdata, +# phi, derivative, integral, +# multioutput, init_params, strategy, eq_params, +# param_estim, default_p = pinnrep + +# eltypeθ = eltype(pinnrep.flat_init_params) + +# if integrand isa Nothing +# loss_function = parse_equation(pinnrep, eq) +# this_eq_pair = pair(eq, depvars, dict_depvars, dict_depvar_input) +# this_eq_indvars = indvars(eq, eqmap) +# else +# this_eq_pair = Dict(map(intvars -> dict_depvars[intvars] => dict_depvar_input[intvars], +# integrating_depvars)) +# this_eq_indvars = transformation_vars isa Nothing ? +# unique(indvars(eq, eqmap)) : transformation_vars +# loss_function = integrand +# end + +# vars = :(cord, $θ, phi, derivative, integral, u, p) +# ex = Expr(:block) +# if multioutput +# θ_nums = Symbol[] +# phi_nums = Symbol[] +# for v in depvars +# num = dict_depvars[v] +# push!(θ_nums, :($(Symbol(:($θ), num)))) +# push!(phi_nums, :($(Symbol(:phi, num)))) +# end + +# expr_θ = Expr[] +# expr_phi = Expr[] + +# acum = [0; accumulate(+, map(length, init_params))] +# sep = [(acum[i] + 1):acum[i + 1] for i in 1:(length(acum) - 1)] + +# for i in eachindex(depvars) +# if (phi isa Vector && phi[1].f isa Optimisers.Restructure) || +# (!(phi isa Vector) && phi.f isa Optimisers.Restructure) +# # Flux.Chain +# push!(expr_θ, :($θ[$(sep[i])])) +# else # Lux.AbstractExplicitLayer +# push!(expr_θ, :($θ.depvar.$(depvars[i]))) +# end +# push!(expr_phi, :(phi[$i])) +# end + +# vars_θ = Expr(:(=), build_expr(:tuple, θ_nums), build_expr(:tuple, expr_θ)) +# push!(ex.args, vars_θ) + +# vars_phi = Expr(:(=), build_expr(:tuple, phi_nums), build_expr(:tuple, expr_phi)) +# push!(ex.args, vars_phi) +# end + +# #Add an expression for parameter symbols +# if param_estim == true && eq_params != SciMLBase.NullParameters() +# param_len = length(eq_params) +# last_indx = [0; accumulate(+, map(length, init_params))][end] +# params_symbols = Symbol[] +# expr_params = Expr[] +# for (i, eq_param) in enumerate(eq_params) +# if (phi isa Vector && phi[1].f isa Optimisers.Restructure) || +# (!(phi isa Vector) && phi.f isa Optimisers.Restructure) +# push!(expr_params, :($θ[$((i + last_indx):(i + last_indx))])) +# else +# push!(expr_params, :($θ.p[$((i):(i))])) +# end +# push!(params_symbols, Symbol(:($eq_param))) +# end +# params_eq = Expr(:(=), build_expr(:tuple, params_symbols), +# build_expr(:tuple, expr_params)) +# push!(ex.args, params_eq) +# end + +# if eq_params != SciMLBase.NullParameters() && param_estim == false +# params_symbols = Symbol[] +# expr_params = Expr[] +# for (i, eq_param) in enumerate(eq_params) +# push!(expr_params, :(ArrayInterface.allowed_getindex(p, ($i):($i)))) +# push!(params_symbols, Symbol(:($eq_param))) +# end +# params_eq = Expr(:(=), build_expr(:tuple, params_symbols), +# build_expr(:tuple, expr_params)) +# push!(ex.args, params_eq) +# end + +# eq_pair_expr = Expr[] +# for i in keys(this_eq_pair) +# push!(eq_pair_expr, :($(Symbol(:cord, :($i))) = vcat($(this_eq_pair[i]...)))) +# end +# vcat_expr = Expr(:block, :($(eq_pair_expr...))) +# vcat_expr_loss_functions = Expr(:block, vcat_expr, loss_function) # TODO rename + +# if strategy isa QuadratureTraining +# indvars_ex = get_indvars_ex(bc_indvars) +# left_arg_pairs, right_arg_pairs = this_eq_indvars, indvars_ex +# vars_eq = Expr(:(=), build_expr(:tuple, left_arg_pairs), +# build_expr(:tuple, right_arg_pairs)) +# else +# indvars_ex = [:($:cord[[$i], :]) for (i, x) in enumerate(this_eq_indvars)] +# left_arg_pairs, right_arg_pairs = this_eq_indvars, indvars_ex +# vars_eq = Expr(:(=), build_expr(:tuple, left_arg_pairs), +# build_expr(:tuple, right_arg_pairs)) +# end + +# if !(dict_transformation_vars isa Nothing) +# transformation_expr_ = Expr[] + +# for (i, u) in dict_transformation_vars +# push!(transformation_expr_, :($i = $u)) +# end +# transformation_expr = Expr(:block, :($(transformation_expr_...))) +# vcat_expr_loss_functions = Expr(:block, transformation_expr, vcat_expr, +# loss_function) +# end +# let_ex = Expr(:let, vars_eq, vcat_expr_loss_functions) +# push!(ex.args, let_ex) +# expr_loss_function = :(($vars) -> begin $ex end) +# end + +# """ +# ```julia +# build_loss_function(eqs, indvars, depvars, phi, derivative, init_params; bc_indvars=nothing) +# ``` + +# Returns the body of loss function, which is the executable Julia function, for the main +# equation or boundary condition. +# """ +# function build_loss_function(pinnrep::PINNRepresentation, eqs, bc_indvars) +# @unpack eq_params, param_estim, default_p, phi, derivative, integral = pinnrep + +# bc_indvars = bc_indvars === nothing ? pinnrep.indvars : bc_indvars + +# expr_loss_function = build_symbolic_loss_function(pinnrep, eqs; +# bc_indvars = bc_indvars, +# eq_params = eq_params, +# param_estim = param_estim, +# default_p = default_p) +# u = get_u() +# _loss_function = @RuntimeGeneratedFunction(expr_loss_function) +# loss_function = (cord, θ) -> begin _loss_function(cord, θ, phi, derivative, integral, u, +# default_p) end +# return loss_function +# end """ ```julia @@ -290,7 +290,6 @@ end function get_bounds(domains, eqs, bcs, eltypeθ, v::VariableMap, strategy::QuadratureTraining) dict_lower_bound = Dict([d.variables => infimum(d.domain) for d in domains]) dict_upper_bound = Dict([d.variables => supremum(d.domain) for d in domains]) - #! Fix this to work with a var_eq mapping pde_args = get_argument(eqs, v) pde_lower_bounds = map(pde_args) do pd @@ -303,7 +302,7 @@ function get_bounds(domains, eqs, bcs, eltypeθ, v::VariableMap, strategy::Quadr end pde_bounds = [pde_lower_bounds, pde_upper_bounds] - bound_vars = get_variables(bcs, dict_indvars, dict_depvars) + bound_vars = get_variables(bcs, v) bcs_lower_bounds = map(bound_vars) do bt map(b -> dict_lower_bound[b], bt) @@ -316,30 +315,6 @@ function get_bounds(domains, eqs, bcs, eltypeθ, v::VariableMap, strategy::Quadr [pde_bounds, bcs_bounds] end -function get_bounds(domains, eqs, bcs, eltypeθ, dict_indvars, dict_depvars, strategy) - dx = 1 / strategy.points - dict_span = Dict([Symbol(d.variables) => [ - infimum(d.domain) + dx, - supremum(d.domain) - dx, - ] for d in domains]) - - # pde_bounds = [[infimum(d.domain),supremum(d.domain)] for d in domains] - pde_args = get_argument(eqs, dict_indvars, dict_depvars) - pde_bounds = map(pde_args) do pde_arg - bds = mapreduce(s -> get(dict_span, s, fill(s, 2)), hcat, pde_arg) - bds = eltypeθ.(bds) - bds[1, :], bds[2, :] - end - - bound_args = get_argument(bcs, dict_indvars, dict_depvars) - bcs_bounds = map(bound_args) do bound_arg - bds = mapreduce(s -> get(dict_span, s, fill(s, 2)), hcat, bound_arg) - bds = eltypeθ.(bds) - bds[1, :], bds[2, :] - end - return pde_bounds, bcs_bounds -end - function get_numeric_integral(pinnrep::PINNRepresentation) @unpack strategy, multioutput, derivative, varmap = pinnrep @@ -442,13 +417,13 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem, derivative, strategy, eqdata, nothing, nothing, nothing, nothing) - integral = get_numeric_integral(pinnrep) + #integral = get_numeric_integral(pinnrep) - symbolic_pde_loss_functions = [build_symbolic_loss_function(pinnrep, eq) for eq in eqs] + #symbolic_pde_loss_functions = [build_symbolic_loss_function(pinnrep, eq) for eq in eqs] - symbolic_bc_loss_functions = [build_symbolic_loss_function(pinnrep, bc) for bc in bcs] + #symbolic_bc_loss_functions = [build_symbolic_loss_function(pinnrep, bc) for bc in bcs] - pinnrep.integral = integral + #pinnrep.integral = integral pinnrep.symbolic_pde_loss_functions = symbolic_pde_loss_functions pinnrep.symbolic_bc_loss_functions = symbolic_bc_loss_functions diff --git a/src/eq_data.jl b/src/eq_data.jl index 65373ccc61..6dfb65df7d 100644 --- a/src/eq_data.jl +++ b/src/eq_data.jl @@ -66,8 +66,6 @@ function get_iv_argument(eqs, v::VariableMap) v.args[operation(map(x -> first(x), f_vars))] end args_ = map(vars) do _vars - seen = [] - args_ = map(vars) do _vars seen = [] filter(reduce(vcat, arguments.(_vars))) do x if x isa Number @@ -82,7 +80,7 @@ function get_iv_argument(eqs, v::VariableMap) end end end - return args_ # TODO for all arguments + return args_ end """ diff --git a/src/training_strategies.jl b/src/training_strategies.jl index 6c6dacbb7a..d87bc64d7c 100644 --- a/src/training_strategies.jl +++ b/src/training_strategies.jl @@ -80,11 +80,11 @@ function merge_strategy_with_loss_function(pinnrep::PINNRepresentation, strategy::StochasticTraining, datafree_pde_loss_function, datafree_bc_loss_function) - @unpack domains, eqs, bcs, dict_indvars, dict_depvars, flat_init_params = pinnrep + @unpack domains, eqs, bcs, v, flat_init_params = pinnrep eltypeθ = eltype(pinnrep.flat_init_params) - bounds = get_bounds(domains, eqs, bcs, eltypeθ, dict_indvars, dict_depvars, + bounds = get_bounds(domains, eqs, bcs, eltypeθ, v, strategy) pde_bounds, bcs_bounds = bounds @@ -304,8 +304,8 @@ end WeightedIntervalTraining(weights, samples) ``` -A training strategy that generates points for training based on the given inputs. -We split the timespan into equal segments based on the number of weights, +A training strategy that generates points for training based on the given inputs. +We split the timespan into equal segments based on the number of weights, then sample points in each segment based on that segments corresponding weight, such that the total number of sampled points is equivalent to the given samples From 0f998113dd9f8d404e5d88a512385cade97af72a Mon Sep 17 00:00:00 2001 From: xtalax Date: Fri, 26 May 2023 15:47:00 +0100 Subject: [PATCH 10/40] oops --- src/discretize.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/discretize.jl b/src/discretize.jl index e26b45b016..15670e9e22 100644 --- a/src/discretize.jl +++ b/src/discretize.jl @@ -314,7 +314,7 @@ function get_bounds(domains, eqs, bcs, eltypeθ, v::VariableMap, strategy::Quadr [pde_bounds, bcs_bounds] end - +# TODO: Get this to work with varmap function get_numeric_integral(pinnrep::PINNRepresentation) @unpack strategy, multioutput, derivative, varmap = pinnrep From 3806caf500a4a2e3b24335e5a1906b12f5e232e1 Mon Sep 17 00:00:00 2001 From: xtalax Date: Tue, 30 May 2023 17:05:18 +0100 Subject: [PATCH 11/40] test fixes --- src/NeuralPDE.jl | 4 +- src/discretize.jl | 116 +++++++++++++++++---- src/eq_data.jl | 8 +- src/loss_function_generation.jl | 7 +- src/pinn_types.jl | 179 ++++++++++---------------------- src/symbolic_utilities.jl | 6 +- test/NNPDE_tests.jl | 2 +- 7 files changed, 162 insertions(+), 160 deletions(-) diff --git a/src/NeuralPDE.jl b/src/NeuralPDE.jl index f1d25464e7..7a87035a65 100644 --- a/src/NeuralPDE.jl +++ b/src/NeuralPDE.jl @@ -16,7 +16,7 @@ using QuasiMonteCarlo using RuntimeGeneratedFunctions using SciMLBase using PDEBase -using PDEBase: cardinalize_eqs! +using PDEBase: cardinalize_eqs!, get_depvars, get_indvars using Statistics using ArrayInterface import Optim @@ -37,6 +37,8 @@ import ChainRulesCore: @non_differentiable RuntimeGeneratedFunctions.init(@__MODULE__) +abstract type AbstractPINN <: SciMLBase.AbstractDiscretization end + abstract type AbstractTrainingStrategy end include("pinn_types.jl") diff --git a/src/discretize.jl b/src/discretize.jl index 15670e9e22..20d9437035 100644 --- a/src/discretize.jl +++ b/src/discretize.jl @@ -362,7 +362,7 @@ end """ ```julia -prob = symbolic_discretize(pde_system::PDESystem, discretization::PhysicsInformedNN) +prob = symbolic_discretize(pdesys::PDESystem, discretization::PhysicsInformedNN) ``` `symbolic_discretize` is the lower level interface to `discretize` for inspecting internals. @@ -373,16 +373,16 @@ to the PDE. For more information, see `discretize` and `PINNRepresentation`. """ -function SciMLBase.symbolic_discretize(pde_system::PDESystem, +function SciMLBase.symbolic_discretize(pdesys::PDESystem, discretization::PhysicsInformedNN) - cardinalize_eqs!(pde_system) - eqs = pde_system.eqs - bcs = pde_system.bcs + cardinalize_eqs!(pdesys) + eqs = pdesys.eqs + bcs = pdesys.bcs chain = discretization.chain - domains = pde_system.domain - eq_params = pde_system.ps - defaults = pde_system.defaults + domains = pdesys.domain + eq_params = pdesys.ps + defaults = pdesys.defaults default_p = eq_params == SciMLBase.NullParameters() ? nothing : [defaults[ep] for ep in eq_params] @@ -390,14 +390,6 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem, additional_loss = discretization.additional_loss adaloss = discretization.adaptive_loss - v = VariableMap(pde_system, discretization) - - # Find the derivative orders in the bcs - bcorders = Dict(map(x -> x => d_orders(x, pdesys.bcs), all_ivs(v))) - # Create a map of each variable to their boundary conditions including initial conditions - boundarymap = parse_bcs(pdesys.bcs, v, bcorders) - - eqdata = EquationData(pdesys, v) multioutput = discretization.multioutput init_params = discretization.init_params @@ -411,8 +403,92 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem, iteration = discretization.iteration self_increment = discretization.self_increment + v = VariableMap(pdesys, discretization) + + eqdata = EquationData(pdesys, v, strategy) + + + if init_params === nothing + # Use the initialization of the neural network framework + # But for Lux, default to Float64 + # For Flux, default to the types matching the values in the neural network + # This is done because Float64 is almost always better for these applications + # But with Flux there's already a chosen type from the user + + if chain isa AbstractArray + if chain[1] isa Flux.Chain + init_params = map(chain) do x + _x = Flux.destructure(x)[1] + end + else + x = map(chain) do x + _x = ComponentArrays.ComponentArray(Lux.initialparameters(Random.default_rng(), + x)) + Float64.(_x) # No ComponentArray GPU support + end + names = ntuple(i -> ~Symbol.(v.ū)[i], length(chain)) + init_params = ComponentArrays.ComponentArray(NamedTuple{names}(i + for i in x)) + end + else + if chain isa Flux.Chain + init_params = Flux.destructure(chain)[1] + init_params = init_params isa Array ? Float64.(init_params) : + init_params + else + init_params = Float64.(ComponentArrays.ComponentArray(Lux.initialparameters(Random.default_rng(), + chain))) + end + end + else + init_params = init_params + end + + if (phi isa Vector && phi[1].f isa Optimisers.Restructure) || + (!(phi isa Vector) && phi.f isa Optimisers.Restructure) + # Flux.Chain + flat_init_params = multioutput ? reduce(vcat, init_params) : init_params + flat_init_params = param_estim == false ? flat_init_params : + vcat(flat_init_params, + adapt(typeof(flat_init_params), default_p)) + else + flat_init_params = if init_params isa ComponentArrays.ComponentArray + init_params + elseif multioutput + @assert length(init_params) == length(depvars) + names = ntuple(i -> depvars[i], length(init_params)) + x = ComponentArrays.ComponentArray(NamedTuple{names}(i for i in init_params)) + else + ComponentArrays.ComponentArray(init_params) + end + flat_init_params = if param_estim == false && multioutput + ComponentArrays.ComponentArray(; depvar = flat_init_params) + elseif param_estim == false && !multioutput + flat_init_params + else + ComponentArrays.ComponentArray(; depvar = flat_init_params, p = default_p) + end + end + + if (phi isa Vector && phi[1].f isa Lux.AbstractExplicitLayer) + for ϕ in phi + ϕ.st = adapt(parameterless_type(ComponentArrays.getdata(flat_init_params)), + ϕ.st) + end + elseif (!(phi isa Vector) && phi.f isa Lux.AbstractExplicitLayer) + phi.st = adapt(parameterless_type(ComponentArrays.getdata(flat_init_params)), + phi.st) + end + + eltypeθ = eltype(flat_init_params) + + if adaptive_loss === nothing + adaptive_loss = NonAdaptiveLoss{eltypeθ}() + end + + pinnrep = PINNRepresentation(eqs, bcs, domains, eq_params, defaults, default_p, - param_estim, additional_loss, adaloss, v, logger, + param_estim, additional_loss, adaptive_loss, v, logger, multioutput, iteration, init_params, flat_init_params, phi, derivative, strategy, eqdata, nothing, nothing, nothing, nothing) @@ -540,15 +616,15 @@ end """ ```julia -prob = discretize(pde_system::PDESystem, discretization::PhysicsInformedNN) +prob = discretize(pdesys::PDESystem, discretization::PhysicsInformedNN) ``` Transforms a symbolic description of a ModelingToolkit-defined `PDESystem` and generates an `OptimizationProblem` for [Optimization.jl](https://docs.sciml.ai/Optimization/stable/) whose solution is the solution to the PDE. """ -function SciMLBase.discretize(pde_system::PDESystem, discretization::PhysicsInformedNN) - pinnrep = symbolic_discretize(pde_system, discretization) +function SciMLBase.discretize(pdesys::PDESystem, discretization::PhysicsInformedNN) + pinnrep = symbolic_discretize(pdesys, discretization) f = OptimizationFunction(pinnrep.loss_functions.full_loss_function, Optimization.AutoZygote()) Optimization.OptimizationProblem(f, pinnrep.flat_init_params) diff --git a/src/eq_data.jl b/src/eq_data.jl index 6dfb65df7d..e6e5abfe25 100644 --- a/src/eq_data.jl +++ b/src/eq_data.jl @@ -6,7 +6,7 @@ struct EquationData <: PDEBase.AbstractVarEqMapping argmap end -function EquationData(pdesys, v) +function EquationData(pdesys, v, strategy) eqs = pdesys.eqs bcs = pdesys.bcs alleqs = vcat(eqs, bcs) @@ -61,13 +61,13 @@ argument(eq, eqdata) = eqdata.argmap[eq] function get_iv_argument(eqs, v::VariableMap) vars = map(eqs) do eq - _vars = map(depvar -> get_depvars(eq, depvar), v.depvar_ops) + _vars = map(depvar -> get_depvars(eq, [depvar]), v.depvar_ops) f_vars = filter(x -> !isempty(x), _vars) - v.args[operation(map(x -> first(x), f_vars))] + map(vars -> map(op -> v.args[op], operation.(vars)), f_vars) end args_ = map(vars) do _vars seen = [] - filter(reduce(vcat, arguments.(_vars))) do x + filter(reduce(vcat, arguments.(_vars), init = [])) do x if x isa Number true else diff --git a/src/loss_function_generation.jl b/src/loss_function_generation.jl index 97beeee8c4..1992481f94 100644 --- a/src/loss_function_generation.jl +++ b/src/loss_function_generation.jl @@ -8,9 +8,9 @@ function build_symbolic_loss_function(pinnrep::PINNRepresentation, eq; integrand = nothing, transformation_vars = nothing) @unpack v, eqdata, - phi, derivative, integral, - multioutput, init_params, strategy, eq_params, - param_estim, default_p = pinnrep + phi, derivative, integral, + multioutput, init_params, strategy, eq_params, + param_estim, default_p = pinnrep eltypeθ = eltype(pinnrep.flat_init_params) @@ -117,7 +117,6 @@ end function generate_derivative_rules(eq, eqdata, dummyvars) phi, u, θ = dummyvars - @register_symbolic derivative(phi, u, coord, εs, order, θ) rs = [@rule $(Differential(~x)^(~d::isinteger)(~w)) => derivative(phi, u, ~x, get_εs(~w), ~d, θ)] # TODO: add mixed derivatives return rs diff --git a/src/pinn_types.jl b/src/pinn_types.jl index f73a53fb1e..cffbb17734 100644 --- a/src/pinn_types.jl +++ b/src/pinn_types.jl @@ -26,19 +26,20 @@ function logscalar(logger, s::R, name::AbstractString, step::Integer) where {R < nothing end + """ ```julia PhysicsInformedNN(chain, - strategy; - init_params = nothing, - phi = nothing, - param_estim = false, - additional_loss = nothing, - adaptive_loss = nothing, - logger = nothing, - log_options = LogOptions(), - iteration = nothing, - kwargs...) where {iip} + strategy; + init_params = nothing, + phi = nothing, + param_estim = false, + additional_loss = nothing, + adaptive_loss = nothing, + logger = nothing, + log_options = LogOptions(), + iteration = nothing, + kwargs...) where {iip} ``` A `discretize` algorithm for the ModelingToolkit PDESystem interface, which transforms a @@ -60,7 +61,6 @@ methodology. `init_params` should match `Flux.destructure(chain)[1]` in shape. If `init_params` is not given, then the neural network default parameters are used. Note that for Lux, the default will convert to Float64. -* `flat_init_params`: the initial parameters of the neural networks, flattened into a vector. * `phi`: a trial solution, specified as `phi(x,p)` where `x` is the coordinates vector for the dependent variable and `p` are the weights of the phi function (generally the weights of the neural network defining `phi`). By default, this is generated from the `chain`. This @@ -79,11 +79,10 @@ methodology. * `iteration`: used to control the iteration counter??? * `kwargs`: Extra keyword arguments which are splatted to the `OptimizationProblem` on `solve`. """ -struct PhysicsInformedNN{T, P, PH, DER, PE, AL, ADA, LOG, K, F} <: SciMLBase.AbstractDiscretization +struct PhysicsInformedNN{T, P, PH, DER, PE, AL, ADA, LOG, K} <: SciMLBase.AbstractDiscretization chain::Any strategy::T init_params::P - flat_init_params::F phi::PH derivative::DER param_estim::PE @@ -96,126 +95,49 @@ struct PhysicsInformedNN{T, P, PH, DER, PE, AL, ADA, LOG, K, F} <: SciMLBase.Abs multioutput::Bool kwargs::K - @add_kwonly function PhysicsInformedNN(chain, - strategy; - init_params = nothing, - phi = nothing, - derivative = nothing, - param_estim = false, - additional_loss = nothing, - adaptive_loss = nothing, - logger = nothing, - log_options = LogOptions(), - iteration = nothing, - kwargs...) where {iip} - multioutput = typeof(chain) <: AbstractArray - - if phi === nothing - if multioutput - phi = Phi.(chain) - else - phi = Phi(chain) - end - end - - if derivative === nothing - _derivative = numeric_derivative - else - _derivative = derivative - end - - if iteration isa Vector{Int64} - self_increment = false - else - iteration = [1] - self_increment = true - end - - if init_params === nothing - # Use the initialization of the neural network framework - # But for Lux, default to Float64 - # For Flux, default to the types matching the values in the neural network - # This is done because Float64 is almost always better for these applications - # But with Flux there's already a chosen type from the user - - if chain isa AbstractArray - if chain[1] isa Flux.Chain - init_params = map(chain) do x - _x = Flux.destructure(x)[1] - end - else - x = map(chain) do x - _x = ComponentArrays.ComponentArray(Lux.initialparameters(Random.default_rng(), - x)) - Float64.(_x) # No ComponentArray GPU support - end - names = ntuple(i -> depvars[i], length(chain)) - init_params = ComponentArrays.ComponentArray(NamedTuple{names}(i - for i in x)) - end - else - if chain isa Flux.Chain - init_params = Flux.destructure(chain)[1] - init_params = init_params isa Array ? Float64.(init_params) : - init_params - else - init_params = Float64.(ComponentArrays.ComponentArray(Lux.initialparameters(Random.default_rng(), - chain))) - end - end - else - init_params = init_params - end - - if (phi isa Vector && phi[1].f isa Optimisers.Restructure) || - (!(phi isa Vector) && phi.f isa Optimisers.Restructure) - # Flux.Chain - flat_init_params = multioutput ? reduce(vcat, init_params) : init_params - flat_init_params = param_estim == false ? flat_init_params : - vcat(flat_init_params, - adapt(typeof(flat_init_params), default_p)) - else - flat_init_params = if init_params isa ComponentArrays.ComponentArray - init_params - elseif multioutput - @assert length(init_params) == length(depvars) - names = ntuple(i -> depvars[i], length(init_params)) - x = ComponentArrays.ComponentArray(NamedTuple{names}(i for i in init_params)) - else - ComponentArrays.ComponentArray(init_params) - end - flat_init_params = if param_estim == false && multioutput - ComponentArrays.ComponentArray(; depvar = flat_init_params) - elseif param_estim == false && !multioutput - flat_init_params - else - ComponentArrays.ComponentArray(; depvar = flat_init_params, p = default_p) - end - end - - if (phi isa Vector && phi[1].f isa Lux.AbstractExplicitLayer) - for ϕ in phi - ϕ.st = adapt(parameterless_type(ComponentArrays.getdata(flat_init_params)), - ϕ.st) - end - elseif (!(phi isa Vector) && phi.f isa Lux.AbstractExplicitLayer) - phi.st = adapt(parameterless_type(ComponentArrays.getdata(flat_init_params)), - phi.st) - end - - eltypeθ = eltype(flat_init_params) - - if adaloss === nothing - adaloss = NonAdaptiveLoss{eltypeθ}() - end + @add_kwonly function PhysicsInformedNN(chain, + strategy; + init_params = nothing, + phi = nothing, + derivative = nothing, + param_estim = false, + additional_loss = nothing, + adaptive_loss = nothing, + logger = nothing, + log_options = LogOptions(), + iteration = nothing, + kwargs...) where {iip} + multioutput = typeof(chain) <: AbstractArray + + if phi === nothing + if multioutput + _phi = Phi.(chain) + else + _phi = Phi(chain) + end + else + _phi = phi + end + + if derivative === nothing + _derivative = numeric_derivative + else + _derivative = derivative + end + + if iteration isa Vector{Int64} + self_increment = false + else + iteration = [1] + self_increment = true + end new{typeof(strategy), typeof(init_params), typeof(_phi), typeof(_derivative), typeof(param_estim), typeof(additional_loss), typeof(adaptive_loss), - typeof(logger), typeof(kwargs), typeof(flat_init_params)}(chain, + typeof(logger), typeof(kwargs)}(chain, strategy, init_params, - flat_init_params, _phi, _derivative, param_estim, @@ -230,6 +152,7 @@ struct PhysicsInformedNN{T, P, PH, DER, PE, AL, ADA, LOG, K, F} <: SciMLBase.Abs end end + """ `PINNRepresentation`` @@ -476,3 +399,5 @@ function numeric_derivative(phi, u, x, εs, order, θ) error("This shouldn't happen! Got an order of $(order).") end end + +@register_symbolic numeric_derivative(phi, u, coord, εs, order, θ) diff --git a/src/symbolic_utilities.jl b/src/symbolic_utilities.jl index 7c423657db..b29bb61417 100644 --- a/src/symbolic_utilities.jl +++ b/src/symbolic_utilities.jl @@ -358,13 +358,13 @@ function get_argument end function get_argument(eqs, v::VariableMap) vars = map(eqs) do eq - _vars = map(depvar -> get_depvars(eq, depvar), v.depvar_ops) + _vars = map(depvar -> get_depvars(eq, [depvar]), v.depvar_ops) f_vars = filter(x -> !isempty(x), _vars) - map(x -> first(x), f_vars) + map(first, f_vars) end args_ = map(vars) do _vars seen = [] - filter(reduce(vcat, arguments.(_vars))) do x + filter(reduce(vcat, arguments.(_vars), init = [])) do x if x isa Number true else diff --git a/test/NNPDE_tests.jl b/test/NNPDE_tests.jl index bf8f91b972..7303ebaa3e 100644 --- a/test/NNPDE_tests.jl +++ b/test/NNPDE_tests.jl @@ -37,7 +37,7 @@ function test_ode(strategy_) discretization = NeuralPDE.PhysicsInformedNN(chain, strategy_) - @named pde_system = PDESystem(eq, bcs, domains, [θ], [u]) + @named pde_system = PDESystem(eq, bcs, domains, [θ], [u(θ)]) prob = NeuralPDE.discretize(pde_system, discretization) res = Optimization.solve(prob, OptimizationOptimisers.Adam(0.1); maxiters = 1000) From c23df24f46db90f49678bdb193194af000f304f6 Mon Sep 17 00:00:00 2001 From: xtalax Date: Tue, 30 May 2023 17:43:57 +0100 Subject: [PATCH 12/40] more fixes --- src/discretize.jl | 12 ++++++------ src/loss_function_generation.jl | 26 +++++++++++++------------- src/pinn_types.jl | 14 +------------- 3 files changed, 20 insertions(+), 32 deletions(-) diff --git a/src/discretize.jl b/src/discretize.jl index 20d9437035..e9e474a69d 100644 --- a/src/discretize.jl +++ b/src/discretize.jl @@ -482,13 +482,13 @@ function SciMLBase.symbolic_discretize(pdesys::PDESystem, eltypeθ = eltype(flat_init_params) - if adaptive_loss === nothing - adaptive_loss = NonAdaptiveLoss{eltypeθ}() + if adaloss === nothing + adaloss = NonAdaptiveLoss{eltypeθ}() end pinnrep = PINNRepresentation(eqs, bcs, domains, eq_params, defaults, default_p, - param_estim, additional_loss, adaptive_loss, v, logger, + param_estim, additional_loss, adaloss, v, logger, multioutput, iteration, init_params, flat_init_params, phi, derivative, strategy, eqdata, nothing, nothing, nothing, nothing) @@ -497,11 +497,11 @@ function SciMLBase.symbolic_discretize(pdesys::PDESystem, #symbolic_pde_loss_functions = [build_symbolic_loss_function(pinnrep, eq) for eq in eqs] - #symbolic_bc_loss_functions = [build_symbolic_loss_function(pinnrep, bc) for bc in bcs] + #symbolic_bc_loss_functions = [build_symbolic_loss_function(pinnrep, bc) |> toexpr for bc in bcs] #pinnrep.integral = integral - pinnrep.symbolic_pde_loss_functions = symbolic_pde_loss_functions - pinnrep.symbolic_bc_loss_functions = symbolic_bc_loss_functions + #pinnrep.symbolic_pde_loss_functions = symbolic_pde_loss_functions + #pinnrep.symbolic_bc_loss_functions = symbolic_bc_loss_functions datafree_pde_loss_functions = [build_loss_function(pinnrep, eq) for eq in eqs] diff --git a/src/loss_function_generation.jl b/src/loss_function_generation.jl index 1992481f94..3e87269190 100644 --- a/src/loss_function_generation.jl +++ b/src/loss_function_generation.jl @@ -7,19 +7,19 @@ function build_symbolic_loss_function(pinnrep::PINNRepresentation, eq; default_p = [], integrand = nothing, transformation_vars = nothing) - @unpack v, eqdata, + @unpack varmap, eqdata, phi, derivative, integral, multioutput, init_params, strategy, eq_params, param_estim, default_p = pinnrep eltypeθ = eltype(pinnrep.flat_init_params) - eq_args = get(eqdata.eq_args, eq, v.x̄) + eq_args = get(eqdata.ivargs, eq, varmap.x̄) if integrand isa Nothing - this_eq_indvars = indvars(eq, eqmap) - this_eq_depvars = depvars(eq, eqmap) - loss_function = parse_equation(pinnrep, eq, eq_iv_args(eq, eqmap)) + this_eq_indvars = indvars(eq, eqdata) + this_eq_depvars = depvars(eq, eqdata) + loss_function = parse_equation(pinnrep, eq, eq_iv_args(eq, eqdata)) else this_eq_indvars = transformation_vars isa Nothing ? unique(indvars(eq, eqmap)) : transformation_vars @@ -56,8 +56,8 @@ function build_symbolic_loss_function(pinnrep::PINNRepresentation, eq; end end - full_loss_func = (cord, θ, phi, derivative, integral, u, p) -> begin - loss_function(get_coords(cord), θ, phi, derivative, integral, u, get_ps(θ)) + full_loss_func = (cord, θ, phi, u, p) -> begin + loss_function(get_coords(cord), θ, phi, u, get_ps(θ)) end return full_loss_func end @@ -70,7 +70,7 @@ function build_loss_function(pinnrep, eqs) param_estim = param_estim) u = get_u() - loss_function = (cord, θ) -> begin _loss_function(cord, θ, phi, derivative, integral, u, + loss_function = (cord, θ) -> begin _loss_function(cord, θ, phi, u, default_p) end return loss_function end @@ -90,22 +90,22 @@ end function parse_equation(pinnrep::PINNRepresentation, eq, ivs; is_integral = false, dict_transformation_vars = nothing, transformation_vars = nothing) - @unpack v, eqdata, derivative, integral = pinnrep + @unpack varmap, eqdata, derivative, integral = pinnrep expr = eq isa Equation ? eq.lhs : eq - ex_vars = vars(expr) + ex_vars = get_depvars(expr, varmap.depvar_ops) ignore = vcat(operation.(ex_vars), getindex, Differential, Integral, ~) ex_ops = filter(x -> !any(isequal(x), ignore), ex_ops) op_rules = [@rule $(op)(~~a) => broadcast(op, ~a...) for op in ex_ops] dummyvars = @variables phi, u, θ - deriv_rules = generate_derivative_rules(eq, eqdata, dummyvars) + deriv_rules = generate_derivative_rules(eq, eqdata, dummyvars, derivative) ch = Postwalk(Chain([deriv_rules; op_rules])) expr = ch(expr) sym_coords = DestructuredArgs(ivs) - ps = DestructuredArgs(v.ps) + ps = DestructuredArgs(varmap.ps) args = [sym_coords, θ, phi, u, ps] @@ -115,7 +115,7 @@ function parse_equation(pinnrep::PINNRepresentation, eq, ivs; is_integral = fals return ex end -function generate_derivative_rules(eq, eqdata, dummyvars) +function generate_derivative_rules(eq, eqdata, dummyvars, derivative) phi, u, θ = dummyvars rs = [@rule $(Differential(~x)^(~d::isinteger)(~w)) => derivative(phi, u, ~x, get_εs(~w), ~d, θ)] # TODO: add mixed derivatives diff --git a/src/pinn_types.jl b/src/pinn_types.jl index cffbb17734..faedb016ac 100644 --- a/src/pinn_types.jl +++ b/src/pinn_types.jl @@ -250,19 +250,7 @@ mutable struct PINNRepresentation """ ??? """ - pde_indvars::Any - """ - ??? - """ - bc_indvars::Any - """ - ??? - """ - pde_integration_vars::Any - """ - ??? - """ - bc_integration_vars::Any + eqdata::Any """ ??? """ From 5fdc7a5e7063fd270072248acb2ad413f0a4605c Mon Sep 17 00:00:00 2001 From: xtalax Date: Wed, 31 May 2023 16:59:20 +0100 Subject: [PATCH 13/40] move to varmap, deprecate old parsing --- src/NeuralPDE.jl | 3 + src/discretize.jl | 234 ++------------------------------ src/loss_function_generation.jl | 3 +- src/neural_adapter.jl | 44 +++--- src/training_strategies.jl | 32 ++--- 5 files changed, 47 insertions(+), 269 deletions(-) diff --git a/src/NeuralPDE.jl b/src/NeuralPDE.jl index 7a87035a65..d13e84d378 100644 --- a/src/NeuralPDE.jl +++ b/src/NeuralPDE.jl @@ -24,6 +24,8 @@ using DomainSets using Symbolics using Symbolics: wrap, unwrap, arguments, operation using SymbolicUtils +using SymbolicUtils.Code +using SymbolicUtils: Postwalk, Chain import ModelingToolkit: value, nameof, toexpr, build_expr, expand_derivatives import DomainSets: Domain, ClosedInterval import ModelingToolkit: Interval, infimum, supremum #,Ball @@ -40,6 +42,7 @@ RuntimeGeneratedFunctions.init(@__MODULE__) abstract type AbstractPINN <: SciMLBase.AbstractDiscretization end abstract type AbstractTrainingStrategy end +abstract type AbstractGridfreeStrategy <: AbstractTrainingStrategy end include("pinn_types.jl") include("eq_data.jl") diff --git a/src/discretize.jl b/src/discretize.jl index e9e474a69d..a3d70ed66d 100644 --- a/src/discretize.jl +++ b/src/discretize.jl @@ -1,199 +1,3 @@ -""" -Build a loss function for a PDE or a boundary condition - -# Examples: System of PDEs: - -Take expressions in the form: - -[Dx(u1(x,y)) + 4*Dy(u2(x,y)) ~ 0, - Dx(u2(x,y)) + 9*Dy(u1(x,y)) ~ 0] - -to - -:((cord, θ, phi, derivative, u)->begin - #= ... =# - #= ... =# - begin - (θ1, θ2) = (θ[1:33], θ"[34:66]) - (phi1, phi2) = (phi[1], phi[2]) - let (x, y) = (cord[1], cord[2]) - [(+)(derivative(phi1, u, [x, y], [[ε, 0.0]], 1, θ1), (*)(4, derivative(phi2, u, [x, y], [[0.0, ε]], 1, θ2))) - 0, - (+)(derivative(phi2, u, [x, y], [[ε, 0.0]], 1, θ2), (*)(9, derivative(phi1, u, [x, y], [[0.0, ε]], 1, θ1))) - 0] - end - end - end) - -for Flux.Chain, and - -:((cord, θ, phi, derivative, u)->begin - #= ... =# - #= ... =# - begin - (u1, u2) = (θ.depvar.u1, θ.depvar.u2) - (phi1, phi2) = (phi[1], phi[2]) - let (x, y) = (cord[1], cord[2]) - [(+)(derivative(phi1, u, [x, y], [[ε, 0.0]], 1, u1), (*)(4, derivative(phi2, u, [x, y], [[0.0, ε]], 1, u1))) - 0, - (+)(derivative(phi2, u, [x, y], [[ε, 0.0]], 1, u2), (*)(9, derivative(phi1, u, [x, y], [[0.0, ε]], 1, u2))) - 0] - end - end - end) - -for Lux.AbstractExplicitLayer -""" -# function build_symbolic_loss_function(pinnrep::PINNRepresentation, eq; -# eq_params = SciMLBase.NullParameters(), -# param_estim = false, -# default_p = nothing, -# bc_indvars = pinnrep.v.x̄, -# integrand = nothing, -# dict_transformation_vars = nothing, -# transformation_vars = nothing, -# integrating_depvars = pinnrep.v.ū) -# @unpack v, eqdata, -# phi, derivative, integral, -# multioutput, init_params, strategy, eq_params, -# param_estim, default_p = pinnrep - -# eltypeθ = eltype(pinnrep.flat_init_params) - -# if integrand isa Nothing -# loss_function = parse_equation(pinnrep, eq) -# this_eq_pair = pair(eq, depvars, dict_depvars, dict_depvar_input) -# this_eq_indvars = indvars(eq, eqmap) -# else -# this_eq_pair = Dict(map(intvars -> dict_depvars[intvars] => dict_depvar_input[intvars], -# integrating_depvars)) -# this_eq_indvars = transformation_vars isa Nothing ? -# unique(indvars(eq, eqmap)) : transformation_vars -# loss_function = integrand -# end - -# vars = :(cord, $θ, phi, derivative, integral, u, p) -# ex = Expr(:block) -# if multioutput -# θ_nums = Symbol[] -# phi_nums = Symbol[] -# for v in depvars -# num = dict_depvars[v] -# push!(θ_nums, :($(Symbol(:($θ), num)))) -# push!(phi_nums, :($(Symbol(:phi, num)))) -# end - -# expr_θ = Expr[] -# expr_phi = Expr[] - -# acum = [0; accumulate(+, map(length, init_params))] -# sep = [(acum[i] + 1):acum[i + 1] for i in 1:(length(acum) - 1)] - -# for i in eachindex(depvars) -# if (phi isa Vector && phi[1].f isa Optimisers.Restructure) || -# (!(phi isa Vector) && phi.f isa Optimisers.Restructure) -# # Flux.Chain -# push!(expr_θ, :($θ[$(sep[i])])) -# else # Lux.AbstractExplicitLayer -# push!(expr_θ, :($θ.depvar.$(depvars[i]))) -# end -# push!(expr_phi, :(phi[$i])) -# end - -# vars_θ = Expr(:(=), build_expr(:tuple, θ_nums), build_expr(:tuple, expr_θ)) -# push!(ex.args, vars_θ) - -# vars_phi = Expr(:(=), build_expr(:tuple, phi_nums), build_expr(:tuple, expr_phi)) -# push!(ex.args, vars_phi) -# end - -# #Add an expression for parameter symbols -# if param_estim == true && eq_params != SciMLBase.NullParameters() -# param_len = length(eq_params) -# last_indx = [0; accumulate(+, map(length, init_params))][end] -# params_symbols = Symbol[] -# expr_params = Expr[] -# for (i, eq_param) in enumerate(eq_params) -# if (phi isa Vector && phi[1].f isa Optimisers.Restructure) || -# (!(phi isa Vector) && phi.f isa Optimisers.Restructure) -# push!(expr_params, :($θ[$((i + last_indx):(i + last_indx))])) -# else -# push!(expr_params, :($θ.p[$((i):(i))])) -# end -# push!(params_symbols, Symbol(:($eq_param))) -# end -# params_eq = Expr(:(=), build_expr(:tuple, params_symbols), -# build_expr(:tuple, expr_params)) -# push!(ex.args, params_eq) -# end - -# if eq_params != SciMLBase.NullParameters() && param_estim == false -# params_symbols = Symbol[] -# expr_params = Expr[] -# for (i, eq_param) in enumerate(eq_params) -# push!(expr_params, :(ArrayInterface.allowed_getindex(p, ($i):($i)))) -# push!(params_symbols, Symbol(:($eq_param))) -# end -# params_eq = Expr(:(=), build_expr(:tuple, params_symbols), -# build_expr(:tuple, expr_params)) -# push!(ex.args, params_eq) -# end - -# eq_pair_expr = Expr[] -# for i in keys(this_eq_pair) -# push!(eq_pair_expr, :($(Symbol(:cord, :($i))) = vcat($(this_eq_pair[i]...)))) -# end -# vcat_expr = Expr(:block, :($(eq_pair_expr...))) -# vcat_expr_loss_functions = Expr(:block, vcat_expr, loss_function) # TODO rename - -# if strategy isa QuadratureTraining -# indvars_ex = get_indvars_ex(bc_indvars) -# left_arg_pairs, right_arg_pairs = this_eq_indvars, indvars_ex -# vars_eq = Expr(:(=), build_expr(:tuple, left_arg_pairs), -# build_expr(:tuple, right_arg_pairs)) -# else -# indvars_ex = [:($:cord[[$i], :]) for (i, x) in enumerate(this_eq_indvars)] -# left_arg_pairs, right_arg_pairs = this_eq_indvars, indvars_ex -# vars_eq = Expr(:(=), build_expr(:tuple, left_arg_pairs), -# build_expr(:tuple, right_arg_pairs)) -# end - -# if !(dict_transformation_vars isa Nothing) -# transformation_expr_ = Expr[] - -# for (i, u) in dict_transformation_vars -# push!(transformation_expr_, :($i = $u)) -# end -# transformation_expr = Expr(:block, :($(transformation_expr_...))) -# vcat_expr_loss_functions = Expr(:block, transformation_expr, vcat_expr, -# loss_function) -# end -# let_ex = Expr(:let, vars_eq, vcat_expr_loss_functions) -# push!(ex.args, let_ex) -# expr_loss_function = :(($vars) -> begin $ex end) -# end - -# """ -# ```julia -# build_loss_function(eqs, indvars, depvars, phi, derivative, init_params; bc_indvars=nothing) -# ``` - -# Returns the body of loss function, which is the executable Julia function, for the main -# equation or boundary condition. -# """ -# function build_loss_function(pinnrep::PINNRepresentation, eqs, bc_indvars) -# @unpack eq_params, param_estim, default_p, phi, derivative, integral = pinnrep - -# bc_indvars = bc_indvars === nothing ? pinnrep.indvars : bc_indvars - -# expr_loss_function = build_symbolic_loss_function(pinnrep, eqs; -# bc_indvars = bc_indvars, -# eq_params = eq_params, -# param_estim = param_estim, -# default_p = default_p) -# u = get_u() -# _loss_function = @RuntimeGeneratedFunction(expr_loss_function) -# loss_function = (cord, θ) -> begin _loss_function(cord, θ, phi, derivative, integral, u, -# default_p) end -# return loss_function -# end - """ ```julia generate_training_sets(domains,dx,bcs,_indvars::Array,_depvars::Array) @@ -204,17 +8,8 @@ strategy. """ function generate_training_sets end -function generate_training_sets(domains, dx, eqs, bcs, eltypeθ, _indvars::Array, - _depvars::Array) - depvars, indvars, dict_indvars, dict_depvars, dict_depvar_input = get_vars(_indvars, - _depvars) - return generate_training_sets(domains, dx, eqs, bcs, eltypeθ, dict_indvars, - dict_depvars) -end - # Generate training set in the domain and on the boundary -function generate_training_sets(domains, dx, eqs, bcs, eltypeθ, dict_indvars::Dict, - dict_depvars::Dict) +function generate_training_sets(domains, dx, eqs, bcs, eltypeθ, varmap) if dx isa Array dxs = dx else @@ -222,11 +17,11 @@ function generate_training_sets(domains, dx, eqs, bcs, eltypeθ, dict_indvars::D end spans = [infimum(d.domain):dx:supremum(d.domain) for (d, dx) in zip(domains, dxs)] - dict_var_span = Dict([Symbol(d.variables) => infimum(d.domain):dx:supremum(d.domain) + dict_var_span = Dict([d.variables => infimum(d.domain):dx:supremum(d.domain) for (d, dx) in zip(domains, dxs)]) - bound_args = get_argument(bcs, dict_indvars, dict_depvars) - bound_vars = get_variables(bcs, dict_indvars, dict_depvars) + bound_args = get_argument(bcs, varmap) + bound_vars = get_variables(bcs, varmap) dif = [eltypeθ[] for i in 1:size(domains)[1]] for _args in bound_vars @@ -241,7 +36,7 @@ function generate_training_sets(domains, dx, eqs, bcs, eltypeθ, dict_indvars::D setdiff(c, d) end - dict_var_span_ = Dict([Symbol(d.variables) => bc for (d, bc) in zip(domains, bc_data)]) + dict_var_span_ = Dict([d.variables => bc for (d, bc) in zip(domains, bc_data)]) bcs_train_sets = map(bound_args) do bt span = map(b -> get(dict_var_span, b, b), bt) @@ -249,8 +44,8 @@ function generate_training_sets(domains, dx, eqs, bcs, eltypeθ, dict_indvars::D hcat(vec(map(points -> collect(points), Iterators.product(span...)))...)) end - pde_vars = get_variables(eqs, dict_indvars, dict_depvars) - pde_args = get_argument(eqs, dict_indvars, dict_depvars) + pde_vars = get_variables(eqs, varmap) + pde_args = get_argument(eqs, varmap) pde_train_set = adapt(eltypeθ, hcat(vec(map(points -> collect(points), @@ -274,20 +69,7 @@ training strategy: StochasticTraining, QuasiRandomTraining, QuadratureTraining. """ function get_bounds end -function get_bounds(domains, eqs, bcs, eltypeθ, _indvars::Array, _depvars::Array, strategy) - depvars, indvars, dict_indvars, dict_depvars, dict_depvar_input = get_vars(_indvars, - _depvars) - return get_bounds(domains, eqs, bcs, eltypeθ, dict_indvars, dict_depvars, strategy) -end - -function get_bounds(domains, eqs, bcs, eltypeθ, _indvars::Array, _depvars::Array, - strategy::QuadratureTraining) - depvars, indvars, dict_indvars, dict_depvars, dict_depvar_input = get_vars(_indvars, - _depvars) - return get_bounds(domains, eqs, bcs, eltypeθ, dict_indvars, dict_depvars, strategy) -end - -function get_bounds(domains, eqs, bcs, eltypeθ, v::VariableMap, strategy::QuadratureTraining) +function get_bounds(domains, eqs, bcs, eltypeθ, v::VariableMap, strategy::AbstractGridfreeTraining) dict_lower_bound = Dict([d.variables => infimum(d.domain) for d in domains]) dict_upper_bound = Dict([d.variables => supremum(d.domain) for d in domains]) pde_args = get_argument(eqs, v) diff --git a/src/loss_function_generation.jl b/src/loss_function_generation.jl index 3e87269190..2fb422ed33 100644 --- a/src/loss_function_generation.jl +++ b/src/loss_function_generation.jl @@ -95,6 +95,7 @@ function parse_equation(pinnrep::PINNRepresentation, eq, ivs; is_integral = fals expr = eq isa Equation ? eq.lhs : eq ex_vars = get_depvars(expr, varmap.depvar_ops) ignore = vcat(operation.(ex_vars), getindex, Differential, Integral, ~) + ex_ops = operations(expr) ex_ops = filter(x -> !any(isequal(x), ignore), ex_ops) op_rules = [@rule $(op)(~~a) => broadcast(op, ~a...) for op in ex_ops] @@ -117,7 +118,7 @@ end function generate_derivative_rules(eq, eqdata, dummyvars, derivative) phi, u, θ = dummyvars - rs = [@rule $(Differential(~x)^(~d::isinteger)(~w)) => derivative(phi, u, ~x, get_εs(~w), ~d, θ)] + rs = [@rule ($Differential(~x)^(~d::isinteger))(~w) => derivative(phi, u, x, get_εs(w), d, θ)] # TODO: add mixed derivatives return rs end diff --git a/src/neural_adapter.jl b/src/neural_adapter.jl index 1e97c9eb70..4aa9fadbdd 100644 --- a/src/neural_adapter.jl +++ b/src/neural_adapter.jl @@ -15,18 +15,17 @@ function get_loss_function_(loss, init_params, pde_system, strategy::GridTrainin eqs = [eqs] end domains = pde_system.domain - depvars, indvars, dict_indvars, dict_depvars = get_vars(pde_system.indvars, - pde_system.depvars) + eltypeθ = eltype(init_params) dx = strategy.dx train_set = generate_training_sets(domains, dx, eqs, eltypeθ) get_loss_function(loss, train_set, eltypeθ, strategy) end -function get_bounds_(domains, eqs, eltypeθ, dict_indvars, dict_depvars, strategy) +function get_bounds_(domains, eqs, eltypeθ, varmap, strategy) dict_span = Dict([Symbol(d.variables) => [infimum(d.domain), supremum(d.domain)] for d in domains]) - args = get_argument(eqs, dict_indvars, dict_depvars) + args = get_argument(eqs, varmap) bounds = map(args) do pd span = map(p -> get(dict_span, p, p), pd) @@ -35,44 +34,38 @@ function get_bounds_(domains, eqs, eltypeθ, dict_indvars, dict_depvars, strateg bounds end -function get_loss_function_(loss, init_params, pde_system, strategy::StochasticTraining) +function get_loss_function_(loss, init_params, pde_system, varmap, strategy::StochasticTraining) eqs = pde_system.eqs if !(eqs isa Array) eqs = [eqs] end domains = pde_system.domain - depvars, indvars, dict_indvars, dict_depvars = get_vars(pde_system.indvars, - pde_system.depvars) - eltypeθ = eltype(init_params) - bound = get_bounds_(domains, eqs, eltypeθ, dict_indvars, dict_depvars, strategy)[1] + bound = get_bounds_(domains, eqs, eltypeθ, varmap, strategy)[1] get_loss_function(loss, bound, eltypeθ, strategy) end -function get_loss_function_(loss, init_params, pde_system, strategy::QuasiRandomTraining) +function get_loss_function_(loss, init_params, pde_system, varmap, strategy::QuasiRandomTraining) eqs = pde_system.eqs if !(eqs isa Array) eqs = [eqs] end domains = pde_system.domain - depvars, indvars, dict_indvars, dict_depvars = get_vars(pde_system.indvars, - pde_system.depvars) - eltypeθ = eltype(init_params) - bound = get_bounds_(domains, eqs, eltypeθ, dict_indvars, dict_depvars, strategy)[1] + bound = get_bounds_(domains, eqs, eltypeθ, varmap, strategy)[1] get_loss_function(loss, bound, eltypeθ, strategy) end -function get_bounds_(domains, eqs, eltypeθ, dict_indvars, dict_depvars, +function get_bounds_(domains, eqs, eltypeθ, varmap, strategy::QuadratureTraining) - dict_lower_bound = Dict([Symbol(d.variables) => infimum(d.domain) for d in domains]) - dict_upper_bound = Dict([Symbol(d.variables) => supremum(d.domain) for d in domains]) + dict_lower_bound = Dict([d.variables => infimum(d.domain) for d in domains]) + dict_upper_bound = Dict([d.variables => supremum(d.domain) for d in domains]) - args = get_argument(eqs, dict_indvars, dict_depvars) + args = get_argument(eqs, varmap) lower_bounds = map(args) do pd span = map(p -> get(dict_lower_bound, p, p), pd) @@ -85,18 +78,15 @@ function get_bounds_(domains, eqs, eltypeθ, dict_indvars, dict_depvars, bound = lower_bounds, upper_bounds end -function get_loss_function_(loss, init_params, pde_system, strategy::QuadratureTraining) +function get_loss_function_(loss, init_params, pde_system, varmap, strategy::QuadratureTraining) eqs = pde_system.eqs if !(eqs isa Array) eqs = [eqs] end domains = pde_system.domain - depvars, indvars, dict_indvars, dict_depvars = get_vars(pde_system.indvars, - pde_system.depvars) - eltypeθ = eltype(init_params) - bound = get_bounds_(domains, eqs, eltypeθ, dict_indvars, dict_depvars, strategy) + bound = get_bounds_(domains, eqs, eltypeθ, varmap, strategy) lb, ub = bound get_loss_function(loss, lb[1], ub[1], eltypeθ, strategy) end @@ -118,7 +108,8 @@ Trains a neural network using the results from one already obtained prediction. function neural_adapter end function neural_adapter(loss, init_params, pde_system, strategy) - loss_function__ = get_loss_function_(loss, init_params, pde_system, strategy) + varmap = VariableMap(pde_system) + loss_function__ = get_loss_function_(loss, init_params, pde_system, varmap, strategy) function loss_function_(θ, p) loss_function__(θ) @@ -128,8 +119,9 @@ function neural_adapter(loss, init_params, pde_system, strategy) end function neural_adapter(losses::Array, init_params, pde_systems::Array, strategy) - loss_functions_ = map(zip(losses, pde_systems)) do (l, p) - get_loss_function_(l, init_params, p, strategy) + varmaps = VariableMap.(pde_systems) + loss_functions_ = map(zip(losses, pde_systems, varmaps)) do (l, p, v) + get_loss_function_(l, init_params, p, v, strategy) end loss_function__ = θ -> sum(map(l -> l(θ), loss_functions_)) function loss_function_(θ, p) diff --git a/src/training_strategies.jl b/src/training_strategies.jl index d87bc64d7c..e13ae178cc 100644 --- a/src/training_strategies.jl +++ b/src/training_strategies.jl @@ -20,12 +20,12 @@ function merge_strategy_with_loss_function(pinnrep::PINNRepresentation, strategy::GridTraining, datafree_pde_loss_function, datafree_bc_loss_function) - @unpack domains, eqs, bcs, dict_indvars, dict_depvars, flat_init_params = pinnrep + @unpack domains, eqs, bcs, varmap, flat_init_params = pinnrep dx = strategy.dx eltypeθ = eltype(pinnrep.flat_init_params) train_sets = generate_training_sets(domains, dx, eqs, bcs, eltypeθ, - dict_indvars, dict_depvars) + varmap) # the points in the domain and on the boundary pde_train_sets, bcs_train_sets = train_sets @@ -62,7 +62,7 @@ StochasticTraining(points; bcs_points = points) * `bcs_points`: number of points in random select training set for boundary conditions (by default, it equals `points`). """ -struct StochasticTraining <: AbstractTrainingStrategy +struct StochasticTraining <: AbstractGridfreeStrategy points::Int64 bcs_points::Int64 end @@ -80,11 +80,11 @@ function merge_strategy_with_loss_function(pinnrep::PINNRepresentation, strategy::StochasticTraining, datafree_pde_loss_function, datafree_bc_loss_function) - @unpack domains, eqs, bcs, v, flat_init_params = pinnrep + @unpack domains, eqs, bcs, varmap, flat_init_params = pinnrep eltypeθ = eltype(pinnrep.flat_init_params) - bounds = get_bounds(domains, eqs, bcs, eltypeθ, v, + bounds = get_bounds(domains, eqs, bcs, eltypeθ, varmap, strategy) pde_bounds, bcs_bounds = bounds @@ -136,7 +136,7 @@ that accelerate the convergence in high dimensional spaces over pure random sequ For more information, see [QuasiMonteCarlo.jl](https://docs.sciml.ai/QuasiMonteCarlo/stable/) """ -struct QuasiRandomTraining <: AbstractTrainingStrategy +struct QuasiRandomTraining <: AbstractGridfreeStrategy points::Int64 bcs_points::Int64 sampling_alg::QuasiMonteCarlo.SamplingAlgorithm @@ -162,22 +162,22 @@ function merge_strategy_with_loss_function(pinnrep::PINNRepresentation, strategy::QuasiRandomTraining, datafree_pde_loss_function, datafree_bc_loss_function) - @unpack domains, eqs, bcs, dict_indvars, dict_depvars, flat_init_params = pinnrep + @unpack domains, eqs, bcs, varmap, flat_init_params = pinnrep eltypeθ = eltype(pinnrep.flat_init_params) - bounds = get_bounds(domains, eqs, bcs, eltypeθ, dict_indvars, dict_depvars, + bounds = get_bounds(domains, eqs, bcs, eltypeθ, varmap, strategy) pde_bounds, bcs_bounds = bounds - pde_loss_functions = [get_loss_function(_loss, bound, eltypeθ, strategy) + pde_loss_functions = [get_loss_function(_loss, bound, eltypeθ, varmap, strategy) for (_loss, bound) in zip(datafree_pde_loss_function, pde_bounds)] strategy_ = QuasiRandomTraining(strategy.bcs_points; sampling_alg = strategy.sampling_alg, resampling = strategy.resampling, minibatch = strategy.minibatch) - bc_loss_functions = [get_loss_function(_loss, bound, eltypeθ, strategy_) + bc_loss_functions = [get_loss_function(_loss, bound, eltypeθ, varmap, strategy_) for (_loss, bound) in zip(datafree_bc_loss_function, bcs_bounds)] pde_loss_functions, bc_loss_functions @@ -238,7 +238,7 @@ For more information on the argument values and algorithm choices, see [Integrals.jl](https://docs.sciml.ai/Integrals/stable/). """ struct QuadratureTraining{Q <: SciMLBase.AbstractIntegralAlgorithm, T} <: - AbstractTrainingStrategy + AbstractGridfreeStrategy quadrature_alg::Q reltol::T abstol::T @@ -255,18 +255,18 @@ function merge_strategy_with_loss_function(pinnrep::PINNRepresentation, strategy::QuadratureTraining, datafree_pde_loss_function, datafree_bc_loss_function) - @unpack domains, eqs, bcs, dict_indvars, dict_depvars, flat_init_params = pinnrep + @unpack domains, eqs, bcs, varmap, flat_init_params = pinnrep eltypeθ = eltype(pinnrep.flat_init_params) - bounds = get_bounds(domains, eqs, bcs, eltypeθ, dict_indvars, dict_depvars, + bounds = get_bounds(domains, eqs, bcs, eltypeθ, varmap, strategy) pde_bounds, bcs_bounds = bounds lbs, ubs = pde_bounds - pde_loss_functions = [get_loss_function(_loss, lb, ub, eltypeθ, strategy) + pde_loss_functions = [get_loss_function(_loss, lb, ub, eltypeθ, varmap, strategy) for (_loss, lb, ub) in zip(datafree_pde_loss_function, lbs, ubs)] lbs, ubs = bcs_bounds - bc_loss_functions = [get_loss_function(_loss, lb, ub, eltypeθ, strategy) + bc_loss_functions = [get_loss_function(_loss, lb, ub, eltypeθ, varmap, strategy) for (_loss, lb, ub) in zip(datafree_bc_loss_function, lbs, ubs)] pde_loss_functions, bc_loss_functions @@ -318,7 +318,7 @@ such that the total number of sampled points is equivalent to the given samples This training strategy can only be used with ODEs (`NNODE`). """ -struct WeightedIntervalTraining{T} <: AbstractTrainingStrategy +struct WeightedIntervalTraining{T} <: AbstractGridfreeStrategyy weights::Vector{T} samples::Int end From 5891313f2bbc55b851346c8e3d56fcc46040f203 Mon Sep 17 00:00:00 2001 From: xtalax Date: Mon, 5 Jun 2023 14:41:36 +0100 Subject: [PATCH 14/40] more fixes --- src/NeuralPDE.jl | 2 +- src/discretize.jl | 4 +- src/eq_data.jl | 34 ++++++++--------- src/loss_function_generation.jl | 66 ++++++++++++++++++++++++++------- src/symbolic_utilities.jl | 43 ++++++--------------- src/training_strategies.jl | 10 ++--- 6 files changed, 89 insertions(+), 70 deletions(-) diff --git a/src/NeuralPDE.jl b/src/NeuralPDE.jl index d13e84d378..288115847d 100644 --- a/src/NeuralPDE.jl +++ b/src/NeuralPDE.jl @@ -16,7 +16,7 @@ using QuasiMonteCarlo using RuntimeGeneratedFunctions using SciMLBase using PDEBase -using PDEBase: cardinalize_eqs!, get_depvars, get_indvars +using PDEBase: cardinalize_eqs!, get_depvars, get_indvars, differential_order using Statistics using ArrayInterface import Optim diff --git a/src/discretize.jl b/src/discretize.jl index a3d70ed66d..82981b0d51 100644 --- a/src/discretize.jl +++ b/src/discretize.jl @@ -69,7 +69,7 @@ training strategy: StochasticTraining, QuasiRandomTraining, QuadratureTraining. """ function get_bounds end -function get_bounds(domains, eqs, bcs, eltypeθ, v::VariableMap, strategy::AbstractGridfreeTraining) +function get_bounds(domains, eqs, bcs, eltypeθ, v::VariableMap, strategy::AbstractGridfreeStrategy) dict_lower_bound = Dict([d.variables => infimum(d.domain) for d in domains]) dict_upper_bound = Dict([d.variables => supremum(d.domain) for d in domains]) pde_args = get_argument(eqs, v) @@ -208,7 +208,7 @@ function SciMLBase.symbolic_discretize(pdesys::PDESystem, x)) Float64.(_x) # No ComponentArray GPU support end - names = ntuple(i -> ~Symbol.(v.ū)[i], length(chain)) + names = ntuple(i -> Symbol.(v.ū)[i], length(chain)) init_params = ComponentArrays.ComponentArray(NamedTuple{names}(i for i in x)) end diff --git a/src/eq_data.jl b/src/eq_data.jl index e6e5abfe25..5108eecd08 100644 --- a/src/eq_data.jl +++ b/src/eq_data.jl @@ -7,8 +7,8 @@ struct EquationData <: PDEBase.AbstractVarEqMapping end function EquationData(pdesys, v, strategy) - eqs = pdesys.eqs - bcs = pdesys.bcs + eqs = map(eq -> eq.lhs, pdesys.eqs) + bcs = map(eq -> eq.lhs, pdesys.bcs) alleqs = vcat(eqs, bcs) argmap = map(alleqs) do eq @@ -21,20 +21,20 @@ function EquationData(pdesys, v, strategy) eq => get_indvars(eq, v) end |> Dict - args = map(alleqs) do eq - if strategy isa QuadratureTraining - eq => get_argument(bcs, v) - else - eq => get_variables(bcs, v) - end + if strategy isa QuadratureTraining + _args = get_argument(alleqs, v) + else + _args = get_variables(alleqs, v) + end + + args = map(zip(alleqs, _args)) do (eq, args) + eq => args end |> Dict - ivargs = map(alleqs) do eq - if strategy isa QuadratureTraining - eq => get_iv_argument(eqs, v) - else - eq => get_iv_variables(eqs, v) - end + ivargs = get_iv_argument(alleqs, v) + + ivargs = map(zip(alleqs, ivargs)) do (eq, args) + eq => args end |> Dict EquationData(depvarmap, indvarmap, args, ivargs, argmap) @@ -63,13 +63,13 @@ function get_iv_argument(eqs, v::VariableMap) vars = map(eqs) do eq _vars = map(depvar -> get_depvars(eq, [depvar]), v.depvar_ops) f_vars = filter(x -> !isempty(x), _vars) - map(vars -> map(op -> v.args[op], operation.(vars)), f_vars) + mapreduce(vars -> mapreduce(op -> v.args[op], vcat, operation.(vars), init = []), vcat, f_vars, init = []) end args_ = map(vars) do _vars seen = [] - filter(reduce(vcat, arguments.(_vars), init = [])) do x + filter(_vars) do x if x isa Number - true + error("Unreachable") else if any(isequal(x), seen) false diff --git a/src/loss_function_generation.jl b/src/loss_function_generation.jl index 2fb422ed33..d073a99a13 100644 --- a/src/loss_function_generation.jl +++ b/src/loss_function_generation.jl @@ -14,6 +14,8 @@ function build_symbolic_loss_function(pinnrep::PINNRepresentation, eq; eltypeθ = eltype(pinnrep.flat_init_params) + eq = eq isa Equation ? eq.lhs : eq + eq_args = get(eqdata.ivargs, eq, varmap.x̄) if integrand isa Nothing @@ -68,7 +70,6 @@ function build_loss_function(pinnrep, eqs) _loss_function = build_symbolic_loss_function(pinnrep, eqs, eq_params = eq_params, param_estim = param_estim) - u = get_u() loss_function = (cord, θ) -> begin _loss_function(cord, θ, phi, u, default_p) end @@ -87,23 +88,23 @@ end # Parse equation ############################################################################################ -function parse_equation(pinnrep::PINNRepresentation, eq, ivs; is_integral = false, +function parse_equation(pinnrep::PINNRepresentation, term, ivs; is_integral = false, dict_transformation_vars = nothing, transformation_vars = nothing) - @unpack varmap, eqdata, derivative, integral = pinnrep + @unpack varmap, eqdata, derivative, integral, flat_init_params = pinnrep + eltypeθ = eltype(flat_init_params) - expr = eq isa Equation ? eq.lhs : eq - ex_vars = get_depvars(expr, varmap.depvar_ops) + ex_vars = get_depvars(term, varmap.depvar_ops) ignore = vcat(operation.(ex_vars), getindex, Differential, Integral, ~) - ex_ops = operations(expr) + ex_ops = operations(term) ex_ops = filter(x -> !any(isequal(x), ignore), ex_ops) op_rules = [@rule $(op)(~~a) => broadcast(op, ~a...) for op in ex_ops] dummyvars = @variables phi, u, θ - deriv_rules = generate_derivative_rules(eq, eqdata, dummyvars, derivative) + deriv_rules = generate_derivative_rules(term, eqdata, eltypeθ, dummyvars, derivative, varmap) + @show deriv_rules - ch = Postwalk(Chain([deriv_rules; op_rules])) - expr = ch(expr) + expr = substitute(term, deriv_rules) sym_coords = DestructuredArgs(ivs) ps = DestructuredArgs(varmap.ps) @@ -112,15 +113,52 @@ function parse_equation(pinnrep::PINNRepresentation, eq, ivs; is_integral = fals args = [sym_coords, θ, phi, u, ps] ex = Func(args, [], expr) |> toexpr + @show ex + f = @RuntimeGeneratedFunction ex + return f +end - return ex +function get_ε(dim::Int, der_num::Int, ::Type{eltypeθ}, order) where {eltypeθ} + epsilon = ^(eps(eltypeθ), one(eltypeθ) / (2 + order)) + ε = zeros(eltypeθ, dim) + ε[der_num] = epsilon + ε end -function generate_derivative_rules(eq, eqdata, dummyvars, derivative) +function generate_derivative_rules(term, eqdata, eltypeθ, dummyvars, derivative, varmap) phi, u, θ = dummyvars - rs = [@rule ($Differential(~x)^(~d::isinteger))(~w) => derivative(phi, u, x, get_εs(w), d, θ)] - # TODO: add mixed derivatives - return rs + dvs = depvars(term, eqdata) + @show dvs + # Orthodox derivatives + rs = [reduce(vcat, [reduce(vcat, [(Differential(x)^d)(w) => + derivative(phi, + u, x, + get_ε(length(arguments(w)), + j, eltypeθ, d), + d, θ) + for d in differential_orders(term, x)], init = []) + for (j, x) in enumerate(varmap.args[operation(w)])], init = []) + for w in dvs] + # Mixed derivatives + mx = mapreduce(vcat, dvs, init = []) do w + mapreduce(vcat, enumerate(varmap.args[operation(w)]), init = []) do (j, x) + map(enumerate(varmap.args[operation(w)])) do (k, y) + if isequal(x, y) + nothing => nothing + else + n = length(arguments(w)) + @rule (Differential(x))((Differential(y))(w)) => + derivative(phi, + (cord_, θ_, phi_) -> + derivative(phi, u, y, + get_ϵ(n, k, eltypeθ, 2), 1, θ), + x, get_ε(n, j, eltypeθ, 2), 1, θ) + end + end + end + end + + return [rs; mx] end function generate_integral_rules(eq, eqdata, dummyvars) diff --git a/src/symbolic_utilities.jl b/src/symbolic_utilities.jl index b29bb61417..b439d82aea 100644 --- a/src/symbolic_utilities.jl +++ b/src/symbolic_utilities.jl @@ -1,5 +1,17 @@ using Base.Broadcast +function get_limits(domain) + if domain isa AbstractInterval + return [leftendpoint(domain)], [rightendpoint(domain)] + elseif domain isa ProductDomain + return collect(map(leftendpoint, DomainSets.components(domain))), + collect(map(rightendpoint, DomainSets.components(domain))) + end +end + +θ = gensym("θ") + + """ Override `Broadcast.__dot__` with `Broadcast.dottable(x::Function) = true` @@ -19,37 +31,6 @@ julia> _dot_(e) dottable_(x) = Broadcast.dottable(x) dottable_(x::Function) = true -_dot_(x) = x -function _dot_(x::Expr) - dotargs = Base.mapany(_dot_, x.args) - if x.head === :call && dottable_(x.args[1]) - Expr(:., dotargs[1], Expr(:tuple, dotargs[2:end]...)) - elseif x.head === :comparison - Expr(:comparison, - (iseven(i) && dottable_(arg) && arg isa Symbol && isoperator(arg) ? - Symbol('.', arg) : arg for (i, arg) in pairs(dotargs))...) - elseif x.head === :$ - x.args[1] - elseif x.head === :let # don't add dots to `let x=...` assignments - Expr(:let, undot(dotargs[1]), dotargs[2]) - elseif x.head === :for # don't add dots to for x=... assignments - Expr(:for, undot(dotargs[1]), dotargs[2]) - elseif (x.head === :(=) || x.head === :function || x.head === :macro) && - Meta.isexpr(x.args[1], :call) # function or macro definition - Expr(x.head, x.args[1], dotargs[2]) - elseif x.head === :(<:) || x.head === :(>:) - tmp = x.head === :(<:) ? :.<: : :.>: - Expr(:call, tmp, dotargs...) - else - head = String(x.head)::String - if last(head) == '=' && first(head) != '.' || head == "&&" || head == "||" - Expr(Symbol('.', head), dotargs...) - else - Expr(x.head, dotargs...) - end - end -end - """ Create dictionary: variable => unique number for variable diff --git a/src/training_strategies.jl b/src/training_strategies.jl index e13ae178cc..e2e696a61e 100644 --- a/src/training_strategies.jl +++ b/src/training_strategies.jl @@ -170,14 +170,14 @@ function merge_strategy_with_loss_function(pinnrep::PINNRepresentation, strategy) pde_bounds, bcs_bounds = bounds - pde_loss_functions = [get_loss_function(_loss, bound, eltypeθ, varmap, strategy) + pde_loss_functions = [get_loss_function(_loss, bound, eltypeθ, strategy) for (_loss, bound) in zip(datafree_pde_loss_function, pde_bounds)] strategy_ = QuasiRandomTraining(strategy.bcs_points; sampling_alg = strategy.sampling_alg, resampling = strategy.resampling, minibatch = strategy.minibatch) - bc_loss_functions = [get_loss_function(_loss, bound, eltypeθ, varmap, strategy_) + bc_loss_functions = [get_loss_function(_loss, bound, eltypeθ, strategy_) for (_loss, bound) in zip(datafree_bc_loss_function, bcs_bounds)] pde_loss_functions, bc_loss_functions @@ -263,10 +263,10 @@ function merge_strategy_with_loss_function(pinnrep::PINNRepresentation, pde_bounds, bcs_bounds = bounds lbs, ubs = pde_bounds - pde_loss_functions = [get_loss_function(_loss, lb, ub, eltypeθ, varmap, strategy) + pde_loss_functions = [get_loss_function(_loss, lb, ub, eltypeθ, strategy) for (_loss, lb, ub) in zip(datafree_pde_loss_function, lbs, ubs)] lbs, ubs = bcs_bounds - bc_loss_functions = [get_loss_function(_loss, lb, ub, eltypeθ, varmap, strategy) + bc_loss_functions = [get_loss_function(_loss, lb, ub, eltypeθ, strategy) for (_loss, lb, ub) in zip(datafree_bc_loss_function, lbs, ubs)] pde_loss_functions, bc_loss_functions @@ -318,7 +318,7 @@ such that the total number of sampled points is equivalent to the given samples This training strategy can only be used with ODEs (`NNODE`). """ -struct WeightedIntervalTraining{T} <: AbstractGridfreeStrategyy +struct WeightedIntervalTraining{T} <: AbstractGridfreeStrategy weights::Vector{T} samples::Int end From 18c34490a19bdc77c4cd8e10f5b29edc8fa45063 Mon Sep 17 00:00:00 2001 From: xtalax Date: Wed, 7 Jun 2023 15:26:22 +0100 Subject: [PATCH 15/40] fix the tests a bit more --- src/NeuralPDE.jl | 2 +- src/loss_function_generation.jl | 52 +++++++++++++++++++-------------- src/pinn_types.jl | 3 +- 3 files changed, 32 insertions(+), 25 deletions(-) diff --git a/src/NeuralPDE.jl b/src/NeuralPDE.jl index 288115847d..3e0d3fe1b5 100644 --- a/src/NeuralPDE.jl +++ b/src/NeuralPDE.jl @@ -25,7 +25,7 @@ using Symbolics using Symbolics: wrap, unwrap, arguments, operation using SymbolicUtils using SymbolicUtils.Code -using SymbolicUtils: Postwalk, Chain +using SymbolicUtils: Prewalk, Postwalk, Chain import ModelingToolkit: value, nameof, toexpr, build_expr, expand_derivatives import DomainSets: Domain, ClosedInterval import ModelingToolkit: Interval, infimum, supremum #,Ball diff --git a/src/loss_function_generation.jl b/src/loss_function_generation.jl index d073a99a13..e3d881575a 100644 --- a/src/loss_function_generation.jl +++ b/src/loss_function_generation.jl @@ -49,9 +49,9 @@ function build_symbolic_loss_function(pinnrep::PINNRepresentation, eq; end function get_coords(cord) - map(enumerate(eq_args)) do (i, x) + mapreduce(vcat, enumerate(eq_args)) do (i, x) if x isa Number - fill(x, size(cord[[1], :])) + fill(convert(eltypeθ, x), size(cord[[1], :])) else cord[[i], :] end @@ -96,21 +96,26 @@ function parse_equation(pinnrep::PINNRepresentation, term, ivs; is_integral = fa ex_vars = get_depvars(term, varmap.depvar_ops) ignore = vcat(operation.(ex_vars), getindex, Differential, Integral, ~) - ex_ops = operations(term) - ex_ops = filter(x -> !any(isequal(x), ignore), ex_ops) - op_rules = [@rule $(op)(~~a) => broadcast(op, ~a...) for op in ex_ops] - dummyvars = @variables phi, u, θ + dummyvars = @variables phi, u(..), θ, coord deriv_rules = generate_derivative_rules(term, eqdata, eltypeθ, dummyvars, derivative, varmap) - @show deriv_rules - expr = substitute(term, deriv_rules) + ch = Prewalk(Chain(deriv_rules)) + + expr = ch(term) + + # Broadcast + ex_ops = operations(expr) + ex_ops = filter(x -> !any(isequal(x), ignore), ex_ops) + op_rules = [@rule $(op)(~~a) => broadcast(op, ~a...) for op in ex_ops] + dotch = Prewalk(Chain(op_rules)) + expr = dotch(expr) sym_coords = DestructuredArgs(ivs) ps = DestructuredArgs(varmap.ps) - args = [sym_coords, θ, phi, u, ps] + args = [coord, θ, phi, u, ps] ex = Func(args, [], expr) |> toexpr @show ex @@ -126,39 +131,42 @@ function get_ε(dim::Int, der_num::Int, ::Type{eltypeθ}, order) where {eltypeθ end function generate_derivative_rules(term, eqdata, eltypeθ, dummyvars, derivative, varmap) - phi, u, θ = dummyvars + phi, u, θ, coord = dummyvars dvs = depvars(term, eqdata) @show dvs # Orthodox derivatives - rs = [reduce(vcat, [reduce(vcat, [(Differential(x)^d)(w) => + rs = reduce(vcat, [reduce(vcat, [[@rule $((Differential(x)^d)(w)) => derivative(phi, - u, x, - get_ε(length(arguments(w)), - j, eltypeθ, d), + u, coord, + [get_ε(length(arguments(w)), + j, eltypeθ, d)], d, θ) - for d in differential_orders(term, x)], init = []) + for d in differential_order(term, x)] for (j, x) in enumerate(varmap.args[operation(w)])], init = []) - for w in dvs] + for w in dvs], init = []) # Mixed derivatives mx = mapreduce(vcat, dvs, init = []) do w mapreduce(vcat, enumerate(varmap.args[operation(w)]), init = []) do (j, x) map(enumerate(varmap.args[operation(w)])) do (k, y) if isequal(x, y) - nothing => nothing + (_) -> nothing else n = length(arguments(w)) - @rule (Differential(x))((Differential(y))(w)) => + @rule $((Differential(x))((Differential(y))(w))) => derivative(phi, (cord_, θ_, phi_) -> - derivative(phi, u, y, - get_ϵ(n, k, eltypeθ, 2), 1, θ), - x, get_ε(n, j, eltypeθ, 2), 1, θ) + derivative(phi_, u, cord_, + [get_ϵ(n, k, eltypeθ, 2)], 1, θ_), + coord, [get_ε(n, j, eltypeθ, 2)], 1, θ) end end end end + vr = mapreduce(vcat, dvs, init = []) do w + @rule w => u(coord, θ, phi) + end - return [rs; mx] + return [mx; rs; vr] end function generate_integral_rules(eq, eqdata, dummyvars) diff --git a/src/pinn_types.jl b/src/pinn_types.jl index faedb016ac..0c9d20f20f 100644 --- a/src/pinn_types.jl +++ b/src/pinn_types.jl @@ -351,10 +351,9 @@ function get_u() end # the method to calculate the derivative -function numeric_derivative(phi, u, x, εs, order, θ) +function numeric_derivative(phi, u, x, ε, order, θ) _type = parameterless_type(ComponentArrays.getdata(θ)) - ε = εs[order] _epsilon = inv(first(ε[ε.!=zero(ε)])) ε = adapt(_type, ε) From 82191e753d985c0f66bad0ce1915433904d175e1 Mon Sep 17 00:00:00 2001 From: xtalax Date: Wed, 7 Jun 2023 17:29:17 +0100 Subject: [PATCH 16/40] remove transform expression --- src/symbolic_utilities.jl | 161 -------------------------------------- 1 file changed, 161 deletions(-) diff --git a/src/symbolic_utilities.jl b/src/symbolic_utilities.jl index b439d82aea..b9fed4c216 100644 --- a/src/symbolic_utilities.jl +++ b/src/symbolic_utilities.jl @@ -95,167 +95,6 @@ where order - order of derivative θ - weight in neural network """ -function _transform_expression(pinnrep::PINNRepresentation, ex; is_integral = false, - dict_transformation_vars = nothing, - transformation_vars = nothing) - @unpack indvars, depvars, dict_indvars, dict_depvars, - dict_depvar_input, multioutput, strategy, phi, - derivative, integral, flat_init_params, init_params = pinnrep - eltypeθ = eltype(flat_init_params) - - _args = ex.args - for (i, e) in enumerate(_args) - if !(e isa Expr) - if e in keys(dict_depvars) - depvar = _args[1] - num_depvar = dict_depvars[depvar] - indvars = _args[2:end] - var_ = is_integral ? :(u) : :($(Expr(:$, :u))) - ex.args = if !multioutput - [var_, Symbol(:cord, num_depvar), :($θ), :phi] - else - [ - var_, - Symbol(:cord, num_depvar), - Symbol(:($θ), num_depvar), - Symbol(:phi, num_depvar), - ] - end - break - elseif e isa ModelingToolkit.Differential - derivative_variables = Symbol[] - order = 0 - while (_args[1] isa ModelingToolkit.Differential) - order += 1 - push!(derivative_variables, toexpr(_args[1].x)) - _args = _args[2].args - end - depvar = _args[1] - num_depvar = dict_depvars[depvar] - indvars = _args[2:end] - dict_interior_indvars = Dict([indvar .=> j - for (j, indvar) in enumerate(dict_depvar_input[depvar])]) - dim_l = length(dict_interior_indvars) - - var_ = is_integral ? :(derivative) : :($(Expr(:$, :derivative))) - εs = [get_ε(dim_l, d, eltypeθ, order) for d in 1:dim_l] - undv = [dict_interior_indvars[d_p] for d_p in derivative_variables] - εs_dnv = [εs[d] for d in undv] - - ex.args = if !multioutput - [var_, :phi, :u, Symbol(:cord, num_depvar), εs_dnv, order, :($θ)] - else - [ - var_, - Symbol(:phi, num_depvar), - :u, - Symbol(:cord, num_depvar), - εs_dnv, - order, - Symbol(:($θ), num_depvar), - ] - end - break - elseif e isa Symbolics.Integral - if _args[1].domain.variables isa Tuple - integrating_variable_ = collect(_args[1].domain.variables) - integrating_variable = toexpr.(integrating_variable_) - integrating_var_id = [dict_indvars[i] for i in integrating_variable] - else - integrating_variable = toexpr(_args[1].domain.variables) - integrating_var_id = [dict_indvars[integrating_variable]] - end - - integrating_depvars = [] - integrand_expr = _args[2] - for d in depvars - d_ex = find_thing_in_expr(integrand_expr, d) - if !isempty(d_ex) - push!(integrating_depvars, d_ex[1].args[1]) - end - end - - lb, ub = get_limits(_args[1].domain.domain) - lb, ub, _args[2], dict_transformation_vars, transformation_vars = transform_inf_integral(lb, - ub, - _args[2], - integrating_depvars, - dict_depvar_input, - dict_depvars, - integrating_variable, - eltypeθ) - - num_depvar = map(int_depvar -> dict_depvars[int_depvar], - integrating_depvars) - integrand_ = transform_expression(pinnrep, _args[2]; - is_integral = false, - dict_transformation_vars = dict_transformation_vars, - transformation_vars = transformation_vars) - integrand__ = _dot_(integrand_) - - integrand = build_symbolic_loss_function(pinnrep, nothing; - integrand = integrand__, - integrating_depvars = integrating_depvars, - eq_params = SciMLBase.NullParameters(), - dict_transformation_vars = dict_transformation_vars, - transformation_vars = transformation_vars, - param_estim = false, - default_p = nothing) - # integrand = repr(integrand) - lb = toexpr.(lb) - ub = toexpr.(ub) - ub_ = [] - lb_ = [] - for l in lb - if l isa Number - push!(lb_, l) - else - l_expr = NeuralPDE.build_symbolic_loss_function(pinnrep, nothing; - integrand = _dot_(l), - integrating_depvars = integrating_depvars, - param_estim = false, - default_p = nothing) - l_f = @RuntimeGeneratedFunction(l_expr) - push!(lb_, l_f) - end - end - for u_ in ub - if u_ isa Number - push!(ub_, u_) - else - u_expr = NeuralPDE.build_symbolic_loss_function(pinnrep, nothing; - integrand = _dot_(u_), - integrating_depvars = integrating_depvars, - param_estim = false, - default_p = nothing) - u_f = @RuntimeGeneratedFunction(u_expr) - push!(ub_, u_f) - end - end - - integrand_func = @RuntimeGeneratedFunction(integrand) - ex.args = [ - :($(Expr(:$, :integral))), - :u, - Symbol(:cord, num_depvar[1]), - :phi, - integrating_var_id, - integrand_func, - lb_, - ub_, - :($θ), - ] - break - end - else - ex.args[i] = _transform_expression(pinnrep, ex.args[i]; - is_integral = is_integral, - dict_transformation_vars = dict_transformation_vars, - transformation_vars = transformation_vars) - end - end - return ex -end """ Parse ModelingToolkit equation form to the inner representation. From 70f1ce3638ca52a07b88fd8fa8ed35e5e779ca70 Mon Sep 17 00:00:00 2001 From: xtalax Date: Thu, 8 Jun 2023 13:55:35 +0100 Subject: [PATCH 17/40] reinstate dot, closer --- src/discretize.jl | 8 ++++++-- src/eq_data.jl | 1 - src/loss_function_generation.jl | 35 +++++++++++---------------------- src/pinn_types.jl | 5 +++-- src/symbolic_utilities.jl | 32 +++++++++++++++++++++++++++++- 5 files changed, 52 insertions(+), 29 deletions(-) diff --git a/src/discretize.jl b/src/discretize.jl index 82981b0d51..6f76cc7f84 100644 --- a/src/discretize.jl +++ b/src/discretize.jl @@ -15,7 +15,7 @@ function generate_training_sets(domains, dx, eqs, bcs, eltypeθ, varmap) else dxs = fill(dx, length(domains)) end - + @show dxs spans = [infimum(d.domain):dx:supremum(d.domain) for (d, dx) in zip(domains, dxs)] dict_var_span = Dict([d.variables => infimum(d.domain):dx:supremum(d.domain) for (d, dx) in zip(domains, dxs)]) @@ -73,6 +73,7 @@ function get_bounds(domains, eqs, bcs, eltypeθ, v::VariableMap, strategy::Abstr dict_lower_bound = Dict([d.variables => infimum(d.domain) for d in domains]) dict_upper_bound = Dict([d.variables => supremum(d.domain) for d in domains]) pde_args = get_argument(eqs, v) + @show pde_args pde_lower_bounds = map(pde_args) do pd span = map(p -> get(dict_lower_bound, p, p), pd) @@ -85,6 +86,7 @@ function get_bounds(domains, eqs, bcs, eltypeθ, v::VariableMap, strategy::Abstr pde_bounds = [pde_lower_bounds, pde_upper_bounds] bound_vars = get_variables(bcs, v) + @show bound_vars bcs_lower_bounds = map(bound_vars) do bt map(b -> dict_lower_bound[b], bt) @@ -93,7 +95,7 @@ function get_bounds(domains, eqs, bcs, eltypeθ, v::VariableMap, strategy::Abstr map(b -> dict_upper_bound[b], bt) end bcs_bounds = [bcs_lower_bounds, bcs_upper_bounds] - + @show bcs_bounds pde_bounds [pde_bounds, bcs_bounds] end # TODO: Get this to work with varmap @@ -268,6 +270,8 @@ function SciMLBase.symbolic_discretize(pdesys::PDESystem, adaloss = NonAdaptiveLoss{eltypeθ}() end + eqs = map(eq -> eq.lhs, eqs) + bcs = map(bc -> bc.lhs, bcs) pinnrep = PINNRepresentation(eqs, bcs, domains, eq_params, defaults, default_p, param_estim, additional_loss, adaloss, v, logger, diff --git a/src/eq_data.jl b/src/eq_data.jl index 5108eecd08..d94e4345d5 100644 --- a/src/eq_data.jl +++ b/src/eq_data.jl @@ -58,7 +58,6 @@ end argument(eq, eqdata) = eqdata.argmap[eq] - function get_iv_argument(eqs, v::VariableMap) vars = map(eqs) do eq _vars = map(depvar -> get_depvars(eq, [depvar]), v.depvar_ops) diff --git a/src/loss_function_generation.jl b/src/loss_function_generation.jl index e3d881575a..8678008d5b 100644 --- a/src/loss_function_generation.jl +++ b/src/loss_function_generation.jl @@ -49,11 +49,13 @@ function build_symbolic_loss_function(pinnrep::PINNRepresentation, eq; end function get_coords(cord) + num_numbers = 0 mapreduce(vcat, enumerate(eq_args)) do (i, x) if x isa Number + num_numbers += 1 fill(convert(eltypeθ, x), size(cord[[1], :])) else - cord[[i], :] + cord[[i-num_numbers], :] end end end @@ -97,39 +99,26 @@ function parse_equation(pinnrep::PINNRepresentation, term, ivs; is_integral = fa ex_vars = get_depvars(term, varmap.depvar_ops) ignore = vcat(operation.(ex_vars), getindex, Differential, Integral, ~) - dummyvars = @variables phi, u(..), θ, coord + dummyvars = @variables phi, u(..), θ_SYMBOL, coord + dummyvars = unwrap.(dummyvars) deriv_rules = generate_derivative_rules(term, eqdata, eltypeθ, dummyvars, derivative, varmap) ch = Prewalk(Chain(deriv_rules)) expr = ch(term) - # Broadcast - ex_ops = operations(expr) - ex_ops = filter(x -> !any(isequal(x), ignore), ex_ops) - op_rules = [@rule $(op)(~~a) => broadcast(op, ~a...) for op in ex_ops] - dotch = Prewalk(Chain(op_rules)) - expr = dotch(expr) - sym_coords = DestructuredArgs(ivs) ps = DestructuredArgs(varmap.ps) - args = [coord, θ, phi, u, ps] + args = [coord, θ_SYMBOL, phi, u, ps] - ex = Func(args, [], expr) |> toexpr + ex = Func(args, [], expr) |> toexpr |> _dot_ @show ex f = @RuntimeGeneratedFunction ex return f end -function get_ε(dim::Int, der_num::Int, ::Type{eltypeθ}, order) where {eltypeθ} - epsilon = ^(eps(eltypeθ), one(eltypeθ) / (2 + order)) - ε = zeros(eltypeθ, dim) - ε[der_num] = epsilon - ε -end - function generate_derivative_rules(term, eqdata, eltypeθ, dummyvars, derivative, varmap) phi, u, θ, coord = dummyvars dvs = depvars(term, eqdata) @@ -139,7 +128,7 @@ function generate_derivative_rules(term, eqdata, eltypeθ, dummyvars, derivative derivative(phi, u, coord, [get_ε(length(arguments(w)), - j, eltypeθ, d)], + j, eltypeθ, i) for i in 1:d], d, θ) for d in differential_order(term, x)] for (j, x) in enumerate(varmap.args[operation(w)])], init = []) @@ -147,17 +136,17 @@ function generate_derivative_rules(term, eqdata, eltypeθ, dummyvars, derivative # Mixed derivatives mx = mapreduce(vcat, dvs, init = []) do w mapreduce(vcat, enumerate(varmap.args[operation(w)]), init = []) do (j, x) - map(enumerate(varmap.args[operation(w)])) do (k, y) + mapreduce(vcat, enumerate(varmap.args[operation(w)]), init = []) do (k, y) if isequal(x, y) (_) -> nothing else n = length(arguments(w)) - @rule $((Differential(x))((Differential(y))(w))) => + [@rule $((Differential(x))((Differential(y))(w))) => derivative(phi, (cord_, θ_, phi_) -> derivative(phi_, u, cord_, - [get_ϵ(n, k, eltypeθ, 2)], 1, θ_), - coord, [get_ε(n, j, eltypeθ, 2)], 1, θ) + [get_ϵ(n, k, eltypeθ, i) for i in 1:2], 1, θ_), + coord, [get_ε(n, j, eltypeθ, 2)], 1, θ)] end end end diff --git a/src/pinn_types.jl b/src/pinn_types.jl index 0c9d20f20f..2cedf3a82a 100644 --- a/src/pinn_types.jl +++ b/src/pinn_types.jl @@ -347,13 +347,14 @@ function (f::Phi{<:Optimisers.Restructure})(x, θ) end function get_u() - u = (cord, θ, phi) -> phi(cord, θ) + u = (cord, θ, phi) -> phi.(cord, (θ,)) end # the method to calculate the derivative -function numeric_derivative(phi, u, x, ε, order, θ) +function numeric_derivative(phi, u, x, εs, order, θ) _type = parameterless_type(ComponentArrays.getdata(θ)) + ε = εs[order] _epsilon = inv(first(ε[ε.!=zero(ε)])) ε = adapt(_type, ε) diff --git a/src/symbolic_utilities.jl b/src/symbolic_utilities.jl index b9fed4c216..16dc66f6d8 100644 --- a/src/symbolic_utilities.jl +++ b/src/symbolic_utilities.jl @@ -11,7 +11,6 @@ end θ = gensym("θ") - """ Override `Broadcast.__dot__` with `Broadcast.dottable(x::Function) = true` @@ -31,6 +30,36 @@ julia> _dot_(e) dottable_(x) = Broadcast.dottable(x) dottable_(x::Function) = true +_dot_(x) = x +function _dot_(x::Expr) + dotargs = Base.mapany(_dot_, x.args) + if x.head === :call && dottable_(x.args[1]) + Expr(:., dotargs[1], Expr(:tuple, dotargs[2:end]...)) + elseif x.head === :comparison + Expr(:comparison, + (iseven(i) && dottable_(arg) && arg isa Symbol && isoperator(arg) ? + Symbol('.', arg) : arg for (i, arg) in pairs(dotargs))...) + elseif x.head === :$ + x.args[1] + elseif x.head === :let # don't add dots to `let x=...` assignments + Expr(:let, undot(dotargs[1]), dotargs[2]) + elseif x.head === :for # don't add dots to for x=... assignments + Expr(:for, undot(dotargs[1]), dotargs[2]) + elseif (x.head === :(=) || x.head === :function || x.head === :macro) && + Meta.isexpr(x.args[1], :call) # function or macro definition + Expr(x.head, x.args[1], dotargs[2]) + elseif x.head === :(<:) || x.head === :(>:) + tmp = x.head === :(<:) ? :.<: : :.>: + Expr(:call, tmp, dotargs...) + else + head = String(x.head)::String + if last(head) == '=' && first(head) != '.' || head == "&&" || head == "||" + Expr(Symbol('.', head), dotargs...) + else + Expr(x.head, dotargs...) + end + end +end """ Create dictionary: variable => unique number for variable @@ -182,6 +211,7 @@ function get_argument(eqs, v::VariableMap) f_vars = filter(x -> !isempty(x), _vars) map(first, f_vars) end + @show vars args_ = map(vars) do _vars seen = [] filter(reduce(vcat, arguments.(_vars), init = [])) do x From b3967aa869be7628b38b8b9e70767ff5c0797691 Mon Sep 17 00:00:00 2001 From: xtalax Date: Thu, 15 Jun 2023 16:39:58 +0100 Subject: [PATCH 18/40] last confusing errors? --- src/eq_data.jl | 1 + src/loss_function_generation.jl | 46 ++++--- src/pinn_types.jl | 15 ++- src/symbolic_utilities.jl | 10 +- test/NNPDE_tests.jl | 228 ++++++++++++++++---------------- 5 files changed, 161 insertions(+), 139 deletions(-) diff --git a/src/eq_data.jl b/src/eq_data.jl index d94e4345d5..b7f304d673 100644 --- a/src/eq_data.jl +++ b/src/eq_data.jl @@ -62,6 +62,7 @@ function get_iv_argument(eqs, v::VariableMap) vars = map(eqs) do eq _vars = map(depvar -> get_depvars(eq, [depvar]), v.depvar_ops) f_vars = filter(x -> !isempty(x), _vars) + @show v.ū mapreduce(vars -> mapreduce(op -> v.args[op], vcat, operation.(vars), init = []), vcat, f_vars, init = []) end args_ = map(vars) do _vars diff --git a/src/loss_function_generation.jl b/src/loss_function_generation.jl index 8678008d5b..0be31500e3 100644 --- a/src/loss_function_generation.jl +++ b/src/loss_function_generation.jl @@ -50,18 +50,27 @@ function build_symbolic_loss_function(pinnrep::PINNRepresentation, eq; function get_coords(cord) num_numbers = 0 - mapreduce(vcat, enumerate(eq_args)) do (i, x) + out = map(enumerate(eq_args)) do (i, x) if x isa Number num_numbers += 1 - fill(convert(eltypeθ, x), size(cord[[1], :])) + fill(convert(eltypeθ, x), length(cord[[1], :])) else cord[[i-num_numbers], :] end end + if out === nothing + return [] + else + return out + end end - full_loss_func = (cord, θ, phi, u, p) -> begin - loss_function(get_coords(cord), θ, phi, u, get_ps(θ)) + full_loss_func = (cord, θ, phi, p) -> begin + coords = get_coords(cord) + @show coords + combinedcoords = reduce(vcat, coords, init = []) + @show combinedcoords + loss_function(coords, combinedcoords, θ, phi, get_ps(θ)) end return full_loss_func end @@ -72,8 +81,7 @@ function build_loss_function(pinnrep, eqs) _loss_function = build_symbolic_loss_function(pinnrep, eqs, eq_params = eq_params, param_estim = param_estim) - u = get_u() - loss_function = (cord, θ) -> begin _loss_function(cord, θ, phi, u, + loss_function = (cord, θ) -> begin _loss_function(cord, θ, phi, default_p) end return loss_function end @@ -99,7 +107,7 @@ function parse_equation(pinnrep::PINNRepresentation, term, ivs; is_integral = fa ex_vars = get_depvars(term, varmap.depvar_ops) ignore = vcat(operation.(ex_vars), getindex, Differential, Integral, ~) - dummyvars = @variables phi, u(..), θ_SYMBOL, coord + dummyvars = @variables phi(..), θ_SYMBOL, coord dummyvars = unwrap.(dummyvars) deriv_rules = generate_derivative_rules(term, eqdata, eltypeθ, dummyvars, derivative, varmap) @@ -111,7 +119,7 @@ function parse_equation(pinnrep::PINNRepresentation, term, ivs; is_integral = fa ps = DestructuredArgs(varmap.ps) - args = [coord, θ_SYMBOL, phi, u, ps] + args = [sym_coords, coord, θ_SYMBOL, phi, ps] ex = Func(args, [], expr) |> toexpr |> _dot_ @show ex @@ -120,14 +128,15 @@ function parse_equation(pinnrep::PINNRepresentation, term, ivs; is_integral = fa end function generate_derivative_rules(term, eqdata, eltypeθ, dummyvars, derivative, varmap) - phi, u, θ, coord = dummyvars - dvs = depvars(term, eqdata) + phi, θ, coord = dummyvars + dvs = get_depvars(term, varmap.depvar_ops) @show dvs # Orthodox derivatives + n(w) = length(arguments(w)) rs = reduce(vcat, [reduce(vcat, [[@rule $((Differential(x)^d)(w)) => derivative(phi, - u, coord, - [get_ε(length(arguments(w)), + ufunc, coord, + [get_ε(n(w), j, eltypeθ, i) for i in 1:d], d, θ) for d in differential_order(term, x)] @@ -138,21 +147,22 @@ function generate_derivative_rules(term, eqdata, eltypeθ, dummyvars, derivative mapreduce(vcat, enumerate(varmap.args[operation(w)]), init = []) do (j, x) mapreduce(vcat, enumerate(varmap.args[operation(w)]), init = []) do (k, y) if isequal(x, y) - (_) -> nothing + [(_) -> nothing] else - n = length(arguments(w)) + ε1 = [get_ε(n(w), j, eltypeθ, i) for i in 1:2] + ε2 = [get_ε(n(w), k, eltypeθ, i) for i in 1:2] [@rule $((Differential(x))((Differential(y))(w))) => derivative(phi, (cord_, θ_, phi_) -> - derivative(phi_, u, cord_, - [get_ϵ(n, k, eltypeθ, i) for i in 1:2], 1, θ_), - coord, [get_ε(n, j, eltypeθ, 2)], 1, θ)] + derivative(phi_, ufunc, cord_, + ε2, 1, θ_), + coord, ε1, 1, θ)] end end end end vr = mapreduce(vcat, dvs, init = []) do w - @rule w => u(coord, θ, phi) + @rule w => ufunc(coord, θ, phi) end return [mx; rs; vr] diff --git a/src/pinn_types.jl b/src/pinn_types.jl index 2cedf3a82a..281e562f58 100644 --- a/src/pinn_types.jl +++ b/src/pinn_types.jl @@ -346,9 +346,9 @@ function (f::Phi{<:Optimisers.Restructure})(x, θ) f.f(θ)(adapt(parameterless_type(θ), x)) end -function get_u() - u = (cord, θ, phi) -> phi.(cord, (θ,)) -end +ufunc(cord, θ, phi) = phi(cord, θ) +@register_symbolic ufunc(cord, θ, phi) + # the method to calculate the derivative function numeric_derivative(phi, u, x, εs, order, θ) @@ -365,27 +365,32 @@ function numeric_derivative(phi, u, x, εs, order, θ) # if order 1, this is trivially true if order > 4 || any(x -> x != εs[1], εs) + @show "me" return (numeric_derivative(phi, u, x .+ ε, @view(εs[1:(end-1)]), order - 1, θ) .- numeric_derivative(phi, u, x .- ε, @view(εs[1:(end-1)]), order - 1, θ)) .* _epsilon ./ 2 elseif order == 4 + @show "me4" return (u(x .+ 2 .* ε, θ, phi) .- 4 .* u(x .+ ε, θ, phi) .+ 6 .* u(x, θ, phi) .- 4 .* u(x .- ε, θ, phi) .+ u(x .- 2 .* ε, θ, phi)) .* _epsilon^4 elseif order == 3 + @show "me3" return (u(x .+ 2 .* ε, θ, phi) .- 2 .* u(x .+ ε, θ, phi) .+ 2 .* u(x .- ε, θ, phi) - u(x .- 2 .* ε, θ, phi)) .* _epsilon^3 ./ 2 elseif order == 2 + @show "me2" return (u(x .+ ε, θ, phi) .+ u(x .- ε, θ, phi) .- 2 .* u(x, θ, phi)) .* _epsilon^2 elseif order == 1 + @show "me1" return (u(x .+ ε, θ, phi) .- u(x .- ε, θ, phi)) .* _epsilon ./ 2 else error("This shouldn't happen! Got an order of $(order).") end end - -@register_symbolic numeric_derivative(phi, u, coord, εs, order, θ) +# Hacky workaround for metaprogramming with symbolics +@register_symbolic(numeric_derivative(phi, u, x, εs, order, θ), true, [], true) diff --git a/src/symbolic_utilities.jl b/src/symbolic_utilities.jl index 16dc66f6d8..15cbb5e5ca 100644 --- a/src/symbolic_utilities.jl +++ b/src/symbolic_utilities.jl @@ -29,11 +29,15 @@ julia> _dot_(e) """ dottable_(x) = Broadcast.dottable(x) dottable_(x::Function) = true +dottable_(x::typeof(numeric_derivative)) = false +dottable_(x::Phi) = false + _dot_(x) = x function _dot_(x::Expr) dotargs = Base.mapany(_dot_, x.args) - if x.head === :call && dottable_(x.args[1]) + nodot = [:phi, Symbol("NeuralPDE.numeric_derivative")] + if x.head === :call && dottable_(x.args[1]) && all(s -> x.args[1] !== s, nodot) Expr(:., dotargs[1], Expr(:tuple, dotargs[2:end]...)) elseif x.head === :comparison Expr(:comparison, @@ -45,7 +49,9 @@ function _dot_(x::Expr) Expr(:let, undot(dotargs[1]), dotargs[2]) elseif x.head === :for # don't add dots to for x=... assignments Expr(:for, undot(dotargs[1]), dotargs[2]) - elseif (x.head === :(=) || x.head === :function || x.head === :macro) && + elseif x.head === :(=) # don't add dots to x=... assignments + Expr(:(=), dotargs[1], dotargs[2]) + elseif (x.head === :function || x.head === :macro) && Meta.isexpr(x.args[1], :call) # function or macro definition Expr(x.head, x.args[1], dotargs[2]) elseif x.head === :(<:) || x.head === :(>:) diff --git a/test/NNPDE_tests.jl b/test/NNPDE_tests.jl index 7303ebaa3e..40a47bae8d 100644 --- a/test/NNPDE_tests.jl +++ b/test/NNPDE_tests.jl @@ -435,117 +435,117 @@ end # plot(p1,p2) end -## Example 5, 2d wave equation, neumann boundary condition -@testset "Example 5, 2d wave equation, neumann boundary condition" begin - #here we use low level api for build solution - @parameters x, t - @variables u(..) - Dxx = Differential(x)^2 - Dtt = Differential(t)^2 - Dt = Differential(t) - - #2D PDE - C = 1 - eq = Dtt(u(x, t)) ~ C^2 * Dxx(u(x, t)) - - # Initial and boundary conditions - bcs = [u(0, t) ~ 0.0,# for all t > 0 - u(1, t) ~ 0.0,# for all t > 0 - u(x, 0) ~ x * (1.0 - x), #for all 0 < x < 1 - Dt(u(x, 0)) ~ 0.0] #for all 0 < x < 1] - - # Space and time domains - domains = [x ∈ Interval(0.0, 1.0), - t ∈ Interval(0.0, 1.0)] - @named pde_system = PDESystem(eq, bcs, domains, [x, t], [u(x, t)]) - - # Neural network - chain = Lux.Chain(Lux.Dense(2, 16, Lux.σ), Lux.Dense(16, 16, Lux.σ), Lux.Dense(16, 1)) - phi = NeuralPDE.Phi(chain) - derivative = NeuralPDE.numeric_derivative - - quadrature_strategy = NeuralPDE.QuadratureTraining(quadrature_alg = CubatureJLh(), - reltol = 1e-3, abstol = 1e-3, - maxiters = 50, batch = 100) - - discretization = NeuralPDE.PhysicsInformedNN(chain, quadrature_strategy) - prob = NeuralPDE.discretize(pde_system, discretization) - - cb_ = function (p, l) - println("loss: ", l) - println("losses: ", map(l -> l(p), loss_functions)) - return false - end - - res = solve(prob, OptimizationOptimJL.BFGS(); maxiters = 500, f_abstol = 10^-6) - - dx = 0.1 - xs, ts = [infimum(d.domain):dx:supremum(d.domain) for d in domains] - function analytic_sol_func(x, t) - sum([(8 / (k^3 * pi^3)) * sin(k * pi * x) * cos(C * k * pi * t) for k in 1:2:50000]) - end - - u_predict = reshape([first(phi([x, t], res.u)) for x in xs for t in ts], - (length(xs), length(ts))) - u_real = reshape([analytic_sol_func(x, t) for x in xs for t in ts], - (length(xs), length(ts))) - - @test u_predict≈u_real atol=0.1 - - # diff_u = abs.(u_predict .- u_real) - # p1 = plot(xs, ts, u_real, linetype=:contourf,title = "analytic"); - # p2 =plot(xs, ts, u_predict, linetype=:contourf,title = "predict"); - # p3 = plot(xs, ts, diff_u,linetype=:contourf,title = "error"); - # plot(p1,p2,p3) -end -## Example 6, pde with mixed derivative -@testset "Example 6, pde with mixed derivative" begin - @parameters x y - @variables u(..) - Dxx = Differential(x)^2 - Dyy = Differential(y)^2 - Dx = Differential(x) - Dy = Differential(y) - - eq = Dxx(u(x, y)) + Dx(Dy(u(x, y))) - 2 * Dyy(u(x, y)) ~ -1.0 - - # Initial and boundary conditions - bcs = [u(x, 0) ~ x, - Dy(u(x, 0)) ~ x, - u(x, 0) ~ Dy(u(x, 0))] - - # Space and time domains - domains = [x ∈ Interval(0.0, 1.0), y ∈ Interval(0.0, 1.0)] - - quadrature_strategy = NeuralPDE.QuadratureTraining() - # Neural network - inner = 20 - chain = Lux.Chain(Lux.Dense(2, inner, Lux.tanh), Lux.Dense(inner, inner, Lux.tanh), - Lux.Dense(inner, 1)) - - discretization = NeuralPDE.PhysicsInformedNN(chain, quadrature_strategy) - @named pde_system = PDESystem(eq, bcs, domains, [x, y], [u(x, y)]) - - prob = NeuralPDE.discretize(pde_system, discretization) - - res = solve(prob, OptimizationOptimJL.BFGS(); maxiters = 1500) - @show res.original - - phi = discretization.phi - - analytic_sol_func(x, y) = x + x * y + y^2 / 2 - xs, ys = [infimum(d.domain):0.01:supremum(d.domain) for d in domains] - - u_predict = reshape([first(phi([x, y], res.u)) for x in xs for y in ys], - (length(xs), length(ys))) - u_real = reshape([analytic_sol_func(x, y) for x in xs for y in ys], - (length(xs), length(ys))) - diff_u = abs.(u_predict .- u_real) - - @test u_predict≈u_real rtol=0.1 - - # p1 = plot(xs, ys, u_real, linetype=:contourf,title = "analytic"); - # p2 = plot(xs, ys, u_predict, linetype=:contourf,title = "predict"); - # p3 = plot(xs, ys, diff_u,linetype=:contourf,title = "error"); - # plot(p1,p2,p3) -end +# ## Example 5, 2d wave equation, neumann boundary condition +# @testset "Example 5, 2d wave equation, neumann boundary condition" begin +# #here we use low level api for build solution +# @parameters x, t +# @variables u(..) +# Dxx = Differential(x)^2 +# Dtt = Differential(t)^2 +# Dt = Differential(t) + +# #2D PDE +# C = 1 +# eq = Dtt(u(x, t)) ~ C^2 * Dxx(u(x, t)) + +# # Initial and boundary conditions +# bcs = [u(0, t) ~ 0.0,# for all t > 0 +# u(1, t) ~ 0.0,# for all t > 0 +# u(x, 0) ~ x * (1.0 - x), #for all 0 < x < 1 +# Dt(u(x, 0)) ~ 0.0] #for all 0 < x < 1] + +# # Space and time domains +# domains = [x ∈ Interval(0.0, 1.0), +# t ∈ Interval(0.0, 1.0)] +# @named pde_system = PDESystem(eq, bcs, domains, [x, t], [u(x, t)]) + +# # Neural network +# chain = Lux.Chain(Lux.Dense(2, 16, Lux.σ), Lux.Dense(16, 16, Lux.σ), Lux.Dense(16, 1)) +# phi = NeuralPDE.Phi(chain) +# derivative = NeuralPDE.numeric_derivative + +# quadrature_strategy = NeuralPDE.QuadratureTraining(quadrature_alg = CubatureJLh(), +# reltol = 1e-3, abstol = 1e-3, +# maxiters = 50, batch = 100) + +# discretization = NeuralPDE.PhysicsInformedNN(chain, quadrature_strategy) +# prob = NeuralPDE.discretize(pde_system, discretization) + +# cb_ = function (p, l) +# println("loss: ", l) +# println("losses: ", map(l -> l(p), loss_functions)) +# return false +# end + +# res = solve(prob, OptimizationOptimJL.BFGS(); maxiters = 500, f_abstol = 10^-6) + +# dx = 0.1 +# xs, ts = [infimum(d.domain):dx:supremum(d.domain) for d in domains] +# function analytic_sol_func(x, t) +# sum([(8 / (k^3 * pi^3)) * sin(k * pi * x) * cos(C * k * pi * t) for k in 1:2:50000]) +# end + +# u_predict = reshape([first(phi([x, t], res.u)) for x in xs for t in ts], +# (length(xs), length(ts))) +# u_real = reshape([analytic_sol_func(x, t) for x in xs for t in ts], +# (length(xs), length(ts))) + +# @test u_predict≈u_real atol=0.1 + +# # diff_u = abs.(u_predict .- u_real) +# # p1 = plot(xs, ts, u_real, linetype=:contourf,title = "analytic"); +# # p2 =plot(xs, ts, u_predict, linetype=:contourf,title = "predict"); +# # p3 = plot(xs, ts, diff_u,linetype=:contourf,title = "error"); +# # plot(p1,p2,p3) +# end +# ## Example 6, pde with mixed derivative +# @testset "Example 6, pde with mixed derivative" begin +# @parameters x y +# @variables u(..) +# Dxx = Differential(x)^2 +# Dyy = Differential(y)^2 +# Dx = Differential(x) +# Dy = Differential(y) + +# eq = Dxx(u(x, y)) + Dx(Dy(u(x, y))) - 2 * Dyy(u(x, y)) ~ -1.0 + +# # Initial and boundary conditions +# bcs = [u(x, 0) ~ x, +# Dy(u(x, 0)) ~ x, +# u(x, 0) ~ Dy(u(x, 0))] + +# # Space and time domains +# domains = [x ∈ Interval(0.0, 1.0), y ∈ Interval(0.0, 1.0)] + +# quadrature_strategy = NeuralPDE.QuadratureTraining() +# # Neural network +# inner = 20 +# chain = Lux.Chain(Lux.Dense(2, inner, Lux.tanh), Lux.Dense(inner, inner, Lux.tanh), +# Lux.Dense(inner, 1)) + +# discretization = NeuralPDE.PhysicsInformedNN(chain, quadrature_strategy) +# @named pde_system = PDESystem(eq, bcs, domains, [x, y], [u(x, y)]) + +# prob = NeuralPDE.discretize(pde_system, discretization) + +# res = solve(prob, OptimizationOptimJL.BFGS(); maxiters = 1500) +# @show res.original + +# phi = discretization.phi + +# analytic_sol_func(x, y) = x + x * y + y^2 / 2 +# xs, ys = [infimum(d.domain):0.01:supremum(d.domain) for d in domains] + +# u_predict = reshape([first(phi([x, y], res.u)) for x in xs for y in ys], +# (length(xs), length(ys))) +# u_real = reshape([analytic_sol_func(x, y) for x in xs for y in ys], +# (length(xs), length(ys))) +# diff_u = abs.(u_predict .- u_real) + +# @test u_predict≈u_real rtol=0.1 + +# # p1 = plot(xs, ys, u_real, linetype=:contourf,title = "analytic"); +# # p2 = plot(xs, ys, u_predict, linetype=:contourf,title = "predict"); +# # p3 = plot(xs, ys, diff_u,linetype=:contourf,title = "error"); +# # plot(p1,p2,p3) +# end From 17d01503de887b323f860e8c66281be93d00ca30 Mon Sep 17 00:00:00 2001 From: xtalax Date: Fri, 16 Jun 2023 14:57:29 +0100 Subject: [PATCH 19/40] fix test --- src/loss_function_generation.jl | 2 +- test/NNPDE_tests.jl | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/loss_function_generation.jl b/src/loss_function_generation.jl index 0be31500e3..ce7f7aa6db 100644 --- a/src/loss_function_generation.jl +++ b/src/loss_function_generation.jl @@ -37,6 +37,7 @@ function build_symbolic_loss_function(pinnrep::PINNRepresentation, eq; (!(phi isa Vector) && phi.f isa Optimisers.Restructure) if psform + @show length(phi) last_indx = [0; accumulate(+, map(length, init_params))][end] ps_range = 1:param_len .+ last_indx get_ps = (θ) -> θ[ps_range] @@ -122,7 +123,6 @@ function parse_equation(pinnrep::PINNRepresentation, term, ivs; is_integral = fa args = [sym_coords, coord, θ_SYMBOL, phi, ps] ex = Func(args, [], expr) |> toexpr |> _dot_ - @show ex f = @RuntimeGeneratedFunction ex return f end diff --git a/test/NNPDE_tests.jl b/test/NNPDE_tests.jl index 40a47bae8d..4d6a31c27b 100644 --- a/test/NNPDE_tests.jl +++ b/test/NNPDE_tests.jl @@ -331,7 +331,7 @@ end eq = Dx(Dxxu(x)) ~ cos(pi * x) # Initial and boundary conditions - bcs_ = [u(0.0) ~ 0.0, + bcs = [u(0.0) ~ 0.0, u(1.0) ~ cos(pi), Dxu(1.0) ~ 1.0] ep = (cbrt(eps(eltype(Float64))))^2 / 6 @@ -339,7 +339,7 @@ end der = [Dxu(x) ~ Dx(u(x)) + ep * O1(x), Dxxu(x) ~ Dx(Dxu(x)) + ep * O2(x)] - bcs = [bcs_; der] + eqs = [eq; der] # Space and time domains domains = [x ∈ Interval(0.0, 1.0)] @@ -352,7 +352,7 @@ end discretization = NeuralPDE.PhysicsInformedNN(chain, quasirandom_strategy) - @named pde_system = PDESystem(eq, bcs, domains, [x], + @named pde_system = PDESystem(eqs, bcs, domains, [x], [u(x), Dxu(x), Dxxu(x), O1(x), O2(x)]) prob = NeuralPDE.discretize(pde_system, discretization) From 0011392d4446eb7c71eea075ce179d2bc4e81a63 Mon Sep 17 00:00:00 2001 From: Alex Jones Date: Mon, 19 Jun 2023 13:43:25 +0100 Subject: [PATCH 20/40] add multioutput --- src/discretize.jl | 19 ++++++++++++++++++- src/loss_function_generation.jl | 19 ++++++++++--------- src/pinn_types.jl | 29 ++++++++++++++++------------- 3 files changed, 44 insertions(+), 23 deletions(-) diff --git a/src/discretize.jl b/src/discretize.jl index 6f76cc7f84..0b5cbc7260 100644 --- a/src/discretize.jl +++ b/src/discretize.jl @@ -264,6 +264,23 @@ function SciMLBase.symbolic_discretize(pdesys::PDESystem, phi.st) end + if multioutput + dvs = v.ū + acum = [0; accumulate(+, map(length, init_params))] + sep = [(acum[i] + 1):acum[i + 1] for i in 1:(length(acum) - 1)] + phimap = map(enumerate(dvs)) do (i, dv) + if (phi isa Vector && phi[1].f isa Optimisers.Restructure) || + (!(phi isa Vector) && phi.f isa Optimisers.Restructure) + # Flux.Chain + dv => (coord, expr_θ) -> phi[i](coord, expr_θ[sep[i]]) + else # Lux.AbstractExplicitLayer + dv => (coord, expr_θ) -> phi[i](coord, expr_θ.depvar.$(dv)) + end + end |> Dict + else + phimap = nothing + end + eltypeθ = eltype(flat_init_params) if adaloss === nothing @@ -276,7 +293,7 @@ function SciMLBase.symbolic_discretize(pdesys::PDESystem, pinnrep = PINNRepresentation(eqs, bcs, domains, eq_params, defaults, default_p, param_estim, additional_loss, adaloss, v, logger, multioutput, iteration, init_params, flat_init_params, phi, - derivative, + phimap, derivative, strategy, eqdata, nothing, nothing, nothing, nothing) #integral = get_numeric_integral(pinnrep) diff --git a/src/loss_function_generation.jl b/src/loss_function_generation.jl index ce7f7aa6db..928be8049c 100644 --- a/src/loss_function_generation.jl +++ b/src/loss_function_generation.jl @@ -8,7 +8,7 @@ function build_symbolic_loss_function(pinnrep::PINNRepresentation, eq; integrand = nothing, transformation_vars = nothing) @unpack varmap, eqdata, - phi, derivative, integral, + phi, phimap, derivative, integral, multioutput, init_params, strategy, eq_params, param_estim, default_p = pinnrep @@ -77,7 +77,11 @@ function build_symbolic_loss_function(pinnrep::PINNRepresentation, eq; end function build_loss_function(pinnrep, eqs) - @unpack eq_params, param_estim, default_p, phi, derivative, integral = pinnrep + @unpack eq_params, param_estim, default_p, phi, phimap, multioutput, derivative, integral = pinnrep + + if multioutput + phi = phimap + end _loss_function = build_symbolic_loss_function(pinnrep, eqs, eq_params = eq_params, @@ -134,8 +138,7 @@ function generate_derivative_rules(term, eqdata, eltypeθ, dummyvars, derivative # Orthodox derivatives n(w) = length(arguments(w)) rs = reduce(vcat, [reduce(vcat, [[@rule $((Differential(x)^d)(w)) => - derivative(phi, - ufunc, coord, + derivative(ufunc(w, coord, θ, phi), coord, [get_ε(n(w), j, eltypeθ, i) for i in 1:d], d, θ) @@ -152,17 +155,15 @@ function generate_derivative_rules(term, eqdata, eltypeθ, dummyvars, derivative ε1 = [get_ε(n(w), j, eltypeθ, i) for i in 1:2] ε2 = [get_ε(n(w), k, eltypeθ, i) for i in 1:2] [@rule $((Differential(x))((Differential(y))(w))) => - derivative(phi, - (cord_, θ_, phi_) -> - derivative(phi_, ufunc, cord_, - ε2, 1, θ_), + derivative((cord_, θ_) -> derivative(ufunc(w, coord, θ, phi), cord_, + ε2, 1, θ_), coord, ε1, 1, θ)] end end end end vr = mapreduce(vcat, dvs, init = []) do w - @rule w => ufunc(coord, θ, phi) + @rule w => ufunc(w, coord, θ, phi)(coord, θ) end return [mx; rs; vr] diff --git a/src/pinn_types.jl b/src/pinn_types.jl index 281e562f58..0fae251df3 100644 --- a/src/pinn_types.jl +++ b/src/pinn_types.jl @@ -239,6 +239,10 @@ mutable struct PINNRepresentation The representation of the test function of the PDE solution """ phi::Any + """ + the map of vars to chains + """ + phimap::Any """ The function used for computing the derivative """ @@ -346,12 +350,11 @@ function (f::Phi{<:Optimisers.Restructure})(x, θ) f.f(θ)(adapt(parameterless_type(θ), x)) end -ufunc(cord, θ, phi) = phi(cord, θ) -@register_symbolic ufunc(cord, θ, phi) +ufunc(u, cord, θ, phi) = phi isa Dict ? phi[u](cord, θ) : phi(cord, θ) # the method to calculate the derivative -function numeric_derivative(phi, u, x, εs, order, θ) +function numeric_derivative(phi, x, εs, order, θ) _type = parameterless_type(ComponentArrays.getdata(θ)) ε = εs[order] @@ -366,31 +369,31 @@ function numeric_derivative(phi, u, x, εs, order, θ) if order > 4 || any(x -> x != εs[1], εs) @show "me" - return (numeric_derivative(phi, u, x .+ ε, @view(εs[1:(end-1)]), order - 1, θ) + return (numeric_derivative(phi, x .+ ε, @view(εs[1:(end-1)]), order - 1, θ) .- - numeric_derivative(phi, u, x .- ε, @view(εs[1:(end-1)]), order - 1, θ)) .* + numeric_derivative(phi, x .- ε, @view(εs[1:(end-1)]), order - 1, θ)) .* _epsilon ./ 2 elseif order == 4 @show "me4" - return (u(x .+ 2 .* ε, θ, phi) .- 4 .* u(x .+ ε, θ, phi) + return (phi(x .+ 2 .* ε, θ) .- 4 .* phi(x .+ ε, θ) .+ - 6 .* u(x, θ, phi) + 6 .* phi(x, θ) .- - 4 .* u(x .- ε, θ, phi) .+ u(x .- 2 .* ε, θ, phi)) .* _epsilon^4 + 4 .* phi(x .- ε, θ) .+ phi(x .- 2 .* ε, θ)) .* _epsilon^4 elseif order == 3 @show "me3" - return (u(x .+ 2 .* ε, θ, phi) .- 2 .* u(x .+ ε, θ, phi) .+ 2 .* u(x .- ε, θ, phi) + return (phi(x .+ 2 .* ε, θ) .- 2 .* phi(x .+ ε, θ) .+ 2 .* phi(x .- ε, θ) - - u(x .- 2 .* ε, θ, phi)) .* _epsilon^3 ./ 2 + phi(x .- 2 .* ε, θ)) .* _epsilon^3 ./ 2 elseif order == 2 @show "me2" - return (u(x .+ ε, θ, phi) .+ u(x .- ε, θ, phi) .- 2 .* u(x, θ, phi)) .* _epsilon^2 + return (phi(x .+ ε, θ) .+ phi(x .- ε, θ) .- 2 .* phi(x, θ)) .* _epsilon^2 elseif order == 1 @show "me1" - return (u(x .+ ε, θ, phi) .- u(x .- ε, θ, phi)) .* _epsilon ./ 2 + return (phi(x .+ ε, θ) .- phi(x .- ε, θ)) .* _epsilon ./ 2 else error("This shouldn't happen! Got an order of $(order).") end end # Hacky workaround for metaprogramming with symbolics -@register_symbolic(numeric_derivative(phi, u, x, εs, order, θ), true, [], true) +@register_symbolic(numeric_derivative(phi, x, εs, order, θ), true, [], true) From 8c3ad765e5a9dc517ef5848a7a2680ef75b2eff0 Mon Sep 17 00:00:00 2001 From: Alex Jones Date: Mon, 19 Jun 2023 13:51:42 +0100 Subject: [PATCH 21/40] ignore ds_store --- .gitignore | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 876d74a0bf..1a5ddd55f2 100644 --- a/.gitignore +++ b/.gitignore @@ -7,4 +7,5 @@ Manifest.toml */testlogs docs/build/* scratch -scratch/* \ No newline at end of file +scratch/* +.DS_Store From a949169659cb40b6ceb5dcd2669bdf5dbad37424 Mon Sep 17 00:00:00 2001 From: Alex Jones Date: Mon, 31 Jul 2023 15:37:40 +0100 Subject: [PATCH 22/40] change to x(0) --- src/NeuralPDE.jl | 4 +- src/discretize.jl | 40 ++++-- src/loss_function_generation.jl | 55 ++++---- src/pinn_types.jl | 59 +++++++-- src/symbolic_utilities.jl | 18 +-- test/NNPDE_tests.jl | 228 ++++++++++++++++---------------- 6 files changed, 226 insertions(+), 178 deletions(-) diff --git a/src/NeuralPDE.jl b/src/NeuralPDE.jl index 3e0d3fe1b5..9bbd316a3b 100644 --- a/src/NeuralPDE.jl +++ b/src/NeuralPDE.jl @@ -22,7 +22,7 @@ using ArrayInterface import Optim using DomainSets using Symbolics -using Symbolics: wrap, unwrap, arguments, operation +using Symbolics: wrap, unwrap, arguments, operation, symtype using SymbolicUtils using SymbolicUtils.Code using SymbolicUtils: Prewalk, Postwalk, Chain @@ -34,7 +34,7 @@ import Optimisers import UnPack: @unpack import RecursiveArrayTools import ChainRulesCore, Flux, Lux, ComponentArrays -import ChainRulesCore: @non_differentiable +import ChainRulesCore: @non_differentiable, @ignore_derivatives RuntimeGeneratedFunctions.init(@__MODULE__) diff --git a/src/discretize.jl b/src/discretize.jl index 0b5cbc7260..aa6fe2f526 100644 --- a/src/discretize.jl +++ b/src/discretize.jl @@ -15,7 +15,6 @@ function generate_training_sets(domains, dx, eqs, bcs, eltypeθ, varmap) else dxs = fill(dx, length(domains)) end - @show dxs spans = [infimum(d.domain):dx:supremum(d.domain) for (d, dx) in zip(domains, dxs)] dict_var_span = Dict([d.variables => infimum(d.domain):dx:supremum(d.domain) for (d, dx) in zip(domains, dxs)]) @@ -69,11 +68,10 @@ training strategy: StochasticTraining, QuasiRandomTraining, QuadratureTraining. """ function get_bounds end -function get_bounds(domains, eqs, bcs, eltypeθ, v::VariableMap, strategy::AbstractGridfreeStrategy) +function get_bounds(domains, eqs, bcs, eltypeθ, v::VariableMap, strategy::QuadratureTraining) dict_lower_bound = Dict([d.variables => infimum(d.domain) for d in domains]) dict_upper_bound = Dict([d.variables => supremum(d.domain) for d in domains]) pde_args = get_argument(eqs, v) - @show pde_args pde_lower_bounds = map(pde_args) do pd span = map(p -> get(dict_lower_bound, p, p), pd) @@ -86,7 +84,6 @@ function get_bounds(domains, eqs, bcs, eltypeθ, v::VariableMap, strategy::Abstr pde_bounds = [pde_lower_bounds, pde_upper_bounds] bound_vars = get_variables(bcs, v) - @show bound_vars bcs_lower_bounds = map(bound_vars) do bt map(b -> dict_lower_bound[b], bt) @@ -95,9 +92,32 @@ function get_bounds(domains, eqs, bcs, eltypeθ, v::VariableMap, strategy::Abstr map(b -> dict_upper_bound[b], bt) end bcs_bounds = [bcs_lower_bounds, bcs_upper_bounds] - @show bcs_bounds pde_bounds [pde_bounds, bcs_bounds] end + +function get_bounds(domains, eqs, bcs, eltypeθ, v::VariableMap, strategy) + dx = 1 / strategy.points + dict_span = Dict([d.variables => [ + infimum(d.domain) + dx, + supremum(d.domain) - dx, + ] for d in domains]) + + # pde_bounds = [[infimum(d.domain),supremum(d.domain)] for d in domains] + pde_args = get_argument(eqs, v) + pde_bounds = map(pde_args) do pde_arg + bds = mapreduce(s -> get(dict_span, s, fill(s, 2)), hcat, pde_arg) + bds = eltypeθ.(bds) + bds[1, :], bds[2, :] + end + + bound_args = get_argument(bcs, v) + bcs_bounds = map(bound_args) do bound_arg + bds = mapreduce(s -> get(dict_span, s, fill(s, 2)), hcat, bound_arg) + bds = eltypeθ.(bds) + bds[1, :], bds[2, :] + end + return pde_bounds, bcs_bounds +end # TODO: Get this to work with varmap function get_numeric_integral(pinnrep::PINNRepresentation) @unpack strategy, multioutput, derivative, varmap = pinnrep @@ -268,15 +288,15 @@ function SciMLBase.symbolic_discretize(pdesys::PDESystem, dvs = v.ū acum = [0; accumulate(+, map(length, init_params))] sep = [(acum[i] + 1):acum[i + 1] for i in 1:(length(acum) - 1)] - phimap = map(enumerate(dvs)) do (i, dv) + phi = map(enumerate(dvs)) do (i, dv) if (phi isa Vector && phi[1].f isa Optimisers.Restructure) || (!(phi isa Vector) && phi.f isa Optimisers.Restructure) # Flux.Chain - dv => (coord, expr_θ) -> phi[i](coord, expr_θ[sep[i]]) + (coord, expr_θ) -> phi[i](coord, expr_θ[sep[i]]) else # Lux.AbstractExplicitLayer - dv => (coord, expr_θ) -> phi[i](coord, expr_θ.depvar.$(dv)) + (coord, expr_θ) -> phi[i](coord, expr_θ.depvar.$(dv)) end - end |> Dict + end else phimap = nothing end @@ -293,7 +313,7 @@ function SciMLBase.symbolic_discretize(pdesys::PDESystem, pinnrep = PINNRepresentation(eqs, bcs, domains, eq_params, defaults, default_p, param_estim, additional_loss, adaloss, v, logger, multioutput, iteration, init_params, flat_init_params, phi, - phimap, derivative, + derivative, strategy, eqdata, nothing, nothing, nothing, nothing) #integral = get_numeric_integral(pinnrep) diff --git a/src/loss_function_generation.jl b/src/loss_function_generation.jl index 928be8049c..79e17bbb70 100644 --- a/src/loss_function_generation.jl +++ b/src/loss_function_generation.jl @@ -8,7 +8,7 @@ function build_symbolic_loss_function(pinnrep::PINNRepresentation, eq; integrand = nothing, transformation_vars = nothing) @unpack varmap, eqdata, - phi, phimap, derivative, integral, + phi, derivative, integral, multioutput, init_params, strategy, eq_params, param_estim, default_p = pinnrep @@ -53,10 +53,9 @@ function build_symbolic_loss_function(pinnrep::PINNRepresentation, eq; num_numbers = 0 out = map(enumerate(eq_args)) do (i, x) if x isa Number - num_numbers += 1 - fill(convert(eltypeθ, x), length(cord[[1], :])) + fill(convert(eltypeθ, x), size(cord[[1], :])) else - cord[[i-num_numbers], :] + cord[[i], :] end end if out === nothing @@ -67,21 +66,15 @@ function build_symbolic_loss_function(pinnrep::PINNRepresentation, eq; end full_loss_func = (cord, θ, phi, p) -> begin - coords = get_coords(cord) - @show coords - combinedcoords = reduce(vcat, coords, init = []) - @show combinedcoords - loss_function(coords, combinedcoords, θ, phi, get_ps(θ)) + coords = [[nothing]] + @ignore_derivatives coords = get_coords(cord) + loss_function(coords, θ, phi, get_ps(θ)) end return full_loss_func end function build_loss_function(pinnrep, eqs) - @unpack eq_params, param_estim, default_p, phi, phimap, multioutput, derivative, integral = pinnrep - - if multioutput - phi = phimap - end + @unpack eq_params, param_estim, default_p, phi, multioutput, derivative, integral = pinnrep _loss_function = build_symbolic_loss_function(pinnrep, eqs, eq_params = eq_params, @@ -106,15 +99,19 @@ end function parse_equation(pinnrep::PINNRepresentation, term, ivs; is_integral = false, dict_transformation_vars = nothing, transformation_vars = nothing) - @unpack varmap, eqdata, derivative, integral, flat_init_params = pinnrep + @unpack varmap, eqdata, derivative, integral, flat_init_params, multioutput = pinnrep eltypeθ = eltype(flat_init_params) ex_vars = get_depvars(term, varmap.depvar_ops) - ignore = vcat(operation.(ex_vars), getindex, Differential, Integral, ~) - dummyvars = @variables phi(..), θ_SYMBOL, coord + if multioutput + dummyvars = @variables phi[1:length(varmap.ū)](..), θ_SYMBOL + else + dummyvars = @variables phi(..), θ_SYMBOL + end + dummyvars = unwrap.(dummyvars) - deriv_rules = generate_derivative_rules(term, eqdata, eltypeθ, dummyvars, derivative, varmap) + deriv_rules = generate_derivative_rules(term, eqdata, eltypeθ, dummyvars, derivative, varmap, multioutput) ch = Prewalk(Chain(deriv_rules)) @@ -124,21 +121,27 @@ function parse_equation(pinnrep::PINNRepresentation, term, ivs; is_integral = fa ps = DestructuredArgs(varmap.ps) - args = [sym_coords, coord, θ_SYMBOL, phi, ps] + args = [sym_coords, θ_SYMBOL, phi, ps] ex = Func(args, [], expr) |> toexpr |> _dot_ + + @show ex f = @RuntimeGeneratedFunction ex return f end -function generate_derivative_rules(term, eqdata, eltypeθ, dummyvars, derivative, varmap) - phi, θ, coord = dummyvars +function generate_derivative_rules(term, eqdata, eltypeθ, dummyvars, derivative, varmap, multioutput) + phi, θ = dummyvars + if symtype(phi) isa AbstractArray + phi = collect(phi) + end + dvs = get_depvars(term, varmap.depvar_ops) - @show dvs + @show eltypeθ # Orthodox derivatives n(w) = length(arguments(w)) rs = reduce(vcat, [reduce(vcat, [[@rule $((Differential(x)^d)(w)) => - derivative(ufunc(w, coord, θ, phi), coord, + derivative(ufunc(w, phi, varmap), reducevcat(arguments(w), eltypeθ), [get_ε(n(w), j, eltypeθ, i) for i in 1:d], d, θ) @@ -155,15 +158,15 @@ function generate_derivative_rules(term, eqdata, eltypeθ, dummyvars, derivative ε1 = [get_ε(n(w), j, eltypeθ, i) for i in 1:2] ε2 = [get_ε(n(w), k, eltypeθ, i) for i in 1:2] [@rule $((Differential(x))((Differential(y))(w))) => - derivative((cord_, θ_) -> derivative(ufunc(w, coord, θ, phi), cord_, + derivative((coord_, θ_) -> derivative(ufunc(w, phi, varmap), reducevcat(arguments(w), eltypeθ), ε2, 1, θ_), - coord, ε1, 1, θ)] + reducevcat(arguments(w), eltypeθ), ε1, 1, θ)] end end end end vr = mapreduce(vcat, dvs, init = []) do w - @rule w => ufunc(w, coord, θ, phi)(coord, θ) + @rule w => ufunc(w, phi, varmap)(reducevcat(arguments(w), eltypeθ), θ) end return [mx; rs; vr] diff --git a/src/pinn_types.jl b/src/pinn_types.jl index 0fae251df3..a4a03ddc5f 100644 --- a/src/pinn_types.jl +++ b/src/pinn_types.jl @@ -239,10 +239,6 @@ mutable struct PINNRepresentation The representation of the test function of the PDE solution """ phi::Any - """ - the map of vars to chains - """ - phimap::Any """ The function used for computing the derivative """ @@ -350,9 +346,6 @@ function (f::Phi{<:Optimisers.Restructure})(x, θ) f.f(θ)(adapt(parameterless_type(θ), x)) end -ufunc(u, cord, θ, phi) = phi isa Dict ? phi[u](cord, θ) : phi(cord, θ) - - # the method to calculate the derivative function numeric_derivative(phi, x, εs, order, θ) _type = parameterless_type(ComponentArrays.getdata(θ)) @@ -368,28 +361,23 @@ function numeric_derivative(phi, x, εs, order, θ) # if order 1, this is trivially true if order > 4 || any(x -> x != εs[1], εs) - @show "me" return (numeric_derivative(phi, x .+ ε, @view(εs[1:(end-1)]), order - 1, θ) .- numeric_derivative(phi, x .- ε, @view(εs[1:(end-1)]), order - 1, θ)) .* _epsilon ./ 2 elseif order == 4 - @show "me4" return (phi(x .+ 2 .* ε, θ) .- 4 .* phi(x .+ ε, θ) .+ 6 .* phi(x, θ) .- 4 .* phi(x .- ε, θ) .+ phi(x .- 2 .* ε, θ)) .* _epsilon^4 elseif order == 3 - @show "me3" return (phi(x .+ 2 .* ε, θ) .- 2 .* phi(x .+ ε, θ) .+ 2 .* phi(x .- ε, θ) - phi(x .- 2 .* ε, θ)) .* _epsilon^3 ./ 2 elseif order == 2 - @show "me2" return (phi(x .+ ε, θ) .+ phi(x .- ε, θ) .- 2 .* phi(x, θ)) .* _epsilon^2 elseif order == 1 - @show "me1" return (phi(x .+ ε, θ) .- phi(x .- ε, θ)) .* _epsilon ./ 2 else error("This shouldn't happen! Got an order of $(order).") @@ -397,3 +385,50 @@ function numeric_derivative(phi, x, εs, order, θ) end # Hacky workaround for metaprogramming with symbolics @register_symbolic(numeric_derivative(phi, x, εs, order, θ), true, [], true) + +function ufunc(u, phi, v) + if symtype(phi) isa AbstractArray + return phi[findfirst(w -> isequal(operation(w), operation(u)), v.ū)] + else + return phi + end +end + +#= +_vcat(x::Number...) = vcat(x...) +_vcat(x::AbstractArray{<:Number}...) = vcat(x...) +function _vcat(x::Union{Number, AbstractArray{<:Number}}...) + example = first(Iterators.filter(e -> !(e isa Number), x)) + dims = (1, size(example)[2:end]...) + x = map(el -> el isa Number ? (typeof(example))(fill(el, dims)) : el, x) + _vcat(x...) +end +_vcat(x...) = vcat(x...) +https://github.com/SciML/NeuralPDE.jl/pull/627/files +=# + + + +function reducevcat(vector, eltypeθ) + if all(x -> x isa Number, vector) + return vector + else + z = findfirst(x -> !(x isa Number), vector) + return rvcat(vector, vector[z], eltypeθ) + end +end + +function rvcat(example, vector, eltypeθ) + isnothing(vector) && return [[nothing]] + return mapreduce(hcat, vector) do x + if x isa Number + out = typeof(example)(fill(convert(eltypeθ, x), size(example))) + out + else + out = x + out + end + end +end + +@register_symbolic(rvcat(vector, example, eltypeθ), true, [], true) \ No newline at end of file diff --git a/src/symbolic_utilities.jl b/src/symbolic_utilities.jl index 15cbb5e5ca..fc81f3020a 100644 --- a/src/symbolic_utilities.jl +++ b/src/symbolic_utilities.jl @@ -1,16 +1,5 @@ using Base.Broadcast -function get_limits(domain) - if domain isa AbstractInterval - return [leftendpoint(domain)], [rightendpoint(domain)] - elseif domain isa ProductDomain - return collect(map(leftendpoint, DomainSets.components(domain))), - collect(map(rightendpoint, DomainSets.components(domain))) - end -end - -θ = gensym("θ") - """ Override `Broadcast.__dot__` with `Broadcast.dottable(x::Function) = true` @@ -36,8 +25,8 @@ dottable_(x::Phi) = false _dot_(x) = x function _dot_(x::Expr) dotargs = Base.mapany(_dot_, x.args) - nodot = [:phi, Symbol("NeuralPDE.numeric_derivative")] - if x.head === :call && dottable_(x.args[1]) && all(s -> x.args[1] !== s, nodot) + nodot = [:phi, Symbol("NeuralPDE.numeric_derivative"), NeuralPDE.rvcat] + if x.head === :call && dottable_(x.args[1]) && all(s -> x.args[1] != s, nodot) Expr(:., dotargs[1], Expr(:tuple, dotargs[2:end]...)) elseif x.head === :comparison Expr(:comparison, @@ -217,7 +206,6 @@ function get_argument(eqs, v::VariableMap) f_vars = filter(x -> !isempty(x), _vars) map(first, f_vars) end - @show vars args_ = map(vars) do _vars seen = [] filter(reduce(vcat, arguments.(_vars), init = [])) do x @@ -252,3 +240,5 @@ function get_number(eqs, v::VariableMap) args = get_argument(eqs, v) return map(arg -> filter(x -> x isa Number, arg), args) end + +sym_op(u) = Symbol(operation(u)) \ No newline at end of file diff --git a/test/NNPDE_tests.jl b/test/NNPDE_tests.jl index 4d6a31c27b..b4fb8018d6 100644 --- a/test/NNPDE_tests.jl +++ b/test/NNPDE_tests.jl @@ -435,117 +435,117 @@ end # plot(p1,p2) end -# ## Example 5, 2d wave equation, neumann boundary condition -# @testset "Example 5, 2d wave equation, neumann boundary condition" begin -# #here we use low level api for build solution -# @parameters x, t -# @variables u(..) -# Dxx = Differential(x)^2 -# Dtt = Differential(t)^2 -# Dt = Differential(t) - -# #2D PDE -# C = 1 -# eq = Dtt(u(x, t)) ~ C^2 * Dxx(u(x, t)) - -# # Initial and boundary conditions -# bcs = [u(0, t) ~ 0.0,# for all t > 0 -# u(1, t) ~ 0.0,# for all t > 0 -# u(x, 0) ~ x * (1.0 - x), #for all 0 < x < 1 -# Dt(u(x, 0)) ~ 0.0] #for all 0 < x < 1] - -# # Space and time domains -# domains = [x ∈ Interval(0.0, 1.0), -# t ∈ Interval(0.0, 1.0)] -# @named pde_system = PDESystem(eq, bcs, domains, [x, t], [u(x, t)]) - -# # Neural network -# chain = Lux.Chain(Lux.Dense(2, 16, Lux.σ), Lux.Dense(16, 16, Lux.σ), Lux.Dense(16, 1)) -# phi = NeuralPDE.Phi(chain) -# derivative = NeuralPDE.numeric_derivative - -# quadrature_strategy = NeuralPDE.QuadratureTraining(quadrature_alg = CubatureJLh(), -# reltol = 1e-3, abstol = 1e-3, -# maxiters = 50, batch = 100) - -# discretization = NeuralPDE.PhysicsInformedNN(chain, quadrature_strategy) -# prob = NeuralPDE.discretize(pde_system, discretization) - -# cb_ = function (p, l) -# println("loss: ", l) -# println("losses: ", map(l -> l(p), loss_functions)) -# return false -# end - -# res = solve(prob, OptimizationOptimJL.BFGS(); maxiters = 500, f_abstol = 10^-6) - -# dx = 0.1 -# xs, ts = [infimum(d.domain):dx:supremum(d.domain) for d in domains] -# function analytic_sol_func(x, t) -# sum([(8 / (k^3 * pi^3)) * sin(k * pi * x) * cos(C * k * pi * t) for k in 1:2:50000]) -# end - -# u_predict = reshape([first(phi([x, t], res.u)) for x in xs for t in ts], -# (length(xs), length(ts))) -# u_real = reshape([analytic_sol_func(x, t) for x in xs for t in ts], -# (length(xs), length(ts))) - -# @test u_predict≈u_real atol=0.1 - -# # diff_u = abs.(u_predict .- u_real) -# # p1 = plot(xs, ts, u_real, linetype=:contourf,title = "analytic"); -# # p2 =plot(xs, ts, u_predict, linetype=:contourf,title = "predict"); -# # p3 = plot(xs, ts, diff_u,linetype=:contourf,title = "error"); -# # plot(p1,p2,p3) -# end -# ## Example 6, pde with mixed derivative -# @testset "Example 6, pde with mixed derivative" begin -# @parameters x y -# @variables u(..) -# Dxx = Differential(x)^2 -# Dyy = Differential(y)^2 -# Dx = Differential(x) -# Dy = Differential(y) - -# eq = Dxx(u(x, y)) + Dx(Dy(u(x, y))) - 2 * Dyy(u(x, y)) ~ -1.0 - -# # Initial and boundary conditions -# bcs = [u(x, 0) ~ x, -# Dy(u(x, 0)) ~ x, -# u(x, 0) ~ Dy(u(x, 0))] - -# # Space and time domains -# domains = [x ∈ Interval(0.0, 1.0), y ∈ Interval(0.0, 1.0)] - -# quadrature_strategy = NeuralPDE.QuadratureTraining() -# # Neural network -# inner = 20 -# chain = Lux.Chain(Lux.Dense(2, inner, Lux.tanh), Lux.Dense(inner, inner, Lux.tanh), -# Lux.Dense(inner, 1)) - -# discretization = NeuralPDE.PhysicsInformedNN(chain, quadrature_strategy) -# @named pde_system = PDESystem(eq, bcs, domains, [x, y], [u(x, y)]) - -# prob = NeuralPDE.discretize(pde_system, discretization) - -# res = solve(prob, OptimizationOptimJL.BFGS(); maxiters = 1500) -# @show res.original - -# phi = discretization.phi - -# analytic_sol_func(x, y) = x + x * y + y^2 / 2 -# xs, ys = [infimum(d.domain):0.01:supremum(d.domain) for d in domains] - -# u_predict = reshape([first(phi([x, y], res.u)) for x in xs for y in ys], -# (length(xs), length(ys))) -# u_real = reshape([analytic_sol_func(x, y) for x in xs for y in ys], -# (length(xs), length(ys))) -# diff_u = abs.(u_predict .- u_real) - -# @test u_predict≈u_real rtol=0.1 - -# # p1 = plot(xs, ys, u_real, linetype=:contourf,title = "analytic"); -# # p2 = plot(xs, ys, u_predict, linetype=:contourf,title = "predict"); -# # p3 = plot(xs, ys, diff_u,linetype=:contourf,title = "error"); -# # plot(p1,p2,p3) -# end +## Example 5, 2d wave equation, neumann boundary condition +@testset "Example 5, 2d wave equation, neumann boundary condition" begin + #here we use low level api for build solution + @parameters x, t + @variables u(..) + Dxx = Differential(x)^2 + Dtt = Differential(t)^2 + Dt = Differential(t) + + #2D PDE + C = 1 + eq = Dtt(u(x, t)) ~ C^2 * Dxx(u(x, t)) + + # Initial and boundary conditions + bcs = [u(0, t) ~ 0.0,# for all t > 0 + u(1, t) ~ 0.0,# for all t > 0 + u(x, 0) ~ x * (1.0 - x), #for all 0 < x < 1 + Dt(u(x, 0)) ~ 0.0] #for all 0 < x < 1] + + # Space and time domains + domains = [x ∈ Interval(0.0, 1.0), + t ∈ Interval(0.0, 1.0)] + @named pde_system = PDESystem(eq, bcs, domains, [x, t], [u(x, t)]) + + # Neural network + chain = Lux.Chain(Lux.Dense(2, 16, Lux.σ), Lux.Dense(16, 16, Lux.σ), Lux.Dense(16, 1)) + phi = NeuralPDE.Phi(chain) + derivative = NeuralPDE.numeric_derivative + + quadrature_strategy = NeuralPDE.QuadratureTraining(quadrature_alg = CubatureJLh(), + reltol = 1e-3, abstol = 1e-3, + maxiters = 50, batch = 100) + + discretization = NeuralPDE.PhysicsInformedNN(chain, quadrature_strategy) + prob = NeuralPDE.discretize(pde_system, discretization) + + cb_ = function (p, l) + println("loss: ", l) + println("losses: ", map(l -> l(p), loss_functions)) + return false + end + + res = solve(prob, OptimizationOptimJL.BFGS(); maxiters = 500, f_abstol = 10^-6) + + dx = 0.1 + xs, ts = [infimum(d.domain):dx:supremum(d.domain) for d in domains] + function analytic_sol_func(x, t) + sum([(8 / (k^3 * pi^3)) * sin(k * pi * x) * cos(C * k * pi * t) for k in 1:2:50000]) + end + + u_predict = reshape([first(phi([x, t], res.u)) for x in xs for t in ts], + (length(xs), length(ts))) + u_real = reshape([analytic_sol_func(x, t) for x in xs for t in ts], + (length(xs), length(ts))) + + @test u_predict≈u_real atol=0.1 + + # diff_u = abs.(u_predict .- u_real) + # p1 = plot(xs, ts, u_real, linetype=:contourf,title = "analytic"); + # p2 =plot(xs, ts, u_predict, linetype=:contourf,title = "predict"); + # p3 = plot(xs, ts, diff_u,linetype=:contourf,title = "error"); + # plot(p1,p2,p3) +end +## Example 6, pde with mixed derivative +@testset "Example 6, pde with mixed derivative" begin + @parameters x y + @variables u(..) + Dxx = Differential(x)^2 + Dyy = Differential(y)^2 + Dx = Differential(x) + Dy = Differential(y) + + eq = Dxx(u(x, y)) + Dx(Dy(u(x, y))) - 2 * Dyy(u(x, y)) ~ -1.0 + + # Initial and boundary conditions + bcs = [u(x, 0) ~ x, + Dy(u(x, 0)) ~ x, + u(x, 0) ~ Dy(u(x, 0))] + + # Space and time domains + domains = [x ∈ Interval(0.0, 1.0), y ∈ Interval(0.0, 1.0)] + + quadrature_strategy = NeuralPDE.QuadratureTraining() + # Neural network + inner = 20 + chain = Lux.Chain(Lux.Dense(2, inner, Lux.tanh), Lux.Dense(inner, inner, Lux.tanh), + Lux.Dense(inner, 1)) + + discretization = NeuralPDE.PhysicsInformedNN(chain, quadrature_strategy) + @named pde_system = PDESystem(eq, bcs, domains, [x, y], [u(x, y)]) + + prob = NeuralPDE.discretize(pde_system, discretization) + + res = solve(prob, OptimizationOptimJL.BFGS(); maxiters = 1500) + @show res.original + + phi = discretization.phi + + analytic_sol_func(x, y) = x + x * y + y^2 / 2 + xs, ys = [infimum(d.domain):0.01:supremum(d.domain) for d in domains] + + u_predict = reshape([first(phi([x, y], res.u)) for x in xs for y in ys], + (length(xs), length(ys))) + u_real = reshape([analytic_sol_func(x, y) for x in xs for y in ys], + (length(xs), length(ys))) + diff_u = abs.(u_predict .- u_real) + + @test u_predict≈u_real rtol=0.1 + + # p1 = plot(xs, ys, u_real, linetype=:contourf,title = "analytic"); + # p2 = plot(xs, ys, u_predict, linetype=:contourf,title = "predict"); + # p3 = plot(xs, ys, diff_u,linetype=:contourf,title = "error"); + # plot(p1,p2,p3) +end From 6df392a87ed6124ca5e9e90086ebff7211e86b80 Mon Sep 17 00:00:00 2001 From: Alex Jones Date: Thu, 4 Jan 2024 12:54:59 +0000 Subject: [PATCH 23/40] fix deved package --- Project.toml | 3 ++- src/pinn_types.jl | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 3cd1e08ba3..20398a250a 100644 --- a/Project.toml +++ b/Project.toml @@ -64,7 +64,7 @@ Lux = "0.4, 0.5" MCMCChains = "6" ModelingToolkit = "8" MonteCarloMeasurements = "1" -Optim = ">= 1.7.8" +Optim = ">= 1.7.7" Optimisers = "0.2, 0.3" Optimization = "3" QuasiMonteCarlo = "0.3.2" @@ -72,6 +72,7 @@ RecursiveArrayTools = "2.31" Reexport = "1.0" RuntimeGeneratedFunctions = "0.5" SciMLBase = "1.91, 2" +PDEBase = "0.1.7" Statistics = "1" StochasticDiffEq = "6.13" SymbolicUtils = "1" diff --git a/src/pinn_types.jl b/src/pinn_types.jl index 90ef0fa65f..c64ac14092 100644 --- a/src/pinn_types.jl +++ b/src/pinn_types.jl @@ -384,7 +384,7 @@ function numeric_derivative(phi, x, εs, order, θ) end end # Hacky workaround for metaprogramming with symbolics -@register_symbolic(numeric_derivative(phi, x, εs, order, θ), true, [], true) +@register_symbolic(numeric_derivative(phi, x, εs, order, θ), true, []) function ufunc(u, phi, v) if symtype(phi) isa AbstractArray From a7d61ec22694255ce6975fe4af50abb19207a539 Mon Sep 17 00:00:00 2001 From: Alex Jones Date: Thu, 4 Jan 2024 12:55:55 +0000 Subject: [PATCH 24/40] ditto --- src/pinn_types.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pinn_types.jl b/src/pinn_types.jl index c64ac14092..fa495f4093 100644 --- a/src/pinn_types.jl +++ b/src/pinn_types.jl @@ -431,4 +431,4 @@ function rvcat(example, vector, eltypeθ) end end -@register_symbolic(rvcat(vector, example, eltypeθ), true, [], true) \ No newline at end of file +@register_symbolic(rvcat(vector, example, eltypeθ), true, []) \ No newline at end of file From 1cd80764708bda21a3f5f980134bbf57087f10a1 Mon Sep 17 00:00:00 2001 From: Alex Jones Date: Thu, 4 Jan 2024 15:26:36 +0000 Subject: [PATCH 25/40] last ditch fix symbolic error --- src/loss_function_generation.jl | 12 ++++++++---- src/pinn_types.jl | 8 ++++---- src/symbolic_utilities.jl | 1 + 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/src/loss_function_generation.jl b/src/loss_function_generation.jl index 79e17bbb70..25bb72c417 100644 --- a/src/loss_function_generation.jl +++ b/src/loss_function_generation.jl @@ -105,9 +105,9 @@ function parse_equation(pinnrep::PINNRepresentation, term, ivs; is_integral = fa ex_vars = get_depvars(term, varmap.depvar_ops) if multioutput - dummyvars = @variables phi[1:length(varmap.ū)](..), θ_SYMBOL + dummyvars = @variables phi[1:length(varmap.ū)](..), θ_SYMBOL, switch else - dummyvars = @variables phi(..), θ_SYMBOL + dummyvars = @variables phi(..), θ_SYMBOL, switch end dummyvars = unwrap.(dummyvars) @@ -131,13 +131,14 @@ function parse_equation(pinnrep::PINNRepresentation, term, ivs; is_integral = fa end function generate_derivative_rules(term, eqdata, eltypeθ, dummyvars, derivative, varmap, multioutput) - phi, θ = dummyvars + phi, θ, switch = dummyvars if symtype(phi) isa AbstractArray phi = collect(phi) end dvs = get_depvars(term, varmap.depvar_ops) @show eltypeθ + @show methods(derivative) # Orthodox derivatives n(w) = length(arguments(w)) rs = reduce(vcat, [reduce(vcat, [[@rule $((Differential(x)^d)(w)) => @@ -169,7 +170,9 @@ function generate_derivative_rules(term, eqdata, eltypeθ, dummyvars, derivative @rule w => ufunc(w, phi, varmap)(reducevcat(arguments(w), eltypeθ), θ) end - return [mx; rs; vr] + sr = @rule switch => 1 + + return [mx; rs; vr; sr] end function generate_integral_rules(eq, eqdata, dummyvars) @@ -178,3 +181,4 @@ function generate_integral_rules(eq, eqdata, dummyvars) #! with rules without putting symbols through the solve end + diff --git a/src/pinn_types.jl b/src/pinn_types.jl index fa495f4093..38ca4f4b4d 100644 --- a/src/pinn_types.jl +++ b/src/pinn_types.jl @@ -384,7 +384,7 @@ function numeric_derivative(phi, x, εs, order, θ) end end # Hacky workaround for metaprogramming with symbolics -@register_symbolic(numeric_derivative(phi, x, εs, order, θ), true, []) +@register_symbolic(numeric_derivative(phi, x, εs, order, θ)) function ufunc(u, phi, v) if symtype(phi) isa AbstractArray @@ -409,7 +409,7 @@ https://github.com/SciML/NeuralPDE.jl/pull/627/files -function reducevcat(vector, eltypeθ) +function reducevcat(vector::Vector, eltypeθ) if all(x -> x isa Number, vector) return vector else @@ -418,7 +418,7 @@ function reducevcat(vector, eltypeθ) end end -function rvcat(example, vector, eltypeθ) +function rvcat(example, vector, eltypeθ, switch) isnothing(vector) && return [[nothing]] return mapreduce(hcat, vector) do x if x isa Number @@ -431,4 +431,4 @@ function rvcat(example, vector, eltypeθ) end end -@register_symbolic(rvcat(vector, example, eltypeθ), true, []) \ No newline at end of file +@register_symbolic(rvcat(vector, example, eltypeθ, switch)) \ No newline at end of file diff --git a/src/symbolic_utilities.jl b/src/symbolic_utilities.jl index fc81f3020a..9a430e6e2a 100644 --- a/src/symbolic_utilities.jl +++ b/src/symbolic_utilities.jl @@ -89,6 +89,7 @@ function get_ε(dim::Int, der_num::Int, ::Type{eltypeθ}, order) where {eltypeθ epsilon = ^(eps(eltypeθ), one(eltypeθ) / (2 + order)) ε = zeros(eltypeθ, dim) ε[der_num] = epsilon + @show typeof(ε) ε end From 12995659ee9351c953bebd50a9fcc4a592e6adc9 Mon Sep 17 00:00:00 2001 From: Alex Jones Date: Thu, 4 Jan 2024 16:12:03 +0000 Subject: [PATCH 26/40] fix switch --- src/loss_function_generation.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/loss_function_generation.jl b/src/loss_function_generation.jl index 25bb72c417..eeb7bf5362 100644 --- a/src/loss_function_generation.jl +++ b/src/loss_function_generation.jl @@ -120,7 +120,6 @@ function parse_equation(pinnrep::PINNRepresentation, term, ivs; is_integral = fa sym_coords = DestructuredArgs(ivs) ps = DestructuredArgs(varmap.ps) - args = [sym_coords, θ_SYMBOL, phi, ps] ex = Func(args, [], expr) |> toexpr |> _dot_ @@ -137,12 +136,13 @@ function generate_derivative_rules(term, eqdata, eltypeθ, dummyvars, derivative end dvs = get_depvars(term, varmap.depvar_ops) + ivs = get_indvars(term, v) @show eltypeθ @show methods(derivative) # Orthodox derivatives n(w) = length(arguments(w)) rs = reduce(vcat, [reduce(vcat, [[@rule $((Differential(x)^d)(w)) => - derivative(ufunc(w, phi, varmap), reducevcat(arguments(w), eltypeθ), + derivative(ufunc(w, phi, varmap), reducevcat(arguments(w), eltypeθ, switch), [get_ε(n(w), j, eltypeθ, i) for i in 1:d], d, θ) @@ -159,7 +159,7 @@ function generate_derivative_rules(term, eqdata, eltypeθ, dummyvars, derivative ε1 = [get_ε(n(w), j, eltypeθ, i) for i in 1:2] ε2 = [get_ε(n(w), k, eltypeθ, i) for i in 1:2] [@rule $((Differential(x))((Differential(y))(w))) => - derivative((coord_, θ_) -> derivative(ufunc(w, phi, varmap), reducevcat(arguments(w), eltypeθ), + derivative((coord_, θ_) -> derivative(ufunc(w, phi, varmap), reducevcat(arguments(w), eltypeθ, switch), ε2, 1, θ_), reducevcat(arguments(w), eltypeθ), ε1, 1, θ)] end From aa123214d7283f0ca24d7213cf170c19fd10b3e7 Mon Sep 17 00:00:00 2001 From: Alex Jones Date: Thu, 4 Jan 2024 16:47:34 +0000 Subject: [PATCH 27/40] undo stupidity --- src/loss_function_generation.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/loss_function_generation.jl b/src/loss_function_generation.jl index eeb7bf5362..cb0d18e178 100644 --- a/src/loss_function_generation.jl +++ b/src/loss_function_generation.jl @@ -136,7 +136,6 @@ function generate_derivative_rules(term, eqdata, eltypeθ, dummyvars, derivative end dvs = get_depvars(term, varmap.depvar_ops) - ivs = get_indvars(term, v) @show eltypeθ @show methods(derivative) # Orthodox derivatives @@ -180,5 +179,4 @@ function generate_integral_rules(eq, eqdata, dummyvars) #! all that should be needed is to solve an integral problem, the trick is doing this #! with rules without putting symbols through the solve -end - +end \ No newline at end of file From cd0052e5d27b6b6e5ffed0da50afcd8014bb9aed Mon Sep 17 00:00:00 2001 From: Alex Jones Date: Thu, 4 Jan 2024 18:12:32 +0000 Subject: [PATCH 28/40] fix switch --- src/pinn_types.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pinn_types.jl b/src/pinn_types.jl index 38ca4f4b4d..92e6950daf 100644 --- a/src/pinn_types.jl +++ b/src/pinn_types.jl @@ -414,7 +414,7 @@ function reducevcat(vector::Vector, eltypeθ) return vector else z = findfirst(x -> !(x isa Number), vector) - return rvcat(vector, vector[z], eltypeθ) + return rvcat(vector, vector[z], eltypeθ, switch) end end From 2e57081cb7614db710f22d0a48dca13782be8760 Mon Sep 17 00:00:00 2001 From: Alex Jones Date: Thu, 4 Jan 2024 19:01:40 +0000 Subject: [PATCH 29/40] " --- src/pinn_types.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pinn_types.jl b/src/pinn_types.jl index 92e6950daf..5dd983a12d 100644 --- a/src/pinn_types.jl +++ b/src/pinn_types.jl @@ -409,7 +409,7 @@ https://github.com/SciML/NeuralPDE.jl/pull/627/files -function reducevcat(vector::Vector, eltypeθ) +function reducevcat(vector::Vector, eltypeθ, switch) if all(x -> x isa Number, vector) return vector else From 39bd9c72459411f3ba22945227719cac8429dcb5 Mon Sep 17 00:00:00 2001 From: Alex Jones Date: Sun, 7 Jan 2024 00:19:23 +0000 Subject: [PATCH 30/40] rerun tests --- src/loss_function_generation.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/loss_function_generation.jl b/src/loss_function_generation.jl index cb0d18e178..a938fccfb9 100644 --- a/src/loss_function_generation.jl +++ b/src/loss_function_generation.jl @@ -166,7 +166,7 @@ function generate_derivative_rules(term, eqdata, eltypeθ, dummyvars, derivative end end vr = mapreduce(vcat, dvs, init = []) do w - @rule w => ufunc(w, phi, varmap)(reducevcat(arguments(w), eltypeθ), θ) + @rule w => ufunc(w, phi, varmap)(reducevcat(arguments(w), eltypeθ, switch), θ) end sr = @rule switch => 1 From aafd11da5c68c5e55441a28da2475f5e15ea9a9f Mon Sep 17 00:00:00 2001 From: Alex Jones Date: Sun, 7 Jan 2024 23:58:33 +0000 Subject: [PATCH 31/40] " --- src/loss_function_generation.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/loss_function_generation.jl b/src/loss_function_generation.jl index a938fccfb9..7a1753fc44 100644 --- a/src/loss_function_generation.jl +++ b/src/loss_function_generation.jl @@ -160,7 +160,7 @@ function generate_derivative_rules(term, eqdata, eltypeθ, dummyvars, derivative [@rule $((Differential(x))((Differential(y))(w))) => derivative((coord_, θ_) -> derivative(ufunc(w, phi, varmap), reducevcat(arguments(w), eltypeθ, switch), ε2, 1, θ_), - reducevcat(arguments(w), eltypeθ), ε1, 1, θ)] + reducevcat(arguments(w), eltypeθ, switch), ε1, 1, θ)] end end end From cf170915121f8c51103867ac71b0e01b56c89a50 Mon Sep 17 00:00:00 2001 From: Alex Jones Date: Mon, 8 Jan 2024 12:02:34 +0000 Subject: [PATCH 32/40] test --- src/loss_function_generation.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/loss_function_generation.jl b/src/loss_function_generation.jl index 7a1753fc44..1ce500706a 100644 --- a/src/loss_function_generation.jl +++ b/src/loss_function_generation.jl @@ -163,7 +163,7 @@ function generate_derivative_rules(term, eqdata, eltypeθ, dummyvars, derivative reducevcat(arguments(w), eltypeθ, switch), ε1, 1, θ)] end end - end + end end vr = mapreduce(vcat, dvs, init = []) do w @rule w => ufunc(w, phi, varmap)(reducevcat(arguments(w), eltypeθ, switch), θ) From 765c047fd18e763fab4eefb5e28cea5bdee9c4f6 Mon Sep 17 00:00:00 2001 From: Alex Jones Date: Fri, 12 Jan 2024 16:04:58 +0000 Subject: [PATCH 33/40] try other method --- src/loss_function_generation.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/loss_function_generation.jl b/src/loss_function_generation.jl index 1ce500706a..96e3ef276a 100644 --- a/src/loss_function_generation.jl +++ b/src/loss_function_generation.jl @@ -111,11 +111,12 @@ function parse_equation(pinnrep::PINNRepresentation, term, ivs; is_integral = fa end dummyvars = unwrap.(dummyvars) - deriv_rules = generate_derivative_rules(term, eqdata, eltypeθ, dummyvars, derivative, varmap, multioutput) + deriv_rules, swch = generate_derivative_rules(term, eqdata, eltypeθ, dummyvars, derivative, varmap, multioutput) ch = Prewalk(Chain(deriv_rules)) expr = ch(term) + expr = swch(expr) sym_coords = DestructuredArgs(ivs) ps = DestructuredArgs(varmap.ps) @@ -170,8 +171,8 @@ function generate_derivative_rules(term, eqdata, eltypeθ, dummyvars, derivative end sr = @rule switch => 1 - - return [mx; rs; vr; sr] + swch = Postwalk(sr) + return [mx; rs; vr], swch end function generate_integral_rules(eq, eqdata, dummyvars) From 8406bed99ecec55ec9a01c1b2b6104a417c2bdce Mon Sep 17 00:00:00 2001 From: Alex Jones Date: Thu, 18 Jan 2024 16:40:53 +0000 Subject: [PATCH 34/40] rvcat --- src/pinn_types.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/pinn_types.jl b/src/pinn_types.jl index 901ded980d..e2c6a37196 100644 --- a/src/pinn_types.jl +++ b/src/pinn_types.jl @@ -539,6 +539,7 @@ https://github.com/SciML/NeuralPDE.jl/pull/627/files function reducevcat(vector::Vector, eltypeθ, switch) + isnothing(vector) && return [[nothing]] if all(x -> x isa Number, vector) return vector else @@ -547,9 +548,8 @@ function reducevcat(vector::Vector, eltypeθ, switch) end end -function rvcat(example, vector, eltypeθ, switch) - isnothing(vector) && return [[nothing]] - return mapreduce(hcat, vector) do x +function rvcat(example, sym, eltypeθ, switch) + return mapreduce(hcat, example) do x if x isa Number out = typeof(example)(fill(convert(eltypeθ, x), size(example))) out From 60e2c8c3e076a358bc6ef0b4764a6f40837fbdc1 Mon Sep 17 00:00:00 2001 From: Alex Jones Date: Thu, 18 Jan 2024 16:57:28 +0000 Subject: [PATCH 35/40] try something --- src/pinn_types.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pinn_types.jl b/src/pinn_types.jl index 81e2e83a57..8c850c7463 100644 --- a/src/pinn_types.jl +++ b/src/pinn_types.jl @@ -558,7 +558,7 @@ end function rvcat(example, sym, eltypeθ, switch) return mapreduce(hcat, example) do x if x isa Number - out = typeof(example)(fill(convert(eltypeθ, x), size(example))) + out = convert(eltypeθ, x) out else out = x From 492c2dfe3e5c458cc4da1b98a489aeb741dc3735 Mon Sep 17 00:00:00 2001 From: Alex Jones Date: Thu, 18 Jan 2024 17:06:14 +0000 Subject: [PATCH 36/40] test --- src/pinn_types.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pinn_types.jl b/src/pinn_types.jl index 8c850c7463..689702a378 100644 --- a/src/pinn_types.jl +++ b/src/pinn_types.jl @@ -556,7 +556,7 @@ function reducevcat(vector::Vector, eltypeθ, switch) end function rvcat(example, sym, eltypeθ, switch) - return mapreduce(hcat, example) do x + return map(example) do x if x isa Number out = convert(eltypeθ, x) out From fa988ad6880407fca0f45eea2769deba43581817 Mon Sep 17 00:00:00 2001 From: Alex Jones Date: Thu, 18 Jan 2024 17:29:16 +0000 Subject: [PATCH 37/40] remove show --- src/eq_data.jl | 1 - src/loss_function_generation.jl | 3 --- src/symbolic_utilities.jl | 1 - 3 files changed, 5 deletions(-) diff --git a/src/eq_data.jl b/src/eq_data.jl index b7f304d673..d94e4345d5 100644 --- a/src/eq_data.jl +++ b/src/eq_data.jl @@ -62,7 +62,6 @@ function get_iv_argument(eqs, v::VariableMap) vars = map(eqs) do eq _vars = map(depvar -> get_depvars(eq, [depvar]), v.depvar_ops) f_vars = filter(x -> !isempty(x), _vars) - @show v.ū mapreduce(vars -> mapreduce(op -> v.args[op], vcat, operation.(vars), init = []), vcat, f_vars, init = []) end args_ = map(vars) do _vars diff --git a/src/loss_function_generation.jl b/src/loss_function_generation.jl index 96e3ef276a..211192e1b3 100644 --- a/src/loss_function_generation.jl +++ b/src/loss_function_generation.jl @@ -37,7 +37,6 @@ function build_symbolic_loss_function(pinnrep::PINNRepresentation, eq; (!(phi isa Vector) && phi.f isa Optimisers.Restructure) if psform - @show length(phi) last_indx = [0; accumulate(+, map(length, init_params))][end] ps_range = 1:param_len .+ last_indx get_ps = (θ) -> θ[ps_range] @@ -137,8 +136,6 @@ function generate_derivative_rules(term, eqdata, eltypeθ, dummyvars, derivative end dvs = get_depvars(term, varmap.depvar_ops) - @show eltypeθ - @show methods(derivative) # Orthodox derivatives n(w) = length(arguments(w)) rs = reduce(vcat, [reduce(vcat, [[@rule $((Differential(x)^d)(w)) => diff --git a/src/symbolic_utilities.jl b/src/symbolic_utilities.jl index 9a430e6e2a..fc81f3020a 100644 --- a/src/symbolic_utilities.jl +++ b/src/symbolic_utilities.jl @@ -89,7 +89,6 @@ function get_ε(dim::Int, der_num::Int, ::Type{eltypeθ}, order) where {eltypeθ epsilon = ^(eps(eltypeθ), one(eltypeθ) / (2 + order)) ε = zeros(eltypeθ, dim) ε[der_num] = epsilon - @show typeof(ε) ε end From a49b1b7570ba5a9b62ef99f43921d8afa85ea44e Mon Sep 17 00:00:00 2001 From: Alex Jones Date: Tue, 13 Feb 2024 15:42:56 +0000 Subject: [PATCH 38/40] arrayop wrap --- src/NeuralPDE.jl | 2 +- src/pinn_types.jl | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/NeuralPDE.jl b/src/NeuralPDE.jl index 626b24d7f2..37e5e2871d 100644 --- a/src/NeuralPDE.jl +++ b/src/NeuralPDE.jl @@ -23,7 +23,7 @@ using ArrayInterface import Optim using DomainSets using Symbolics -using Symbolics: wrap, unwrap, arguments, operation, symtype +using Symbolics: wrap, unwrap, arguments, operation, symtype, @arrayop using SymbolicUtils using AdvancedHMC, LogDensityProblems, LinearAlgebra, Functors, MCMCChains using MonteCarloMeasurements diff --git a/src/pinn_types.jl b/src/pinn_types.jl index 532fe1f8aa..dcc9f332d6 100644 --- a/src/pinn_types.jl +++ b/src/pinn_types.jl @@ -522,6 +522,8 @@ function numeric_derivative(phi, x, εs, order, θ) ε = adapt(_type, ε) x = adapt(_type, x) + ε = @arrayop (i,) ε[i] for i in 1:length(ε) + x = @arrayop (i,) x[i] for i in 1:length(x) # any(x->x!=εs[1],εs) # εs is the epsilon for each order, if they are all the same then we use a fancy formula From 5fa91f3542dfc6004a5e7ecaeedab1fc24f11c19 Mon Sep 17 00:00:00 2001 From: Alex Jones Date: Tue, 13 Feb 2024 16:05:11 +0000 Subject: [PATCH 39/40] fix --- Project.toml | 1 - src/NeuralPDE.jl | 1 - src/loss_function_generation.jl | 2 +- src/pinn_types.jl | 32 ++------------------------------ 4 files changed, 3 insertions(+), 33 deletions(-) diff --git a/Project.toml b/Project.toml index 6854fd2c68..5048be40c0 100644 --- a/Project.toml +++ b/Project.toml @@ -64,7 +64,6 @@ MCMCChains = "6" ModelingToolkit = "8" MonteCarloMeasurements = "1" Optim = "1.7.8" -Optimisers = "0.2, 0.3" Optimization = "3" OptimizationOptimisers = "0.1" QuasiMonteCarlo = "0.3.2" diff --git a/src/NeuralPDE.jl b/src/NeuralPDE.jl index 37e5e2871d..f91ebc3fda 100644 --- a/src/NeuralPDE.jl +++ b/src/NeuralPDE.jl @@ -34,7 +34,6 @@ import DomainSets: Domain, ClosedInterval import ModelingToolkit: Interval, infimum, supremum #,Ball import SciMLBase: @add_kwonly, parameterless_type import UnPack: @unpack -import RecursiveArrayTools import ChainRulesCore, Lux, ComponentArrays import ChainRulesCore: @non_differentiable, @ignore_derivatives diff --git a/src/loss_function_generation.jl b/src/loss_function_generation.jl index 211192e1b3..3e2b046c00 100644 --- a/src/loss_function_generation.jl +++ b/src/loss_function_generation.jl @@ -168,7 +168,7 @@ function generate_derivative_rules(term, eqdata, eltypeθ, dummyvars, derivative end sr = @rule switch => 1 - swch = Postwalk(sr) + swch = Postwalk(Chain(sr)) return [mx; rs; vr], swch end diff --git a/src/pinn_types.jl b/src/pinn_types.jl index dcc9f332d6..127748b839 100644 --- a/src/pinn_types.jl +++ b/src/pinn_types.jl @@ -350,23 +350,7 @@ mutable struct PINNRepresentation """ The dependent variables of the system """ - depvars::Any - """ - The independent variables of the system - """ - indvars::Any - """ - A dictionary form of the independent variables. Define the structure ??? - """ - dict_indvars::Any - """ - A dictionary form of the dependent variables. Define the structure ??? - """ - dict_depvars::Any - """ - ??? - """ - dict_depvar_input::Any + varmap::Any """ The logger as provided by the user """ @@ -411,19 +395,7 @@ mutable struct PINNRepresentation """ ??? """ - pde_indvars::Any - """ - ??? - """ - bc_indvars::Any - """ - ??? - """ - pde_integration_vars::Any - """ - ??? - """ - bc_integration_vars::Any + eqdata::Any """ ??? """ From 4f0c9b55b3cfe67fb888849be18259fd00515db4 Mon Sep 17 00:00:00 2001 From: Alex Jones Date: Tue, 13 Feb 2024 19:53:01 +0000 Subject: [PATCH 40/40] almost --- src/NeuralPDE.jl | 2 +- src/loss_function_generation.jl | 22 +++++++++----------- src/pinn_types.jl | 37 +++++++++++++-------------------- 3 files changed, 25 insertions(+), 36 deletions(-) diff --git a/src/NeuralPDE.jl b/src/NeuralPDE.jl index 91dab82003..7d2d0ea922 100644 --- a/src/NeuralPDE.jl +++ b/src/NeuralPDE.jl @@ -24,7 +24,7 @@ using ArrayInterface import Optim using DomainSets using Symbolics -using Symbolics: wrap, unwrap, arguments, operation, symtype, @arrayop +using Symbolics: wrap, unwrap, arguments, operation, symtype, @arrayop, Arr using SymbolicUtils using AdvancedHMC, LogDensityProblems, LinearAlgebra, Functors, MCMCChains using MonteCarloMeasurements diff --git a/src/loss_function_generation.jl b/src/loss_function_generation.jl index 3e2b046c00..831800a6ba 100644 --- a/src/loss_function_generation.jl +++ b/src/loss_function_generation.jl @@ -67,6 +67,7 @@ function build_symbolic_loss_function(pinnrep::PINNRepresentation, eq; full_loss_func = (cord, θ, phi, p) -> begin coords = [[nothing]] @ignore_derivatives coords = get_coords(cord) + @show coords loss_function(coords, θ, phi, get_ps(θ)) end return full_loss_func @@ -110,12 +111,12 @@ function parse_equation(pinnrep::PINNRepresentation, term, ivs; is_integral = fa end dummyvars = unwrap.(dummyvars) - deriv_rules, swch = generate_derivative_rules(term, eqdata, eltypeθ, dummyvars, derivative, varmap, multioutput) + deriv_rules = generate_derivative_rules(term, eqdata, eltypeθ, dummyvars, derivative, varmap, multioutput) ch = Prewalk(Chain(deriv_rules)) expr = ch(term) - expr = swch(expr) + #expr = swch(expr) sym_coords = DestructuredArgs(ivs) ps = DestructuredArgs(varmap.ps) @@ -139,9 +140,8 @@ function generate_derivative_rules(term, eqdata, eltypeθ, dummyvars, derivative # Orthodox derivatives n(w) = length(arguments(w)) rs = reduce(vcat, [reduce(vcat, [[@rule $((Differential(x)^d)(w)) => - derivative(ufunc(w, phi, varmap), reducevcat(arguments(w), eltypeθ, switch), - [get_ε(n(w), - j, eltypeθ, i) for i in 1:d], + derivative(ufunc(w, phi, varmap), reducevcat(arguments(w), eltypeθ), + get_ε(n(w),j, eltypeθ, d), d, θ) for d in differential_order(term, x)] for (j, x) in enumerate(varmap.args[operation(w)])], init = []) @@ -153,10 +153,10 @@ function generate_derivative_rules(term, eqdata, eltypeθ, dummyvars, derivative if isequal(x, y) [(_) -> nothing] else - ε1 = [get_ε(n(w), j, eltypeθ, i) for i in 1:2] - ε2 = [get_ε(n(w), k, eltypeθ, i) for i in 1:2] + ε1 = get_ε(n(w), j, eltypeθ, 1) + ε2 = get_ε(n(w), k, eltypeθ, 1) [@rule $((Differential(x))((Differential(y))(w))) => - derivative((coord_, θ_) -> derivative(ufunc(w, phi, varmap), reducevcat(arguments(w), eltypeθ, switch), + derivative((coord_, θ_) -> derivative(ufunc(w, phi, varmap), reducevcat(arguments(w), eltypeθ), ε2, 1, θ_), reducevcat(arguments(w), eltypeθ, switch), ε1, 1, θ)] end @@ -164,12 +164,10 @@ function generate_derivative_rules(term, eqdata, eltypeθ, dummyvars, derivative end end vr = mapreduce(vcat, dvs, init = []) do w - @rule w => ufunc(w, phi, varmap)(reducevcat(arguments(w), eltypeθ, switch), θ) + @rule w => ufunc(w, phi, varmap)(reducevcat(arguments(w), eltypeθ), θ) end - sr = @rule switch => 1 - swch = Postwalk(Chain(sr)) - return [mx; rs; vr], swch + return [mx; rs; vr] end function generate_integral_rules(eq, eqdata, dummyvars) diff --git a/src/pinn_types.jl b/src/pinn_types.jl index 127748b839..68352e243f 100644 --- a/src/pinn_types.jl +++ b/src/pinn_types.jl @@ -476,6 +476,7 @@ function (f::Phi{<:Lux.AbstractExplicitLayer})(x::Number, θ) end function (f::Phi{<:Lux.AbstractExplicitLayer})(x::AbstractArray, θ) + @show x, typeof(x) y, st = f.f(adapt(parameterless_type(ComponentArrays.getdata(θ)), x), θ, f.st) ChainRulesCore.@ignore_derivatives f.st = st y @@ -486,27 +487,14 @@ function (f::Phi{<:Optimisers.Restructure})(x, θ) end # the method to calculate the derivative -function numeric_derivative(phi, x, εs, order, θ) +function numeric_derivative(phi, x, ε, order, θ) _type = parameterless_type(ComponentArrays.getdata(θ)) - ε = εs[order] - _epsilon = inv(first(ε[ε.!=zero(ε)])) - + _epsilon = inv(first(ε[ε.!=zero(eltype(ε))])) ε = adapt(_type, ε) x = adapt(_type, x) - ε = @arrayop (i,) ε[i] for i in 1:length(ε) - x = @arrayop (i,) x[i] for i in 1:length(x) - - # any(x->x!=εs[1],εs) - # εs is the epsilon for each order, if they are all the same then we use a fancy formula - # if order 1, this is trivially true - if order > 4 || any(x -> x != εs[1], εs) - return (numeric_derivative(phi, x .+ ε, @view(εs[1:(end-1)]), order - 1, θ) - .- - numeric_derivative(phi, x .- ε, @view(εs[1:(end-1)]), order - 1, θ)) .* - _epsilon ./ 2 - elseif order == 4 + if order == 4 return (phi(x .+ 2 .* ε, θ) .- 4 .* phi(x .+ ε, θ) .+ 6 .* phi(x, θ) @@ -524,8 +512,8 @@ function numeric_derivative(phi, x, εs, order, θ) error("This shouldn't happen! Got an order of $(order).") end end -# Hacky workaround for metaprogramming with symbolics -@register_symbolic(numeric_derivative(phi, x, εs, order, θ)) + +#@register_symbolic(numeric_derivative(phi, x, ε, order, θ)) function ufunc(u, phi, v) if symtype(phi) isa AbstractArray @@ -550,18 +538,18 @@ https://github.com/SciML/NeuralPDE.jl/pull/627/files -function reducevcat(vector::Vector, eltypeθ, switch) +function reducevcat(vector::Vector, eltypeθ) isnothing(vector) && return [[nothing]] if all(x -> x isa Number, vector) return vector else z = findfirst(x -> !(x isa Number), vector) - return rvcat(vector, vector[z], eltypeθ, switch) + return rvcat(vector, vector[z], eltypeθ) end end -function rvcat(example, sym, eltypeθ, switch) - return map(example) do x +function rvcat(example, sym, eltypeθ) + out = map(example) do x if x isa Number out = convert(eltypeθ, x) out @@ -570,6 +558,9 @@ function rvcat(example, sym, eltypeθ, switch) out end end + #out = @arrayop (i,) out[i] i in 1:length(out) + + return out end -@register_symbolic(rvcat(vector, example, eltypeθ, switch)) \ No newline at end of file +#@register_symbolic(rvcat(vector, example, eltypeθ, switch)) \ No newline at end of file