Skip to content

Commit

Permalink
some fixes to zuo example
Browse files Browse the repository at this point in the history
  • Loading branch information
Christoph Ortner committed Sep 11, 2024
1 parent a12c051 commit 2f11456
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 75 deletions.
18 changes: 10 additions & 8 deletions examples/zuobench/error_table.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
# This script reproduces the error table from the 2023/24 ACEpotentials
# paper, but with the new version 0.8 ACE models.

using Distributed
addprocs(10, exeflags="--project=$(Base.active_project())")
@everywhere using ACEpotentials, PrettyTables

# the dataset is provided via ACE1pack artifacts as a convenient benchmarkset
# the following chemical symbols are available:
# syms = [:Ni, :Cu, :Li, :Mo, :Si, :Ge]
syms = [:Ni, :Cu, ]
syms = [:Ni, :Cu, :Li, :Mo, :Si, :Ge]

totaldegree_tiny = [ 18, 14, 10 ] # very small model: ~ 100 basis functions
totaldegree_tiny = [ 16, 12, 8 ] # very small model: ~ 120 basis functions
totaldegree_sm = [ 20, 16, 12 ] # small model: ~ 300 basis functions
totaldegree_lge = [ 25, 21, 17 ] # large model: ~ 1000 basis functions

Expand All @@ -19,16 +21,16 @@ err = Dict("lge" => Dict("E" => Dict(), "F" => Dict()),
for sym in syms
@info("---------- fitting $(sym) ----------")
train, test, _ = ACEpotentials.example_dataset("Zuo20_$sym")
train = train[1:5:end]

# specify the models
model_sm = ace1_model(elements = [sym,], order = 3, totaldegree = totaldegree_sm)
model_lge = ace1_model(elements = [sym,], order = 3, totaldegree = totaldegree_lge)
@info("$sym models: length = $(length_basis(model_lge)), $(length_basis(model_sm))")

# train the model
acefit!(train, model_sm; solver=ACEfit.BLR()); GC.gc()
acefit!(train, model_lge; solver=ACEfit.BLR()); GC.gc()
solver = ACEfit.BLR(; factorization = :svd)
acefit!(train, model_sm; solver=solver); GC.gc()
acefit!(train, model_lge; solver=solver); GC.gc()

# compute and store errors for later visualisation
err_sm = ACEpotentials.linear_errors(test, model_sm)
Expand Down Expand Up @@ -64,5 +66,5 @@ pretty_table(f_table; header = header)

##

pretty_table(e_table, backend = Val(:latex), label = "Energy MAE", header = header)
pretty_table(f_table, backend = Val(:latex), label = "Forces MAE", header = header)
# pretty_table(e_table, backend = Val(:latex), label = "Energy MAE", header = header)
# pretty_table(f_table, backend = Val(:latex), label = "Forces MAE", header = header)
95 changes: 46 additions & 49 deletions examples/zuobench/zuo_asp.jl
Original file line number Diff line number Diff line change
@@ -1,77 +1,74 @@

using Distributed
addprocs(10, exeflags="--project=$(Base.active_project())")
@everywhere using ACEpotentials, PrettyTables

# the dataset is provided via ACE1pack artifacts as a convenient benchmarkset
# the following chemical symbols are available:
# syms = [:Ni, :Cu, :Li, :Mo, :Si, :Ge]
syms = [:Ni, :Cu, ]

totaldegree_sm = [ 20, 16, 12 ] # small model: ~ 300 basis functions
totaldegree_lge = [ 25, 21, 17 ] # large model: ~ 1000 basis functions
# this element is quite interesting because it gives odd results
sym = :Ni # :Mo

##
# start with a large-ish model
totaldegree = [ 28, 23, 18 ]
model = ace1_model(elements = [sym,], order = 3, totaldegree = totaldegree)
P = algebraic_smoothness_prior(model; p = 4)

sym = :Si
model = ace1_model(elements = [sym,], order = 3, totaldegree = totaldegree_sm)
@info("$sym model, basis length = $(length_basis(model))")

##

@info("---------- Assemble Training and Validation Systems ----------")
_train_data, test_data, _ = ACEpotentials.example_dataset("Zuo20_$sym")
Random.shuffle!(_train_data)
train_data = _train_data[1:5:end]
val_data = _train_data[2:5:end]
# datakeys = (:)
At, yt, Wt = ACEpotentials.assemble(train_data, model)
Av, yv, Wv = ACEpotentials.assemble(val_data, model)

##
@info("Compute ASP Path")
solver = ACEfit.ASP(; P = P, select = (:byerror, 1.0), tsvd = true )
asp_result = ACEfit.solve(solver, Wt .* At, Wt .* yt, Wv .* Av, Wv .* yv)

err = Dict("lge" => Dict("E" => Dict(), "F" => Dict()),
"sm" => Dict("E" => Dict(), "F" => Dict()) )
##


for sym in syms
@info("---------- fitting $(sym) ----------")
train, test, _ = ACEpotentials.example_dataset("Zuo20_$sym")
train = train[1:5:end]
@info("Look at the best model")
asp_result["C"]
set_parameters!(model, asp_result["C"])
ACEpotentials.linear_errors(test_data, model)

# specify the models
# model_lge = ace1_model(elements = [sym,], order = 3, totaldegree = totaldegree_lge)
model_sm = ace1_model(elements = [sym,], order = 3, totaldegree = totaldegree_sm)
@info("$sym models: length = $(length_basis(model_lge)), $(length_basis(model_sm))")

# train the model
acefit!(train, model_sm; solver=ACEfit.BLR()); GC.gc()
# acefit!(train, model_lge; solver=ACEfit.BLR()); GC.gc()

# compute and store errors for later visualisation
err_sm = ACEpotentials.linear_errors(test, model_sm)
err_lge = ACEpotentials.linear_errors(test, model_lge)
err["sm" ]["E"][sym] = err_sm["mae"]["set"]["E"] * 1000
err["sm" ]["F"][sym] = err_sm["mae"]["set"]["F"]
# err["lge"]["E"][sym] = err_lge["mae"]["set"]["E"] * 1000
# err["lge"]["F"][sym] = err_lge["mae"]["set"]["F"]
err["lge"]["E"][sym] = 0.0
err["lge"]["F"][sym] = 0.0
end

# 120
# 300
# 1000

##

header = ([ "", "ACE(sm)", "ACE(lge)", "GAP", "MTP"])
e_table_gap_mtp = [ 0.42 0.48; 0.46 0.41; 0.49 0.49; 2.24 2.83; 2.91 2.21; 2.06 1.79]
f_table_gap_mtp = [ 0.02 0.01; 0.01 0.01; 0.01 0.01; 0.09 0.09; 0.07 0.06; 0.05 0.05]
# header = ([ "", "ACE(sm)", "ACE(lge)", "GAP", "MTP"])
# e_table_gap_mtp = [ 0.42 0.48; 0.46 0.41; 0.49 0.49; 2.24 2.83; 2.91 2.21; 2.06 1.79]
# f_table_gap_mtp = [ 0.02 0.01; 0.01 0.01; 0.01 0.01; 0.09 0.09; 0.07 0.06; 0.05 0.05]

e_table = hcat(string.(syms),
[round(err["sm"]["E"][sym], digits=3) for sym in syms],
[round(err["lge"]["E"][sym], digits=3) for sym in syms],
e_table_gap_mtp)
# e_table = hcat(string.(syms),
# [round(err["sm"]["E"][sym], digits=3) for sym in syms],
# [round(err["lge"]["E"][sym], digits=3) for sym in syms],
# e_table_gap_mtp)

f_table = hcat(string.(syms),
[round(err["sm"]["F"][sym], digits=3) for sym in syms],
[round(err["lge"]["F"][sym], digits=3) for sym in syms],
f_table_gap_mtp)
# f_table = hcat(string.(syms),
# [round(err["sm"]["F"][sym], digits=3) for sym in syms],
# [round(err["lge"]["F"][sym], digits=3) for sym in syms],
# f_table_gap_mtp)

println("Energy Error")
pretty_table(e_table; header = header)
# println("Energy Error")
# pretty_table(e_table; header = header)

println("Force Error")
pretty_table(f_table; header = header)
# println("Force Error")
# pretty_table(f_table; header = header)

##
# ##

pretty_table(e_table, backend = Val(:latex), label = "Energy MAE", header = header)
pretty_table(f_table, backend = Val(:latex), label = "Forces MAE", header = header)
# pretty_table(e_table, backend = Val(:latex), label = "Energy MAE", header = header)
# pretty_table(f_table, backend = Val(:latex), label = "Forces MAE", header = header)
51 changes: 33 additions & 18 deletions src/fit_model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,24 @@ end

# ---------------- the main fitting function

function make_atoms_data(raw_data::AbstractArray{<: AbstractSystem}, model;
energy_key, force_key, virial_key, weights)

# convert raw data to AtomsData, which ACEfit.jl understands
data = map( raw_data ) do d
AtomsData(d;
energy_key = energy_key,
force_key = force_key,
virial_key = virial_key,
weights = weights,
v_ref = _get_Vref(model)
)
end

return data
end


"""
acefit!(rawdata, model; kwargs...)
Expand Down Expand Up @@ -72,16 +90,11 @@ function acefit!(raw_data::AbstractArray{<: AbstractSystem}, model;
kwargs...
)

# convert raw data to AtomsData, which ACEfit.jl understands
data = map( raw_data ) do d
AtomsData(d;
energy_key = energy_key,
force_key=force_key,
virial_key = virial_key,
weights = weights,
v_ref = _get_Vref(model)
)
end
data = make_atoms_data(raw_data, model;
energy_key = energy_key,
force_key = force_key,
virial_key = virial_key,
weights = weights)

# print some information about the dataset
# (how many observations in each class)
Expand Down Expand Up @@ -170,20 +183,22 @@ end


function assemble(raw_data::AbstractArray{<: AbstractSystem}, model;
weights = default_weights(),
energy_key = "energy",
force_key = "force",
virial_key = "virial",
weights = default_weights(),
energy_key = "energy",
force_key = "force",
virial_key = "virial",
# smoothness = 4,
# prior = nothing,
repulsion_restraint = false,
restraint_weight = 0.01,
mode = :serial,
weights_only = false)

data = [ AtomsData(at; energy_key = energy_key, force_key=force_key,
virial_key = virial_key, weights = weights,
v_ref = model.Vref) for at in raw_data ]
data = make_atoms_data(raw_data, model;
energy_key = energy_key,
force_key = force_key,
virial_key = virial_key,
weights = weights)

if repulsion_restraint
error("Repulsion restraint is currently not implemented")
Expand All @@ -195,6 +210,6 @@ function assemble(raw_data::AbstractArray{<: AbstractSystem}, model;
return W
end

A, Y, W = assemble(data, model.basis, mode)
A, Y, W = assemble(data, model)
return A, Y, W
end

0 comments on commit 2f11456

Please sign in to comment.