Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pushforwards #246

Merged
merged 3 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/models/Rnl_splines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,9 @@ end
function rrule(::typeof(evaluate_batched),
basis::SplineRnlrzzBasis,
rs, zi, zjs, ps, st)
Rnl, st = evaluate_batched(basis, rs, zi, zjs, ps, st)
Rnl = evaluate_batched(basis, rs, zi, zjs, ps, st)

return (Rnl, st),
return Rnl,
Δ -> (NoTangent(), NoTangent(), NoTangent(), NoTangent(),
NamedTuple(), NoTangent())
end
96 changes: 84 additions & 12 deletions src/models/ace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -341,11 +341,8 @@ function evaluate_ed(model::ACEModel,
rs, ∇rs = @withalloc radii_ed!(Rs)

# evaluate the radial basis
# TODO: using @withalloc causes stack overflow
Rnl, dRnl = @withalloc evaluate_ed_batched!(model.rbasis, rs, Z0, Zs,
ps.rbasis, st.rbasis)
# Rnl, dRnl = evaluate_ed_batched(model.rbasis, rs, Z0, Zs,
# ps.rbasis, st.rbasis)

# evaluate the Y basis
Ylm, dYlm = @withalloc P4ML.evaluate_ed!(model.ybasis, Rs)
Expand Down Expand Up @@ -614,23 +611,24 @@ end
__vec(Rs::AbstractVector{SVector{3, T}}) where {T} = reinterpret(T, Rs)
__svecs(Rsvec::AbstractVector{T}) where {T} = reinterpret(SVector{3, T}, Rsvec)

function evaluate_basis_ed(model::ACEModel,
function evaluate_basis_ed_old(model::ACEModel,
Rs::AbstractVector{SVector{3, T}}, Zs, Z0,
ps, st) where {T}

if length(Rs) == 0
B = zeros(T, length_basis(model))
dB = zeros(SVector{3, T}, (0, length_basis(model)))
end
else

B = evaluate_basis(model, Rs, Zs, Z0, ps, st)
B = evaluate_basis(model, Rs, Zs, Z0, ps, st)

dB_vec = ForwardDiff.jacobian(
_Rs -> evaluate_basis(model, __svecs(_Rs), Zs, Z0, ps, st),
__vec(Rs))
dB1 = __svecs(collect(dB_vec')[:])
dB = collect( permutedims( reshape(dB1, length(Rs), length(B)),
(2, 1) ) )
dB_vec = ForwardDiff.jacobian(
_Rs -> evaluate_basis(model, __svecs(_Rs), Zs, Z0, ps, st),
__vec(Rs))
dB1 = __svecs(collect(dB_vec')[:])
dB = collect( permutedims( reshape(dB1, length(Rs), length(B)),
(2, 1) ) )
end

return B, dB
end
Expand All @@ -653,3 +651,77 @@ function jacobian_grad_params(model::ACEModel,
return Ei, ∂Ei_vec, ∂∂Ei, st
end



# ---------------------------------------------------------
# experimental pushforwards

function evaluate_basis_ed(model::ACEModel,
Rs::AbstractVector{SVector{3, T}}, Zs, Z0,
ps, st) where {T}

TB = T
∂TB = SVector{3, T}
B = zeros(TB, length_basis(model))
∂B = zeros(∂TB, length_basis(model), length(Rs))

if length(Rs) == 0
return B, ∂B
end

@no_escape begin
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

# get the radii
rs, ∇rs = @withalloc radii_ed!(Rs)

# evaluate the radial basis
Rnl, dRnl = @withalloc evaluate_ed_batched!(model.rbasis, rs, Z0, Zs,
ps.rbasis, st.rbasis)

# evaluate the Y basis
Ylm, dYlm = @withalloc P4ML.evaluate_ed!(model.ybasis, Rs)

# compute vectorial dRnl
∂Ylm = dYlm
∂Rnl = @alloc(eltype(dYlm), size(dRnl)...)
for nl = 1:size(dRnl, 2)
# @inbounds begin
# @simd ivdep
for j = 1:size(dRnl, 1)
∂Rnl[j, nl] = dRnl[j, nl] * ∇rs[j]
end
# end
end

# pushfoward through the sparse tensor - this completes the MB jacobian.
Bmb_i, ∂Bmb_i = _pfwd(model.tensor, Rnl, Ylm, ∂Rnl, ∂Ylm)

# -------------------
# pair potential
if model.pairbasis != nothing
Rnl2, dRnl2 = @withalloc evaluate_ed_batched!(model.pairbasis,
rs, Z0, Zs,
ps.pairbasis, st.pairbasis)
B2_i = sum(Rnl2, dims=1)[:]
∂B2_i = zeros(eltype(∂Bmb_i), size(Rnl2, 2), size(Rnl2, 1))
for nl = 1:size(dRnl2, 2)
for j = 1:size(dRnl2, 1)
∂B2_i[nl, j] = dRnl2[j, nl] * ∇rs[j]
end
end
else
B2_i = zeros(eltype(Bmb_i), 0)
∂B2_i = zeros(eltype(∂Bmb_i), 0, size(Bmb_i, 2))
end

# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
end # @no_escape

B[get_basis_inds(model, Z0)] .= Bmb_i
B[get_pairbasis_inds(model, Z0)] .= B2_i
∂B[get_basis_inds(model, Z0), :] .= ∂Bmb_i
∂B[get_pairbasis_inds(model, Z0), :] .= ∂B2_i

return B, ∂B
end
75 changes: 75 additions & 0 deletions src/models/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,78 @@ function get_nnll_spec(tensor::SparseEquivTensor{T}) where {T}
return nnll_list
end



# ----------------------------------------
# experimental pushforwards

function _pfwd(tensor::SparseEquivTensor{T}, Rnl, Ylm, ∂Rnl, ∂Ylm) where {T}
A, ∂A = _pfwd(tensor.abasis, (Rnl, Ylm), (∂Rnl, ∂Ylm))
_AA, _∂AA = _pfwd(tensor.aabasis, A, ∂A)

# project to the actual AA basis
proj = tensor.aabasis.projection
AA = _AA[proj]
∂AA = _∂AA[proj, :]

# evaluate the coupling coefficients
B = tensor.A2Bmap * AA
∂B = tensor.A2Bmap * ∂AA
return B, ∂B
end


function _pfwd(abasis::Polynomials4ML.PooledSparseProduct{2}, RY, ∂RY)
R, Y = RY
TA = typeof(R[1] * Y[1])
∂R, ∂Y = ∂RY
∂TA = typeof(R[1] * ∂Y[1] + ∂R[1] * Y[1])

# check lengths
nX = size(R, 1)
@assert nX == size(R, 1) == size(∂R, 1) == size(Y, 1) == size(∂Y, 1)

A = zeros(TA, length(abasis.spec))
∂A = zeros(∂TA, size(∂R, 1), length(abasis.spec))

for i = 1:length(abasis.spec)
@inbounds begin
n1, n2 = abasis.spec[i]
ai = zero(TA)
@simd ivdep for α = 1:nX
ai += R[α, n1] * Y[α, n2]
∂A[α, i] = R[α, n1] * ∂Y[α, n2] + ∂R[α, n1] * Y[α, n2]
end
A[i] = ai
end
end
return A, ∂A
end


function _pfwd(aabasis::Polynomials4ML.SparseSymmProdDAG, A, ∂A)
n∂ = size(∂A, 1)
num1 = aabasis.num1
nodes = aabasis.nodes
AA = zeros(eltype(A), length(nodes))
T∂AA = typeof(A[1] * ∂A[1])
∂AA = zeros(T∂AA, length(nodes), size(∂A, 1))
for i = 1:num1
AA[i] = A[i]
for α = 1:n∂
∂AA[i, α] = ∂A[α, i]
end
end
for iAA = num1+1:length(nodes)
n1, n2 = nodes[iAA]
AA_n1 = AA[n1]
AA_n2 = AA[n2]
AA[iAA] = AA_n1 * AA_n2
for α = 1:n∂
∂AA[iAA, α] = AA_n2 * ∂AA[n1, α] + AA_n1 * ∂AA[n2, α]
end
end
return AA, ∂AA
end


28 changes: 23 additions & 5 deletions test/models/test_ace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ for ybasis in [:spherical, :solid]
##

@info("Test second mixed derivatives reverse-over-reverse")
for ntest = 1:20
for ntest = 1:10
local Nat, Rs, Zs, Us, Ei, ∂Ei, ∂2_Ei,
ps_vec, vs_vec, F, dF0, z0, _restruct

Expand Down Expand Up @@ -150,7 +150,7 @@ for ybasis in [:spherical, :solid]

@info("Test basis implementation")

for ntest = 1:30
for ntest = 1:5
local Nat, Rs, Zs, z0, Ei, B, θ, st1 , ∇Ei

Nat = 15
Expand All @@ -162,7 +162,9 @@ for ybasis in [:spherical, :solid]
print_tf(@test Ei ≈ dot(B, θ))

Ei, ∇Ei = M.evaluate_ed(model, Rs, Zs, z0, ps, st)
B, ∇B = M.evaluate_basis_ed(model, Rs, Zs, z0, ps, st)

B1, ∇B = M.evaluate_basis_ed(model, Rs, Zs, z0, ps, st)
print_tf(@test B ≈ B1)
print_tf(@test ∇Ei ≈ sum(θ .* ∇B, dims=1)[:])
end
println()
Expand All @@ -171,7 +173,7 @@ for ybasis in [:spherical, :solid]

@info("Test the full mixed jacobian")

for ntest = 1:30
for ntest = 1:5
local Nat, Rs, Zs, z0, Ei, ∇Ei, ∂∂Ei, Us, F, dF0

Nat = 15
Expand All @@ -193,7 +195,7 @@ for ybasis in [:spherical, :solid]
ps_lin.WB[:] .= ps.WB[:]
ps_lin.Wpair[:] .= ps.Wpair[:]

for ntest = 1:10
for ntest = 1:5
local len, Nat, Rs, Zs, z0, Ei
len = 100
mae = sum(1:len) do _
Expand All @@ -208,6 +210,22 @@ for ybasis in [:spherical, :solid]
end
println()

##

@info("After splinification check correctness of evaluate_basis_ed again")
for ntest = 1:5
local Nat, Rs, Zs, z0, Ei
Nat = rand(8:16)
Rs, Zs, z0 = M.rand_atenv(model, Nat)
Us = randn(SVector{3, Float64}, Nat) / sqrt(Nat)
B, ∂B = M.evaluate_basis_ed(lin_ace, Rs, Zs, z0, ps, st)
B0 = M.evaluate_basis(lin_ace, Rs, Zs, z0, ps, st)
∂B0 = M.jacobian_grad_params(lin_ace, Rs, Zs, z0, ps, st)[3]
print_tf(@test B ≈ B0)
print_tf(@test ∂B ≈ ∂B0)
end
println()

end

##
Expand Down
1 change: 0 additions & 1 deletion test/test_silicon.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ rmse_qr = Dict(
"set" => Dict("V"=>0.057, "E"=>0.0017, "F"=>0.12),
"bt" => Dict("V"=>0.08, "E"=>0.0022, "F"=>0.07),)


acefit!(data, model;
data_keys...,
weights = weights,
Expand Down
Loading