From 598a9f3bbb59ee2210bfee135f61d8e1b75f9f4e Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Wed, 11 Sep 2024 11:13:44 -0700 Subject: [PATCH] basis tests pass again --- src/models/Rnl_splines.jl | 4 ++-- src/models/ace.jl | 46 ++++++++++++++++++++++++--------------- test/models/test_ace.jl | 23 +++++++++++++++++++- 3 files changed, 52 insertions(+), 21 deletions(-) diff --git a/src/models/Rnl_splines.jl b/src/models/Rnl_splines.jl index 4e9fb17e..29624276 100644 --- a/src/models/Rnl_splines.jl +++ b/src/models/Rnl_splines.jl @@ -133,9 +133,9 @@ end function rrule(::typeof(evaluate_batched), basis::SplineRnlrzzBasis, rs, zi, zjs, ps, st) - Rnl, st = evaluate_batched(basis, rs, zi, zjs, ps, st) + Rnl = evaluate_batched(basis, rs, zi, zjs, ps, st) - return (Rnl, st), + return Rnl, Δ -> (NoTangent(), NoTangent(), NoTangent(), NoTangent(), NamedTuple(), NoTangent()) end diff --git a/src/models/ace.jl b/src/models/ace.jl index 620a81a7..654b839d 100644 --- a/src/models/ace.jl +++ b/src/models/ace.jl @@ -611,23 +611,24 @@ end __vec(Rs::AbstractVector{SVector{3, T}}) where {T} = reinterpret(T, Rs) __svecs(Rsvec::AbstractVector{T}) where {T} = reinterpret(SVector{3, T}, Rsvec) -function evaluate_basis_ed(model::ACEModel, +function evaluate_basis_ed_old(model::ACEModel, Rs::AbstractVector{SVector{3, T}}, Zs, Z0, ps, st) where {T} if length(Rs) == 0 B = zeros(T, length_basis(model)) dB = zeros(SVector{3, T}, (0, length_basis(model))) - end + else - B = evaluate_basis(model, Rs, Zs, Z0, ps, st) + B = evaluate_basis(model, Rs, Zs, Z0, ps, st) - dB_vec = ForwardDiff.jacobian( - _Rs -> evaluate_basis(model, __svecs(_Rs), Zs, Z0, ps, st), - __vec(Rs)) - dB1 = __svecs(collect(dB_vec')[:]) - dB = collect( permutedims( reshape(dB1, length(Rs), length(B)), - (2, 1) ) ) + dB_vec = ForwardDiff.jacobian( + _Rs -> evaluate_basis(model, __svecs(_Rs), Zs, Z0, ps, st), + __vec(Rs)) + dB1 = __svecs(collect(dB_vec')[:]) + dB = collect( permutedims( reshape(dB1, length(Rs), length(B)), + (2, 1) ) ) + end return B, dB end @@ -655,7 +656,7 @@ end # --------------------------------------------------------- # experimental pushforwards -function _evaluate_basis_ed(model::ACEModel, +function evaluate_basis_ed(model::ACEModel, Rs::AbstractVector{SVector{3, T}}, Zs, Z0, ps, st) where {T} @@ -685,7 +686,7 @@ function _evaluate_basis_ed(model::ACEModel, end # pushfoward through the sparse tensor - this completes the MB jacobian. - B, ∂B = _pfwd(model.tensor, Rnl, Ylm, ∂Rnl, ∂Ylm) + Bmb_i, ∂Bmb_i = _pfwd(model.tensor, Rnl, Ylm, ∂Rnl, ∂Ylm) # ------------------- # pair potential @@ -693,20 +694,29 @@ function _evaluate_basis_ed(model::ACEModel, Rnl2, dRnl2 = @withalloc evaluate_ed_batched!(model.pairbasis, rs, Z0, Zs, ps.pairbasis, st.pairbasis) - Apair = sum(Rnl2, dims=1)[:] - ∂Apair = zeros(eltype(∂B), size(Rnl2, 2), size(Rnl2, 1)) + B2_i = sum(Rnl2, dims=1)[:] + ∂B2_i = zeros(eltype(∂Bmb_i), size(Rnl2, 2), size(Rnl2, 1)) for nl = 1:size(dRnl2, 2) for j = 1:size(dRnl2, 1) - ∂Apair[nl, j] = dRnl2[j, nl] * ∇rs[j] + ∂B2_i[nl, j] = dRnl2[j, nl] * ∇rs[j] end end - - B = [ B; Apair ] - ∂B = [ ∂B; ∂Apair ] - end + else + B2_i = zeros(eltype(Bmb_i), 0) + ∂B2_i = zeros(eltype(∂Bmb_i), 0, size(Bmb_i, 2)) + end # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ end # @no_escape + TB = promote_type(eltype(Bmb_i), eltype(B2_i)) + ∂TB = promote_type(eltype(∂Bmb_i), eltype(∂B2_i)) + B = zeros(TB, length_basis(model)) + B[get_basis_inds(model, Z0)] .= Bmb_i + B[get_pairbasis_inds(model, Z0)] .= B2_i + ∂B = zeros(∂TB, length_basis(model), size(∂Bmb_i, 2)) + ∂B[get_basis_inds(model, Z0), :] .= ∂Bmb_i + ∂B[get_pairbasis_inds(model, Z0), :] .= ∂B2_i + return B, ∂B end \ No newline at end of file diff --git a/test/models/test_ace.jl b/test/models/test_ace.jl index 2cbbcbf6..67804365 100644 --- a/test/models/test_ace.jl +++ b/test/models/test_ace.jl @@ -162,7 +162,9 @@ for ybasis in [:spherical, :solid] print_tf(@test Ei ≈ dot(B, θ)) Ei, ∇Ei = M.evaluate_ed(model, Rs, Zs, z0, ps, st) - B, ∇B = M.evaluate_basis_ed(model, Rs, Zs, z0, ps, st) + + B1, ∇B = M.evaluate_basis_ed(model, Rs, Zs, z0, ps, st) + print_tf(@test B ≈ B1) print_tf(@test ∇Ei ≈ sum(θ .* ∇B, dims=1)[:]) end println() @@ -208,6 +210,25 @@ for ybasis in [:spherical, :solid] end println() +## + + @info("After splinification check correctness of evaluate_basis_ed again") + for ntest = 1:10 + local Nat, Rs, Zs, z0, Ei + Nat = rand(8:16) + Rs, Zs, z0 = M.rand_atenv(model, Nat) + Us = randn(SVector{3, Float64}, Nat) / sqrt(Nat) + B, ∂B = M.evaluate_basis_ed(lin_ace, Rs, Zs, z0, ps, st) + B0 = M.evaluate_basis(lin_ace, Rs, Zs, z0, ps, st) + ∂B0 = M.jacobian_grad_params(lin_ace, Rs, Zs, z0, ps, st)[3] + print_tf(@test B ≈ B0) + print_tf(@test ∂B ≈ ∂B0) + end + println() + # ∂E0 = ForwardDiff.derivative( + # t -> M.evaluate_basis(lin_ace, Rs + t * Us, Zs, z0, ps, st), + # 0.0) + end ##