Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Parsing with Symbolics #678

Closed
wants to merge 49 commits into from
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
043d429
begin symbolics overhaul
xtalax Apr 12, 2023
867ed37
include
xtalax Apr 12, 2023
9d52d56
further progress
xtalax Apr 13, 2023
5e27b18
start new loss
xtalax Apr 27, 2023
1947fb8
cardinalize eqs
xtalax May 16, 2023
27fc8bc
add todos
xtalax May 17, 2023
1fcd390
polish loss and refactor
xtalax May 25, 2023
b85eaa8
fix
xtalax May 25, 2023
8e39ede
"final" parsing updates
xtalax May 26, 2023
0f99811
oops
xtalax May 26, 2023
60aaa6b
Merge branch 'SciML:master' into parsing
xtalax May 30, 2023
3806caf
test fixes
xtalax May 30, 2023
c23df24
more fixes
xtalax May 30, 2023
5fdc7a5
move to varmap, deprecate old parsing
xtalax May 31, 2023
5891313
more fixes
xtalax Jun 5, 2023
18c3449
fix the tests a bit more
xtalax Jun 7, 2023
82191e7
remove transform expression
xtalax Jun 7, 2023
70f1ce3
reinstate dot, closer
xtalax Jun 8, 2023
b3967aa
last confusing errors?
xtalax Jun 15, 2023
17d0150
fix test
xtalax Jun 16, 2023
0011392
add multioutput
xtalax Jun 19, 2023
8c3ad76
ignore ds_store
xtalax Jun 19, 2023
a949169
change to x(0)
xtalax Jul 31, 2023
edb83d3
Merge branch 'parsing' into rel
xtalax Jan 3, 2024
8704b2e
Merge pull request #3 from SciML/rel
xtalax Jan 3, 2024
6df392a
fix deved package
xtalax Jan 4, 2024
a7d61ec
ditto
xtalax Jan 4, 2024
1cd8076
last ditch fix symbolic error
xtalax Jan 4, 2024
1299565
fix switch
xtalax Jan 4, 2024
aa12321
undo stupidity
xtalax Jan 4, 2024
cd0052e
fix switch
xtalax Jan 4, 2024
2e57081
"
xtalax Jan 4, 2024
39bd9c7
rerun tests
xtalax Jan 7, 2024
aafd11d
"
xtalax Jan 7, 2024
cf17091
test
xtalax Jan 8, 2024
4a0ee8a
Merge branch 'master' into parsing
xtalax Jan 8, 2024
765c047
try other method
xtalax Jan 12, 2024
8406bed
rvcat
xtalax Jan 18, 2024
82bd6bd
Merge branch 'master' into parsing
xtalax Jan 18, 2024
60e2c8c
try something
xtalax Jan 18, 2024
492c2df
test
xtalax Jan 18, 2024
fa988ad
remove show
xtalax Jan 18, 2024
da31f19
Merge branch 'parsing' into tmp
xtalax Feb 9, 2024
e2460f6
Merge pull request #4 from SciML/tmp
xtalax Feb 9, 2024
a49b1b7
arrayop wrap
xtalax Feb 13, 2024
5fa91f3
fix
xtalax Feb 13, 2024
89ad740
Merge branch 'parsing' into tmp1
xtalax Feb 13, 2024
143367c
Merge pull request #5 from SciML/tmp1
xtalax Feb 13, 2024
4f0c9b5
almost
xtalax Feb 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
PDEBase = "a7812802-0625-4b9e-961c-d332478797e5"
QuasiMonteCarlo = "8a4e6c94-4038-4cdc-81c3-7e6ffdb2a71b"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Expand Down Expand Up @@ -79,4 +80,4 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "CUDA", "SafeTestsets", "OptimizationOptimisers", "OptimizationOptimJL", "Pkg", "OrdinaryDiffEq", "IntegralsCuba"]
test = ["Test", "CUDA", "SafeTestsets", "OptimizationOptimisers", "OptimizationOptimJL", "Pkg", "OrdinaryDiffEq", "IntegralsCuba"]
6 changes: 4 additions & 2 deletions src/NeuralPDE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ using Integrals, IntegralsCubature
using QuasiMonteCarlo
using RuntimeGeneratedFunctions
using SciMLBase
using PDEBase
using PDEBase: cardinalize_eqs!
using Statistics
using ArrayInterface
import Optim
Expand All @@ -32,13 +34,13 @@ import RecursiveArrayTools
import ChainRulesCore, Flux, Lux, ComponentArrays
import ChainRulesCore: @non_differentiable

