Skip to content

Commit

Permalink
basis tests pass again
Browse files Browse the repository at this point in the history
  • Loading branch information
Christoph Ortner committed Sep 11, 2024
1 parent 398948c commit 598a9f3
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 21 deletions.
4 changes: 2 additions & 2 deletions src/models/Rnl_splines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,9 @@ end
function rrule(::typeof(evaluate_batched),
basis::SplineRnlrzzBasis,
rs, zi, zjs, ps, st)
Rnl, st = evaluate_batched(basis, rs, zi, zjs, ps, st)
Rnl = evaluate_batched(basis, rs, zi, zjs, ps, st)

return (Rnl, st),
return Rnl,
Δ -> (NoTangent(), NoTangent(), NoTangent(), NoTangent(),
NamedTuple(), NoTangent())
end
46 changes: 28 additions & 18 deletions src/models/ace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -611,23 +611,24 @@ end
__vec(Rs::AbstractVector{SVector{3, T}}) where {T} = reinterpret(T, Rs)
__svecs(Rsvec::AbstractVector{T}) where {T} = reinterpret(SVector{3, T}, Rsvec)

function evaluate_basis_ed(model::ACEModel,
function evaluate_basis_ed_old(model::ACEModel,
Rs::AbstractVector{SVector{3, T}}, Zs, Z0,
ps, st) where {T}

if length(Rs) == 0
B = zeros(T, length_basis(model))
dB = zeros(SVector{3, T}, (0, length_basis(model)))
end
else

B = evaluate_basis(model, Rs, Zs, Z0, ps, st)
B = evaluate_basis(model, Rs, Zs, Z0, ps, st)

dB_vec = ForwardDiff.jacobian(
_Rs -> evaluate_basis(model, __svecs(_Rs), Zs, Z0, ps, st),
__vec(Rs))
dB1 = __svecs(collect(dB_vec')[:])
dB = collect( permutedims( reshape(dB1, length(Rs), length(B)),
(2, 1) ) )
dB_vec = ForwardDiff.jacobian(
_Rs -> evaluate_basis(model, __svecs(_Rs), Zs, Z0, ps, st),
__vec(Rs))
dB1 = __svecs(collect(dB_vec')[:])
dB = collect( permutedims( reshape(dB1, length(Rs), length(B)),
(2, 1) ) )
end

return B, dB
end
Expand Down Expand Up @@ -655,7 +656,7 @@ end
# ---------------------------------------------------------
# experimental pushforwards

function _evaluate_basis_ed(model::ACEModel,
function evaluate_basis_ed(model::ACEModel,
Rs::AbstractVector{SVector{3, T}}, Zs, Z0,
ps, st) where {T}

Expand Down Expand Up @@ -685,28 +686,37 @@ function _evaluate_basis_ed(model::ACEModel,
end

# pushfoward through the sparse tensor - this completes the MB jacobian.
B, ∂B = _pfwd(model.tensor, Rnl, Ylm, ∂Rnl, ∂Ylm)
Bmb_i, ∂Bmb_i = _pfwd(model.tensor, Rnl, Ylm, ∂Rnl, ∂Ylm)

# -------------------
# pair potential
if model.pairbasis != nothing
Rnl2, dRnl2 = @withalloc evaluate_ed_batched!(model.pairbasis,
rs, Z0, Zs,
ps.pairbasis, st.pairbasis)
Apair = sum(Rnl2, dims=1)[:]
Apair = zeros(eltype(∂B), size(Rnl2, 2), size(Rnl2, 1))
B2_i = sum(Rnl2, dims=1)[:]
B2_i = zeros(eltype(∂Bmb_i), size(Rnl2, 2), size(Rnl2, 1))
for nl = 1:size(dRnl2, 2)
for j = 1:size(dRnl2, 1)
Apair[nl, j] = dRnl2[j, nl] * ∇rs[j]
B2_i[nl, j] = dRnl2[j, nl] * ∇rs[j]
end
end

B = [ B; Apair ]
B = [ ∂B; ∂Apair ]
end
else
B2_i = zeros(eltype(Bmb_i), 0)
B2_i = zeros(eltype(∂Bmb_i), 0, size(Bmb_i, 2))
end

# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
end # @no_escape

TB = promote_type(eltype(Bmb_i), eltype(B2_i))
∂TB = promote_type(eltype(∂Bmb_i), eltype(∂B2_i))
B = zeros(TB, length_basis(model))
B[get_basis_inds(model, Z0)] .= Bmb_i
B[get_pairbasis_inds(model, Z0)] .= B2_i
∂B = zeros(∂TB, length_basis(model), size(∂Bmb_i, 2))
∂B[get_basis_inds(model, Z0), :] .= ∂Bmb_i
∂B[get_pairbasis_inds(model, Z0), :] .= ∂B2_i

return B, ∂B
end
23 changes: 22 additions & 1 deletion test/models/test_ace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,9 @@ for ybasis in [:spherical, :solid]
print_tf(@test Ei dot(B, θ))

Ei, ∇Ei = M.evaluate_ed(model, Rs, Zs, z0, ps, st)
B, ∇B = M.evaluate_basis_ed(model, Rs, Zs, z0, ps, st)

B1, ∇B = M.evaluate_basis_ed(model, Rs, Zs, z0, ps, st)
print_tf(@test B B1)
print_tf(@test ∇Ei sum.* ∇B, dims=1)[:])
end
println()
Expand Down Expand Up @@ -208,6 +210,25 @@ for ybasis in [:spherical, :solid]
end
println()

##

@info("After splinification check correctness of evaluate_basis_ed again")
for ntest = 1:10
local Nat, Rs, Zs, z0, Ei
Nat = rand(8:16)
Rs, Zs, z0 = M.rand_atenv(model, Nat)
Us = randn(SVector{3, Float64}, Nat) / sqrt(Nat)
B, ∂B = M.evaluate_basis_ed(lin_ace, Rs, Zs, z0, ps, st)
B0 = M.evaluate_basis(lin_ace, Rs, Zs, z0, ps, st)
∂B0 = M.jacobian_grad_params(lin_ace, Rs, Zs, z0, ps, st)[3]
print_tf(@test B B0)
print_tf(@test ∂B ∂B0)
end
println()
# ∂E0 = ForwardDiff.derivative(
# t -> M.evaluate_basis(lin_ace, Rs + t * Us, Zs, z0, ps, st),
# 0.0)

end

##
Expand Down

0 comments on commit 598a9f3

Please sign in to comment.