Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding NTK adaptive loss function #834

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions src/adaptive_losses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -253,3 +253,90 @@ function generate_adaptive_loss_function(pinnrep::PINNRepresentation,
nothing
end
end

"""
NeuralTangentKernelLoss(reweight_every;
pde_loss_weights = 1.0,
bc_loss_weights = 1.0,
additional_loss_weights = 1.0)

A way of adaptively reweighting the components of the loss function using the formulation
from Neural Tangent Kernel of the network

## Positional Arguments

* `reweight_every`: how often to reweight the BC loss functions, measured in iterations.
Reweighting is somewhat expensive since it involves evaluating the gradient of each
component loss function. When set to 0, this implies estimating the weights at the first
iteration only,

## References

When and why PINNs fail to train: A neural tangent kernel perspective
Sifan Wang, Xinling Yu, Paris Perdikaris
https://arxiv.org/pdf/2007.14527

With code reference:
https://github.com/PredictiveIntelligenceLab/PINNsNTK
"""
mutable struct NeuralTangentKernelLoss{T <: Real} <: AbstractAdaptiveLoss
reweight_every::Int64
pde_loss_weights::Vector{T}
bc_loss_weights::Vector{T}
additional_loss_weights::Vector{T}
SciMLBase.@add_kwonly function NeuralTangentKernelLoss{T}(reweight_every;
pde_loss_weights = 1.0,
bc_loss_weights = 1.0,
additional_loss_weights = 1.0) where {
T <:
Real
}
new(convert(Int64, reweight_every),
NeuralPDE.vectorify(pde_loss_weights, T), NeuralPDE.vectorify(bc_loss_weights, T),
NeuralPDE.vectorify(additional_loss_weights, T))
end
end
# default to Float64
SciMLBase.@add_kwonly function NeuralTangentKernelLoss(reweight_every;
pde_loss_weights = 1.0,
bc_loss_weights = 1.0,
additional_loss_weights = 1.0)
NeuralTangentKernelLoss{Float64}(reweight_every;
pde_loss_weights = pde_loss_weights,
bc_loss_weights = bc_loss_weights,
additional_loss_weights = additional_loss_weights)
end

function NeuralPDE.generate_adaptive_loss_function(pinnrep::NeuralPDE.PINNRepresentation,
adaloss::NeuralTangentKernelLoss,
pde_loss_functions, bc_loss_functions)
iteration = pinnrep.iteration

adaloss_T = eltype(adaloss.pde_loss_weights)

function run_neural_tangent_kernel_adaptive_loss(θ, pde_losses, bc_losses)

if iteration[1] % adaloss.reweight_every == 0

Kuus = [(Zygote.gradient(bc_loss_function, θ))[1] for bc_loss_function in bc_loss_functions]
Krrs = [(Zygote.gradient(pde_loss_function, θ))[1] for pde_loss_function in pde_loss_functions]

TrKuu = [sum(Kuu.^2) for Kuu in Kuus]
TrKrr = [sum(Krr.^2) for Krr in Krrs]

TrK = sum(TrKuu) + sum(TrKrr)
# nonzero_divisor_eps = adaloss_T isa Float64 ? Float64(1e-11) : convert(adaloss_T, 1e-7)

adaloss.bc_loss_weights = TrK./TrKuu
adaloss.pde_loss_weights = TrK./TrKrr

NeuralPDE.logvector(pinnrep.logger, adaloss.pde_loss_weights,
"adaptive_loss/pde_loss_weights", iteration[1])
NeuralPDE.logvector(pinnrep.logger, adaloss.bc_loss_weights,
"adaptive_loss/bc_loss_weights",
iteration[1])

end
nothing
end
end
5 changes: 4 additions & 1 deletion test/adaptive_loss_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ nonadaptive_loss = NeuralPDE.NonAdaptiveLoss(pde_loss_weights = 1, bc_loss_weigh
gradnormadaptive_loss = NeuralPDE.GradientScaleAdaptiveLoss(100, pde_loss_weights = 1e3,
bc_loss_weights = 1)
adaptive_loss = NeuralPDE.MiniMaxAdaptiveLoss(100; pde_loss_weights = 1,
bc_loss_weights = 1)
bc_loss_weights = 1)
adaptive_losses = [nonadaptive_loss, gradnormadaptive_loss, adaptive_loss]
maxiters = 4000
seed = 60
Expand Down Expand Up @@ -81,8 +81,11 @@ error_results_no_logs = map(test_2d_poisson_equation_adaptive_loss_no_logs_run_s
@show error_results_no_logs[1][:total_diff_rel]
@show error_results_no_logs[2][:total_diff_rel]
@show error_results_no_logs[3][:total_diff_rel]
@show error_results_no_logs[4][:total_diff_rel]

# accuracy tests, these work for this specific seed but might not for others
# note that this doesn't test that the adaptive losses are outperforming the nonadaptive loss, which is not guaranteed, and seed/arch/hyperparam/pde etc dependent
@test error_results_no_logs[1][:total_diff_rel] < 0.4
@test error_results_no_logs[2][:total_diff_rel] < 0.4
@test error_results_no_logs[3][:total_diff_rel] < 0.4
@test error_results_no_logs[4][:total_diff_rel] < 0.4
Loading