RuntimeGeneratedFunctions.init(@__MODULE__)

abstract type AbstractPINN end
RuntimeGeneratedFunctions.init(@__MODULE__)

abstract type AbstractTrainingStrategy end

include("pinn_types.jl")
include("eq_data.jl")
include("symbolic_utilities.jl")
include("training_strategies.jl")
include("adaptive_losses.jl")
Expand Down
157 changes: 28 additions & 129 deletions src/discretize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,31 +40,31 @@ for Flux.Chain, and

for Lux.AbstractExplicitLayer
"""
function build_symbolic_loss_function(pinnrep::PINNRepresentation, eqs;
function build_symbolic_loss_function(pinnrep::PINNRepresentation, eq;
eq_params = SciMLBase.NullParameters(),
param_estim = false,
default_p = nothing,
bc_indvars = pinnrep.indvars,
bc_indvars = pinnrep.v.x̄,
integrand = nothing,
dict_transformation_vars = nothing,
transformation_vars = nothing,
integrating_depvars = pinnrep.depvars)
@unpack indvars, depvars, dict_indvars, dict_depvars, dict_depvar_input,
integrating_depvars = pinnrep.v.ū)
@unpack v, eqdata,
phi, derivative, integral,
multioutput, init_params, strategy, eq_params,
param_estim, default_p = pinnrep

eltypeθ = eltype(pinnrep.flat_init_params)

if integrand isa Nothing
loss_function = parse_equation(pinnrep, eqs)
this_eq_pair = pair(eqs, depvars, dict_depvars, dict_depvar_input)
this_eq_indvars = unique(vcat(values(this_eq_pair)...))
loss_function = parse_equation(pinnrep, eq)
this_eq_pair = pair(eq, depvars, dict_depvars, dict_depvar_input)
this_eq_indvars = indvars(eq, eqmap)
else
this_eq_pair = Dict(map(intvars -> dict_depvars[intvars] => dict_depvar_input[intvars],
integrating_depvars))
this_eq_indvars = transformation_vars isa Nothing ?
unique(vcat(values(this_eq_pair)...)) : transformation_vars
unique(indvars(eq, eqmap)) : transformation_vars
loss_function = integrand
end

Expand Down Expand Up @@ -287,12 +287,11 @@ function get_bounds(domains, eqs, bcs, eltypeθ, _indvars::Array, _depvars::Arra
return get_bounds(domains, eqs, bcs, eltypeθ, dict_indvars, dict_depvars, strategy)
end

function get_bounds(domains, eqs, bcs, eltypeθ, dict_indvars, dict_depvars,
strategy::QuadratureTraining)
dict_lower_bound = Dict([Symbol(d.variables) => infimum(d.domain) for d in domains])
dict_upper_bound = Dict([Symbol(d.variables) => supremum(d.domain) for d in domains])

pde_args = get_argument(eqs, dict_indvars, dict_depvars)
function get_bounds(domains, eqs, bcs, eltypeθ, v::VariableMap, strategy::QuadratureTraining)
dict_lower_bound = Dict([d.variables => infimum(d.domain) for d in domains])
dict_upper_bound = Dict([d.variables => supremum(d.domain) for d in domains])
#! Fix this to work with a var_eq mapping
pde_args = get_argument(eqs, v)

pde_lower_bounds = map(pde_args) do pd
span = map(p -> get(dict_lower_bound, p, p), pd)
Expand Down Expand Up @@ -342,10 +341,9 @@ function get_bounds(domains, eqs, bcs, eltypeθ, dict_indvars, dict_depvars, str
end

function get_numeric_integral(pinnrep::PINNRepresentation)
@unpack strategy, indvars, depvars, multioutput, derivative,
depvars, indvars, dict_indvars, dict_depvars = pinnrep
@unpack strategy, multioutput, derivative, varmap = pinnrep

integral = (u, cord, phi, integrating_var_id, integrand_func, lb, ub, θ; strategy = strategy, indvars = indvars, depvars = depvars, dict_indvars = dict_indvars, dict_depvars = dict_depvars) -> begin
integral = (u, cord, phi, integrating_var_id, integrand_func, lb, ub, θ; strategy = strategy, varmap=varmap) -> begin
function integration_(cord, lb, ub, θ)
cord_ = cord
function integrand_(x, p)
Expand Down Expand Up @@ -402,6 +400,7 @@ For more information, see `discretize` and `PINNRepresentation`.
"""
function SciMLBase.symbolic_discretize(pde_system::PDESystem,
discretization::PhysicsInformedNN)
cardinalize_eqs!(pde_system)
eqs = pde_system.eqs
bcs = pde_system.bcs
chain = discretization.chain
Expand All @@ -416,92 +415,19 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem,
additional_loss = discretization.additional_loss
adaloss = discretization.adaptive_loss

