diff --git a/src/adaptive_losses.jl b/src/adaptive_losses.jl index ca949ec45..08c754d82 100644 --- a/src/adaptive_losses.jl +++ b/src/adaptive_losses.jl @@ -253,3 +253,90 @@ function generate_adaptive_loss_function(pinnrep::PINNRepresentation, nothing end end + +""" + NeuralTangentKernelLoss(reweight_every; + pde_loss_weights = 1.0, + bc_loss_weights = 1.0, + additional_loss_weights = 1.0) + +A way of adaptively reweighting the components of the loss function using the formulation +from Neural Tangent Kernel of the network + +## Positional Arguments + +* `reweight_every`: how often to reweight the BC loss functions, measured in iterations. + Reweighting is somewhat expensive since it involves evaluating the gradient of each + component loss function. When set to 0, this implies estimating the weights at the first + iteration only, + +## References + +When and why PINNs fail to train: A neural tangent kernel perspective +Sifan Wang, Xinling Yu, Paris Perdikaris +https://arxiv.org/pdf/2007.14527 + +With code reference: +https://github.com/PredictiveIntelligenceLab/PINNsNTK +""" +mutable struct NeuralTangentKernelLoss{T <: Real} <: AbstractAdaptiveLoss + reweight_every::Int64 + pde_loss_weights::Vector{T} + bc_loss_weights::Vector{T} + additional_loss_weights::Vector{T} + SciMLBase.@add_kwonly function NeuralTangentKernelLoss{T}(reweight_every; + pde_loss_weights = 1.0, + bc_loss_weights = 1.0, + additional_loss_weights = 1.0) where { + T <: + Real + } + new(convert(Int64, reweight_every), + NeuralPDE.vectorify(pde_loss_weights, T), NeuralPDE.vectorify(bc_loss_weights, T), + NeuralPDE.vectorify(additional_loss_weights, T)) + end +end +# default to Float64 +SciMLBase.@add_kwonly function NeuralTangentKernelLoss(reweight_every; + pde_loss_weights = 1.0, + bc_loss_weights = 1.0, + additional_loss_weights = 1.0) + NeuralTangentKernelLoss{Float64}(reweight_every; + pde_loss_weights = pde_loss_weights, + bc_loss_weights = bc_loss_weights, + additional_loss_weights = additional_loss_weights) +end + +function NeuralPDE.generate_adaptive_loss_function(pinnrep::NeuralPDE.PINNRepresentation, + adaloss::NeuralTangentKernelLoss, + pde_loss_functions, bc_loss_functions) + iteration = pinnrep.iteration + + adaloss_T = eltype(adaloss.pde_loss_weights) + + function run_neural_tangent_kernel_adaptive_loss(θ, pde_losses, bc_losses) + + if iteration[1] % adaloss.reweight_every == 0 + + Kuus = [(Zygote.gradient(bc_loss_function, θ))[1] for bc_loss_function in bc_loss_functions] + Krrs = [(Zygote.gradient(pde_loss_function, θ))[1] for pde_loss_function in pde_loss_functions] + + TrKuu = [sum(Kuu.^2) for Kuu in Kuus] + TrKrr = [sum(Krr.^2) for Krr in Krrs] + + TrK = sum(TrKuu) + sum(TrKrr) + # nonzero_divisor_eps = adaloss_T isa Float64 ? Float64(1e-11) : convert(adaloss_T, 1e-7) + + adaloss.bc_loss_weights = TrK./TrKuu + adaloss.pde_loss_weights = TrK./TrKrr + + NeuralPDE.logvector(pinnrep.logger, adaloss.pde_loss_weights, + "adaptive_loss/pde_loss_weights", iteration[1]) + NeuralPDE.logvector(pinnrep.logger, adaloss.bc_loss_weights, + "adaptive_loss/bc_loss_weights", + iteration[1]) + + end + nothing + end +end diff --git a/test/adaptive_loss_tests.jl b/test/adaptive_loss_tests.jl index 5259a019f..f3357e2c0 100644 --- a/test/adaptive_loss_tests.jl +++ b/test/adaptive_loss_tests.jl @@ -9,7 +9,7 @@ nonadaptive_loss = NeuralPDE.NonAdaptiveLoss(pde_loss_weights = 1, bc_loss_weigh gradnormadaptive_loss = NeuralPDE.GradientScaleAdaptiveLoss(100, pde_loss_weights = 1e3, bc_loss_weights = 1) adaptive_loss = NeuralPDE.MiniMaxAdaptiveLoss(100; pde_loss_weights = 1, - bc_loss_weights = 1) + bc_loss_weights = 1) adaptive_losses = [nonadaptive_loss, gradnormadaptive_loss, adaptive_loss] maxiters = 4000 seed = 60 @@ -81,8 +81,11 @@ error_results_no_logs = map(test_2d_poisson_equation_adaptive_loss_no_logs_run_s @show error_results_no_logs[1][:total_diff_rel] @show error_results_no_logs[2][:total_diff_rel] @show error_results_no_logs[3][:total_diff_rel] +@show error_results_no_logs[4][:total_diff_rel] + # accuracy tests, these work for this specific seed but might not for others # note that this doesn't test that the adaptive losses are outperforming the nonadaptive loss, which is not guaranteed, and seed/arch/hyperparam/pde etc dependent @test error_results_no_logs[1][:total_diff_rel] < 0.4 @test error_results_no_logs[2][:total_diff_rel] < 0.4 @test error_results_no_logs[3][:total_diff_rel] < 0.4 +@test error_results_no_logs[4][:total_diff_rel] < 0.4