From 604f88c3bca6ec63fd36e04f31f20ee6221bc1ad Mon Sep 17 00:00:00 2001 From: ayushinav Date: Tue, 12 Mar 2024 23:05:37 -0400 Subject: [PATCH 1/2] adding NTK adaptive loss --- src/adaptive_losses.jl | 107 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 107 insertions(+) diff --git a/src/adaptive_losses.jl b/src/adaptive_losses.jl index 3a1c4a79d..79efd43d7 100644 --- a/src/adaptive_losses.jl +++ b/src/adaptive_losses.jl @@ -248,3 +248,110 @@ function generate_adaptive_loss_function(pinnrep::PINNRepresentation, nothing end end + +""" + NeuralTangentKernelLoss(reweight_every; + pde_loss_weights = 1.0, + bc_loss_weights = 1.0, + additional_loss_weights = 1.0) + +A way of adaptively reweighting the components of the loss function using the formulation +from Neural Tangent Kernel of the network + +## Positional Arguments + +* `reweight_every`: how often to reweight the BC loss functions, measured in iterations. + Reweighting is somewhat expensive since it involves evaluating the gradient of each + component loss function. When set to 0, this implies estimating the weights at the first + iteration only, + +## References + +When and why PINNs fail to train: A neural tangent kernel perspective +Sifan Wang, Xinling Yu, Paris Perdikaris +https://arxiv.org/pdf/2007.14527 + +With code reference: +https://github.com/PredictiveIntelligenceLab/PINNsNTK +""" +mutable struct NeuralTangentKernelLoss{T <: Real} <: AbstractAdaptiveLoss + reweight_every::Int64 + pde_loss_weights::Vector{T} + bc_loss_weights::Vector{T} + additional_loss_weights::Vector{T} + SciMLBase.@add_kwonly function NeuralTangentKernelLoss{T}(reweight_every; + pde_loss_weights = 1.0, + bc_loss_weights = 1.0, + additional_loss_weights = 1.0) where { + T <: + Real + } + new(convert(Int64, reweight_every), + NeuralPDE.vectorify(pde_loss_weights, T), NeuralPDE.vectorify(bc_loss_weights, T), + NeuralPDE.vectorify(additional_loss_weights, T)) + end +end +# default to Float64 +SciMLBase.@add_kwonly function NeuralTangentKernelLoss(reweight_every; + pde_loss_weights = 1.0, + bc_loss_weights = 1.0, + additional_loss_weights = 1.0) + NeuralTangentKernelLoss{Float64}(reweight_every; + pde_loss_weights = pde_loss_weights, + bc_loss_weights = bc_loss_weights, + additional_loss_weights = additional_loss_weights) +end + +function NeuralPDE.generate_adaptive_loss_function(pinnrep::NeuralPDE.PINNRepresentation, + adaloss::NeuralTangentKernelLoss, + pde_loss_functions, bc_loss_functions) + iteration = pinnrep.iteration + + adaloss_T = eltype(adaloss.pde_loss_weights) + + function run_neural_tangent_kernel_adaptive_loss(θ, pde_losses, bc_losses) + + if adaloss.reweight_every == 0 && iteration[1] == 2 # when NTK remains constant throughout + print("CONSTANT NTK \n") + Kuus = [(Zygote.gradient(bc_loss_function, θ))[1] for bc_loss_function in bc_loss_functions] + Krrs = [(Zygote.gradient(pde_loss_function, θ))[1] for pde_loss_function in pde_loss_functions] + + TrKuu = [sum(Kuu.^2) for Kuu in Kuus] + TrKrr = [sum(Krr.^2) for Krr in Krrs] + + TrK = sum(TrKuu) + sum(TrKrr) + nonzero_divisor_eps = adaloss_T isa Float64 ? Float64(1e-11) : convert(adaloss_T, 1e-7) + + adaloss.bc_loss_weights = TrK./(TrKuu .+ nonzero_divisor_eps) + adaloss.pde_loss_weights = TrK./(TrKrr .+ nonzero_divisor_eps) + + NeuralPDE.logvector(pinnrep.logger, adaloss.pde_loss_weights, + "adaptive_loss/pde_loss_weights", iteration[1]) + NeuralPDE.logvector(pinnrep.logger, adaloss.bc_loss_weights, + "adaptive_loss/bc_loss_weights", + iteration[1]) + + elseif adaloss.reweight_every != 0 && iteration[1] % adaloss.reweight_every == 0 + + Kuus = [(Zygote.gradient(bc_loss_function, θ))[1] for bc_loss_function in bc_loss_functions] + Krrs = [(Zygote.gradient(pde_loss_function, θ))[1] for pde_loss_function in pde_loss_functions] + + TrKuu = [sum(Kuu.^2) for Kuu in Kuus] + TrKrr = [sum(Krr.^2) for Krr in Krrs] + + TrK = sum(TrKuu) + sum(TrKrr) + nonzero_divisor_eps = adaloss_T isa Float64 ? Float64(1e-11) : convert(adaloss_T, 1e-7) + + adaloss.bc_loss_weights = TrK./(TrKuu .+ nonzero_divisor_eps) + adaloss.pde_loss_weights = TrK./(TrKrr .+ nonzero_divisor_eps) + + NeuralPDE.logvector(pinnrep.logger, adaloss.pde_loss_weights, + "adaptive_loss/pde_loss_weights", iteration[1]) + NeuralPDE.logvector(pinnrep.logger, adaloss.bc_loss_weights, + "adaptive_loss/bc_loss_weights", + iteration[1]) + + end + nothing + end +end From 66e651a8d3e6fbb9d5919a15b5cee0648ae47425 Mon Sep 17 00:00:00 2001 From: ayushinav Date: Mon, 13 May 2024 22:46:34 -0400 Subject: [PATCH 2/2] shifting to work on sampling strategy --- src/adaptive_losses.jl | 28 ++++------------------------ test/adaptive_loss_tests.jl | 8 +++++++- 2 files changed, 11 insertions(+), 25 deletions(-) diff --git a/src/adaptive_losses.jl b/src/adaptive_losses.jl index 79efd43d7..06a8c9599 100644 --- a/src/adaptive_losses.jl +++ b/src/adaptive_losses.jl @@ -311,27 +311,7 @@ function NeuralPDE.generate_adaptive_loss_function(pinnrep::NeuralPDE.PINNRepres function run_neural_tangent_kernel_adaptive_loss(θ, pde_losses, bc_losses) - if adaloss.reweight_every == 0 && iteration[1] == 2 # when NTK remains constant throughout - print("CONSTANT NTK \n") - Kuus = [(Zygote.gradient(bc_loss_function, θ))[1] for bc_loss_function in bc_loss_functions] - Krrs = [(Zygote.gradient(pde_loss_function, θ))[1] for pde_loss_function in pde_loss_functions] - - TrKuu = [sum(Kuu.^2) for Kuu in Kuus] - TrKrr = [sum(Krr.^2) for Krr in Krrs] - - TrK = sum(TrKuu) + sum(TrKrr) - nonzero_divisor_eps = adaloss_T isa Float64 ? Float64(1e-11) : convert(adaloss_T, 1e-7) - - adaloss.bc_loss_weights = TrK./(TrKuu .+ nonzero_divisor_eps) - adaloss.pde_loss_weights = TrK./(TrKrr .+ nonzero_divisor_eps) - - NeuralPDE.logvector(pinnrep.logger, adaloss.pde_loss_weights, - "adaptive_loss/pde_loss_weights", iteration[1]) - NeuralPDE.logvector(pinnrep.logger, adaloss.bc_loss_weights, - "adaptive_loss/bc_loss_weights", - iteration[1]) - - elseif adaloss.reweight_every != 0 && iteration[1] % adaloss.reweight_every == 0 + if iteration[1] % adaloss.reweight_every == 0 Kuus = [(Zygote.gradient(bc_loss_function, θ))[1] for bc_loss_function in bc_loss_functions] Krrs = [(Zygote.gradient(pde_loss_function, θ))[1] for pde_loss_function in pde_loss_functions] @@ -340,10 +320,10 @@ function NeuralPDE.generate_adaptive_loss_function(pinnrep::NeuralPDE.PINNRepres TrKrr = [sum(Krr.^2) for Krr in Krrs] TrK = sum(TrKuu) + sum(TrKrr) - nonzero_divisor_eps = adaloss_T isa Float64 ? Float64(1e-11) : convert(adaloss_T, 1e-7) + # nonzero_divisor_eps = adaloss_T isa Float64 ? Float64(1e-11) : convert(adaloss_T, 1e-7) - adaloss.bc_loss_weights = TrK./(TrKuu .+ nonzero_divisor_eps) - adaloss.pde_loss_weights = TrK./(TrKrr .+ nonzero_divisor_eps) + adaloss.bc_loss_weights = TrK./TrKuu + adaloss.pde_loss_weights = TrK./TrKrr NeuralPDE.logvector(pinnrep.logger, adaloss.pde_loss_weights, "adaptive_loss/pde_loss_weights", iteration[1]) diff --git a/test/adaptive_loss_tests.jl b/test/adaptive_loss_tests.jl index 8180a6895..f9e57752c 100644 --- a/test/adaptive_loss_tests.jl +++ b/test/adaptive_loss_tests.jl @@ -10,7 +10,10 @@ gradnormadaptive_loss = NeuralPDE.GradientScaleAdaptiveLoss(100, pde_loss_weight bc_loss_weights = 1) adaptive_loss = NeuralPDE.MiniMaxAdaptiveLoss(100; pde_loss_weights = 1, bc_loss_weights = 1) -adaptive_losses = [nonadaptive_loss, gradnormadaptive_loss, adaptive_loss] +ntk_loss = NeuralPDE.NeuralTangentKernelLoss(10; pde_loss_weights = 1, bc_loss_weights = 1) + +adaptive_losses = [nonadaptive_loss, gradnormadaptive_loss, adaptive_loss, ntk_loss] +# adaptive_losses = [nonadaptive_loss, ntk_loss] maxiters = 4000 seed = 60 @@ -80,8 +83,11 @@ error_results_no_logs = map(test_2d_poisson_equation_adaptive_loss_no_logs_run_s @show error_results_no_logs[1][:total_diff_rel] @show error_results_no_logs[2][:total_diff_rel] @show error_results_no_logs[3][:total_diff_rel] +@show error_results_no_logs[4][:total_diff_rel] + # accuracy tests, these work for this specific seed but might not for others # note that this doesn't test that the adaptive losses are outperforming the nonadaptive loss, which is not guaranteed, and seed/arch/hyperparam/pde etc dependent @test error_results_no_logs[1][:total_diff_rel] < 0.4 @test error_results_no_logs[2][:total_diff_rel] < 0.4 @test error_results_no_logs[3][:total_diff_rel] < 0.4 +@test error_results_no_logs[4][:total_diff_rel] < 0.4