depvars, indvars, dict_indvars, dict_depvars, dict_depvar_input = get_vars(pde_system.indvars,
pde_system.depvars)
v = VariableMap(pde_system, discretization)

multioutput = discretization.multioutput
init_params = discretization.init_params
# Find the derivative orders in the bcs
bcorders = Dict(map(x -> x => d_orders(x, pdesys.bcs), all_ivs(v)))
# Create a map of each variable to their boundary conditions including initial conditions
boundarymap = parse_bcs(pdesys.bcs, v, bcorders)

if init_params === nothing
# Use the initialization of the neural network framework
# But for Lux, default to Float64
# For Flux, default to the types matching the values in the neural network
# This is done because Float64 is almost always better for these applications
# But with Flux there's already a chosen type from the user

if chain isa AbstractArray
if chain[1] isa Flux.Chain
init_params = map(chain) do x
_x = Flux.destructure(x)[1]
end
else
x = map(chain) do x
_x = ComponentArrays.ComponentArray(Lux.initialparameters(Random.default_rng(),
x))
Float64.(_x) # No ComponentArray GPU support
end
names = ntuple(i -> depvars[i], length(chain))
init_params = ComponentArrays.ComponentArray(NamedTuple{names}(i
for i in x))
end
else
if chain isa Flux.Chain
init_params = Flux.destructure(chain)[1]
init_params = init_params isa Array ? Float64.(init_params) :
init_params
else
init_params = Float64.(ComponentArrays.ComponentArray(Lux.initialparameters(Random.default_rng(),
chain)))
end
end
else
init_params = init_params
end

if (discretization.phi isa Vector && discretization.phi[1].f isa Optimisers.Restructure) ||
(!(discretization.phi isa Vector) && discretization.phi.f isa Optimisers.Restructure)
# Flux.Chain
flat_init_params = multioutput ? reduce(vcat, init_params) : init_params
flat_init_params = param_estim == false ? flat_init_params :
vcat(flat_init_params,
adapt(typeof(flat_init_params), default_p))
else
flat_init_params = if init_params isa ComponentArrays.ComponentArray
init_params
elseif multioutput
@assert length(init_params) == length(depvars)
names = ntuple(i -> depvars[i], length(init_params))
x = ComponentArrays.ComponentArray(NamedTuple{names}(i for i in init_params))
else
ComponentArrays.ComponentArray(init_params)
end
flat_init_params = if param_estim == false && multioutput
ComponentArrays.ComponentArray(; depvar = flat_init_params)
elseif param_estim == false && !multioutput
flat_init_params
else
ComponentArrays.ComponentArray(; depvar = flat_init_params, p = default_p)
end
end

eltypeθ = eltype(flat_init_params)

if adaloss === nothing
adaloss = NonAdaptiveLoss{eltypeθ}()
end
eqdata = EquationData(pdesys, v)

multioutput = discretization.multioutput
init_params = discretization.init_params
phi = discretization.phi

