Skip to content

Commit

Permalink
Merge pull request #251 from ACEsuit/co/valset
Browse files Browse the repository at this point in the history
Option to provide validation set to acefit!
  • Loading branch information
cortner authored Sep 12, 2024
2 parents 0f9bf21 + 082d2b9 commit 99a8842
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 2 deletions.
91 changes: 91 additions & 0 deletions examples/zuobench/error_table_svd.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# This script reproduces the error table from the 2023/24 ACEpotentials
# paper, but with the new version 0.8 ACE models.

using Distributed, Random
addprocs(10, exeflags="--project=$(Base.active_project())")
@everywhere using ACEpotentials, PrettyTables

# the dataset is provided via ACE1pack artifacts as a convenient benchmarkset
# the following chemical symbols are available:
syms = [:Ni, :Cu, :Li, :Mo, :Si, :Ge]

totaldegree_tiny = [ 16, 12, 8 ] # very small model: ~ 120 basis functions
totaldegree_sm = [ 20, 16, 12 ] # small model: ~ 300 basis functions
totaldegree_lge = [ 25, 21, 17 ] # large model: ~ 1000 basis functions

##

err = Dict("lge_blr" => Dict("E" => Dict(), "F" => Dict()),
"sm_blr" => Dict("E" => Dict(), "F" => Dict()),
"lge_svd" => Dict("E" => Dict(), "F" => Dict()),
"sm_svd" => Dict("E" => Dict(), "F" => Dict()) )

for sym in syms
@info("---------- fitting $(sym) ----------")
train, test, _ = ACEpotentials.example_dataset("Zuo20_$sym")

# specify the models
model_sm = ace1_model(elements = [sym,], order = 3, totaldegree = totaldegree_sm)
model_lge = ace1_model(elements = [sym,], order = 3, totaldegree = totaldegree_lge)
@info("$sym models: length = $(length_basis(model_lge)), $(length_basis(model_sm))")

# train the model
solver = ACEfit.BLR(; factorization = :svd)
acefit!(train, model_sm; solver=solver); GC.gc()
acefit!(train, model_lge; solver=solver); GC.gc()

# compute and store errors for later visualisation
err_sm = ACEpotentials.linear_errors(test, model_sm)
err_lge = ACEpotentials.linear_errors(test, model_lge)
err["sm_blr" ]["E"][sym] = err_sm["mae"]["set"]["E"] * 1000
err["sm_blr" ]["F"][sym] = err_sm["mae"]["set"]["F"]
err["lge_blr"]["E"][sym] = err_lge["mae"]["set"]["E"] * 1000
err["lge_blr"]["F"][sym] = err_lge["mae"]["set"]["F"]

# train models with validated svd
shuffle!(train); isplit = floor(Int, 0.8 * length(train))
train1 = train[1:isplit]; val1 = train[isplit+1:end]
model_sm = ace1_model(elements = [sym,], order = 3, totaldegree = totaldegree_sm)
model_lge = ace1_model(elements = [sym,], order = 3, totaldegree = totaldegree_lge)
solver = ACEfit.TruncatedSVD() # truncation will be determined from validation set
acefit!(train1, model_sm; validation_set = val1, solver=solver); GC.gc()
acefit!(train1, model_lge; validation_set = val1, solver=solver); GC.gc()
err_sm = ACEpotentials.linear_errors(test, model_sm)
err_lge = ACEpotentials.linear_errors(test, model_lge)
err["sm_svd" ]["E"][sym] = err_sm["mae"]["set"]["E"] * 1000
err["sm_svd" ]["F"][sym] = err_sm["mae"]["set"]["F"]
err["lge_svd"]["E"][sym] = err_lge["mae"]["set"]["E"] * 1000
err["lge_svd"]["F"][sym] = err_lge["mae"]["set"]["F"]
end


##

header = ([ "", "ACE(sm, blr)", "ACE(sm, svd)", "ACE(lge, blr)", "ACE(lge, svd)", "GAP", "MTP"])
e_table_gap_mtp = [ 0.42 0.48; 0.46 0.41; 0.49 0.49; 2.24 2.83; 2.91 2.21; 2.06 1.79]
f_table_gap_mtp = [ 0.02 0.01; 0.01 0.01; 0.01 0.01; 0.09 0.09; 0.07 0.06; 0.05 0.05]

e_table = hcat(string.(syms),
[round(err["sm_blr"]["E"][sym], digits=3) for sym in syms],
[round(err["sm_svd"]["E"][sym], digits=3) for sym in syms],
[round(err["lge_blr"]["E"][sym], digits=3) for sym in syms],
[round(err["lge_svd"]["E"][sym], digits=3) for sym in syms],
e_table_gap_mtp)

f_table = hcat(string.(syms),
[round(err["sm_blr"]["F"][sym], digits=3) for sym in syms],
[round(err["sm_svd"]["F"][sym], digits=3) for sym in syms],
[round(err["lge_blr"]["F"][sym], digits=3) for sym in syms],
[round(err["lge_svd"]["F"][sym], digits=3) for sym in syms],
f_table_gap_mtp)

println("Energy Error")
pretty_table(e_table; header = header)

println("Force Error")
pretty_table(f_table; header = header)

##

# pretty_table(e_table, backend = Val(:latex), label = "Energy MAE", header = header)
# pretty_table(f_table, backend = Val(:latex), label = "Forces MAE", header = header)
22 changes: 20 additions & 2 deletions src/fit_model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ the label of the data to which the parameters will be fitted.
in a JSON format, which can be read from Julia or Python
"""
function acefit!(raw_data::AbstractArray{<: AbstractSystem}, model;
validation_set = nothing,
solver = ACEfit.BLR(),
weights = default_weights(),
energy_key = "energy",
Expand Down Expand Up @@ -133,8 +134,25 @@ function acefit!(raw_data::AbstractArray{<: AbstractSystem}, model;
# then solve the transformed problem
Ap = Diagonal(W) * (A / P)
Y = W .* Y
result = ACEfit.solve(solver, Ap, Y)
coeffs = P \ result["C"]

if isnothing(validation_set)
result = ACEfit.solve(solver, Ap, Y)

else
@info("assemble validation system")
val_data = make_atoms_data(validation_set, model;
energy_key = energy_key,
force_key = force_key,
virial_key = virial_key,
weights = weights)
Av, Yv, Wv = ACEfit.assemble(val_data, model)
Avp = Diagonal(Wv) * (Av / P)
Yv = Wv .* Yv

result = ACEfit.solve(solver, Ap, Y, Avp, Yv)
end

coeffs = P \ result["C"]

# dispatch setting of parameters
__set_params!(model, coeffs)
Expand Down

0 comments on commit 99a8842

Please sign in to comment.