diff --git a/examples/zuobench/README.md b/examples/zuobench/README.md new file mode 100644 index 00000000..e69de29b diff --git a/examples/zuobench/error_table.jl b/examples/zuobench/error_table.jl new file mode 100644 index 00000000..a556e4d0 --- /dev/null +++ b/examples/zuobench/error_table.jl @@ -0,0 +1,70 @@ +# This script reproduces the error table from the 2023/24 ACEpotentials +# paper, but with the new version 0.8 ACE models. + +using Distributed +addprocs(10, exeflags="--project=$(Base.active_project())") +@everywhere using ACEpotentials, PrettyTables + +# the dataset is provided via ACE1pack artifacts as a convenient benchmarkset +# the following chemical symbols are available: +syms = [:Ni, :Cu, :Li, :Mo, :Si, :Ge] + +totaldegree_tiny = [ 16, 12, 8 ] # very small model: ~ 120 basis functions +totaldegree_sm = [ 20, 16, 12 ] # small model: ~ 300 basis functions +totaldegree_lge = [ 25, 21, 17 ] # large model: ~ 1000 basis functions + +## + +err = Dict("lge" => Dict("E" => Dict(), "F" => Dict()), + "sm" => Dict("E" => Dict(), "F" => Dict()) ) + +for sym in syms + @info("---------- fitting $(sym) ----------") + train, test, _ = ACEpotentials.example_dataset("Zuo20_$sym") + + # specify the models + model_sm = ace1_model(elements = [sym,], order = 3, totaldegree = totaldegree_sm) + model_lge = ace1_model(elements = [sym,], order = 3, totaldegree = totaldegree_lge) + @info("$sym models: length = $(length_basis(model_lge)), $(length_basis(model_sm))") + + # train the model + solver = ACEfit.BLR(; factorization = :svd) + acefit!(train, model_sm; solver=solver); GC.gc() + acefit!(train, model_lge; solver=solver); GC.gc() + + # compute and store errors for later visualisation + err_sm = ACEpotentials.linear_errors(test, model_sm) + err_lge = ACEpotentials.linear_errors(test, model_lge) + err["sm" ]["E"][sym] = err_sm["mae"]["set"]["E"] * 1000 + err["sm" ]["F"][sym] = err_sm["mae"]["set"]["F"] + err["lge"]["E"][sym] = err_lge["mae"]["set"]["E"] * 1000 + err["lge"]["F"][sym] = err_lge["mae"]["set"]["F"] +end + + +## + +header = ([ "", "ACE(sm)", "ACE(lge)", "GAP", "MTP"]) +e_table_gap_mtp = [ 0.42 0.48; 0.46 0.41; 0.49 0.49; 2.24 2.83; 2.91 2.21; 2.06 1.79] +f_table_gap_mtp = [ 0.02 0.01; 0.01 0.01; 0.01 0.01; 0.09 0.09; 0.07 0.06; 0.05 0.05] + +e_table = hcat(string.(syms), + [round(err["sm"]["E"][sym], digits=3) for sym in syms], + [round(err["lge"]["E"][sym], digits=3) for sym in syms], + e_table_gap_mtp) + +f_table = hcat(string.(syms), + [round(err["sm"]["F"][sym], digits=3) for sym in syms], + [round(err["lge"]["F"][sym], digits=3) for sym in syms], + f_table_gap_mtp) + +println("Energy Error") +pretty_table(e_table; header = header) + +println("Force Error") +pretty_table(f_table; header = header) + +## + +# pretty_table(e_table, backend = Val(:latex), label = "Energy MAE", header = header) +# pretty_table(f_table, backend = Val(:latex), label = "Forces MAE", header = header) \ No newline at end of file diff --git a/examples/zuobench/zuo_asp.jl b/examples/zuobench/zuo_asp.jl new file mode 100644 index 00000000..e07c9edf --- /dev/null +++ b/examples/zuobench/zuo_asp.jl @@ -0,0 +1,74 @@ + +using Distributed +addprocs(10, exeflags="--project=$(Base.active_project())") +@everywhere using ACEpotentials, PrettyTables + +# the dataset is provided via ACE1pack artifacts as a convenient benchmarkset +# the following chemical symbols are available: +# syms = [:Ni, :Cu, :Li, :Mo, :Si, :Ge] + +# this element is quite interesting because it gives odd results +sym = :Ni # :Mo + +# start with a large-ish model +totaldegree = [ 28, 23, 18 ] +model = ace1_model(elements = [sym,], order = 3, totaldegree = totaldegree) +P = algebraic_smoothness_prior(model; p = 4) + +@info("$sym model, basis length = $(length_basis(model))") + +## + +@info("---------- Assemble Training and Validation Systems ----------") +_train_data, test_data, _ = ACEpotentials.example_dataset("Zuo20_$sym") +Random.shuffle!(_train_data) +train_data = _train_data[1:5:end] +val_data = _train_data[2:5:end] +# datakeys = (:) +At, yt, Wt = ACEpotentials.assemble(train_data, model) +Av, yv, Wv = ACEpotentials.assemble(val_data, model) + +@info("Compute ASP Path") +solver = ACEfit.ASP(; P = P, select = (:byerror, 1.0), tsvd = true ) +asp_result = ACEfit.solve(solver, Wt .* At, Wt .* yt, Wv .* Av, Wv .* yv) + +## + +@info("Look at the best model") +asp_result["C"] +set_parameters!(model, asp_result["C"]) +ACEpotentials.linear_errors(test_data, model) + + + + +# 120 +# 300 +# 1000 + +## + +# header = ([ "", "ACE(sm)", "ACE(lge)", "GAP", "MTP"]) +# e_table_gap_mtp = [ 0.42 0.48; 0.46 0.41; 0.49 0.49; 2.24 2.83; 2.91 2.21; 2.06 1.79] +# f_table_gap_mtp = [ 0.02 0.01; 0.01 0.01; 0.01 0.01; 0.09 0.09; 0.07 0.06; 0.05 0.05] + +# e_table = hcat(string.(syms), +# [round(err["sm"]["E"][sym], digits=3) for sym in syms], +# [round(err["lge"]["E"][sym], digits=3) for sym in syms], +# e_table_gap_mtp) + +# f_table = hcat(string.(syms), +# [round(err["sm"]["F"][sym], digits=3) for sym in syms], +# [round(err["lge"]["F"][sym], digits=3) for sym in syms], +# f_table_gap_mtp) + +# println("Energy Error") +# pretty_table(e_table; header = header) + +# println("Force Error") +# pretty_table(f_table; header = header) + +# ## + +# pretty_table(e_table, backend = Val(:latex), label = "Energy MAE", header = header) +# pretty_table(f_table, backend = Val(:latex), label = "Forces MAE", header = header) \ No newline at end of file diff --git a/src/fit_model.jl b/src/fit_model.jl index 381fa688..34e5929d 100644 --- a/src/fit_model.jl +++ b/src/fit_model.jl @@ -32,6 +32,24 @@ end # ---------------- the main fitting function +function make_atoms_data(raw_data::AbstractArray{<: AbstractSystem}, model; + energy_key, force_key, virial_key, weights) + + # convert raw data to AtomsData, which ACEfit.jl understands + data = map( raw_data ) do d + AtomsData(d; + energy_key = energy_key, + force_key = force_key, + virial_key = virial_key, + weights = weights, + v_ref = _get_Vref(model) + ) + end + + return data +end + + """ acefit!(rawdata, model; kwargs...) @@ -72,16 +90,11 @@ function acefit!(raw_data::AbstractArray{<: AbstractSystem}, model; kwargs... ) - # convert raw data to AtomsData, which ACEfit.jl understands - data = map( raw_data ) do d - AtomsData(d; - energy_key = energy_key, - force_key=force_key, - virial_key = virial_key, - weights = weights, - v_ref = _get_Vref(model) - ) - end + data = make_atoms_data(raw_data, model; + energy_key = energy_key, + force_key = force_key, + virial_key = virial_key, + weights = weights) # print some information about the dataset # (how many observations in each class) @@ -170,10 +183,10 @@ end function assemble(raw_data::AbstractArray{<: AbstractSystem}, model; - weights = default_weights(), - energy_key = "energy", - force_key = "force", - virial_key = "virial", + weights = default_weights(), + energy_key = "energy", + force_key = "force", + virial_key = "virial", # smoothness = 4, # prior = nothing, repulsion_restraint = false, @@ -181,9 +194,11 @@ function assemble(raw_data::AbstractArray{<: AbstractSystem}, model; mode = :serial, weights_only = false) - data = [ AtomsData(at; energy_key = energy_key, force_key=force_key, - virial_key = virial_key, weights = weights, - v_ref = model.Vref) for at in raw_data ] + data = make_atoms_data(raw_data, model; + energy_key = energy_key, + force_key = force_key, + virial_key = virial_key, + weights = weights) if repulsion_restraint error("Repulsion restraint is currently not implemented") @@ -195,6 +210,6 @@ function assemble(raw_data::AbstractArray{<: AbstractSystem}, model; return W end - A, Y, W = assemble(data, model.basis, mode) + A, Y, W = assemble(data, model) return A, Y, W end