if (phi isa Vector && phi[1].f isa Lux.AbstractExplicitLayer)
for ϕ in phi
ϕ.st = adapt(parameterless_type(ComponentArrays.getdata(flat_init_params)),
ϕ.st)
end
elseif (!(phi isa Vector) && phi.f isa Lux.AbstractExplicitLayer)
phi.st = adapt(parameterless_type(ComponentArrays.getdata(flat_init_params)),
phi.st)
end

derivative = discretization.derivative
strategy = discretization.strategy

Expand All @@ -510,32 +436,11 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem,
iteration = discretization.iteration
self_increment = discretization.self_increment

if !(eqs isa Array)
eqs = [eqs]
end

pde_indvars = if strategy isa QuadratureTraining
get_argument(eqs, dict_indvars, dict_depvars)
else
get_variables(eqs, dict_indvars, dict_depvars)
end

bc_indvars = if strategy isa QuadratureTraining
get_argument(bcs, dict_indvars, dict_depvars)
else
get_variables(bcs, dict_indvars, dict_depvars)
end

pde_integration_vars = get_integration_variables(eqs, dict_indvars, dict_depvars)
bc_integration_vars = get_integration_variables(bcs, dict_indvars, dict_depvars)

pinnrep = PINNRepresentation(eqs, bcs, domains, eq_params, defaults, default_p,
param_estim, additional_loss, adaloss, depvars, indvars,
dict_indvars, dict_depvars, dict_depvar_input, logger,
param_estim, additional_loss, adaloss, v, logger,
multioutput, iteration, init_params, flat_init_params, phi,
derivative,
strategy, pde_indvars, bc_indvars, pde_integration_vars,
bc_integration_vars, nothing, nothing, nothing, nothing)
strategy, eqdata, nothing, nothing, nothing, nothing)

integral = get_numeric_integral(pinnrep)

Expand All @@ -553,15 +458,9 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem,
pinnrep.symbolic_pde_loss_functions = symbolic_pde_loss_functions
pinnrep.symbolic_bc_loss_functions = symbolic_bc_loss_functions

datafree_pde_loss_functions = [build_loss_function(pinnrep, eq, pde_indvar)
for (eq, pde_indvar, integration_indvar) in zip(eqs,
pde_indvars,
pde_integration_vars)]
datafree_pde_loss_functions = [build_loss_function(pinnrep, eq) for eq in eqs]

datafree_bc_loss_functions = [build_loss_function(pinnrep, bc, bc_indvar)
for (bc, bc_indvar, integration_indvar) in zip(bcs,
bc_indvars,
bc_integration_vars)]
datafree_bc_loss_functions = [build_loss_function(pinnrep, bc) for bc in bcs]

pde_loss_functions, bc_loss_functions = merge_strategy_with_loss_function(pinnrep,
strategy,
Expand Down
54 changes: 54 additions & 0 deletions src/eq_data.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
struct EquationData <: PDEBase.AbstractVarEqMapping
depvarmap
indvarmap
pde_indvars
bc_indvars
argmap
end

function EquationData(pdesys, v)
eqs = pdesys.eqs
bcs = pdesys.bcs
alleqs = vcat(eqs, bcs)

argmap = map(alleqs) do eq
eq => get_argument([eq], v)[1]
end
depvarmap = map(alleqs) do eq
eq => get_depvars(eq, v.depvar_ops)
end
indvarmap = map(alleqs) do eq
eq => get_indvars(eq, indvars(v))
end
pde_indvars = if strategy isa QuadratureTraining
get_argument(eqs, v)
else
get_variables(eqs, v)
end

bc_indvars = if strategy isa QuadratureTraining
get_argument(bcs, v)
else
get_variables(bcs, v)
end

EquationData(depvarmap, indvarmap, pde_indvars, bc_indvars, argmap)
end

function depvars(eq, eqdata::EquationData)
eqdata.depvarmap[eq]
end

function indvars(eq, eqdata::EquationData)
eqdata.indvarmap[eq]
end

function pde_indvars(eqdata::EquationData)
eqdata.pde_indvars
end

function bc_indvars(eqdata::EquationData)
eqdata.bc_indvars
end

argument(eq, eqdata) = eqdata.argmap[eq]
Loading