From 082d2b92383bf48ad3b9afb65713781f813309a3 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Thu, 12 Sep 2024 09:42:09 -0700 Subject: [PATCH] add option to provide validation set --- examples/zuobench/error_table_svd.jl | 91 ++++++++++++++++++++++++++++ src/fit_model.jl | 22 ++++++- 2 files changed, 111 insertions(+), 2 deletions(-) create mode 100644 examples/zuobench/error_table_svd.jl diff --git a/examples/zuobench/error_table_svd.jl b/examples/zuobench/error_table_svd.jl new file mode 100644 index 00000000..d2922128 --- /dev/null +++ b/examples/zuobench/error_table_svd.jl @@ -0,0 +1,91 @@ +# This script reproduces the error table from the 2023/24 ACEpotentials +# paper, but with the new version 0.8 ACE models. + +using Distributed, Random +addprocs(10, exeflags="--project=$(Base.active_project())") +@everywhere using ACEpotentials, PrettyTables + +# the dataset is provided via ACE1pack artifacts as a convenient benchmarkset +# the following chemical symbols are available: +syms = [:Ni, :Cu, :Li, :Mo, :Si, :Ge] + +totaldegree_tiny = [ 16, 12, 8 ] # very small model: ~ 120 basis functions +totaldegree_sm = [ 20, 16, 12 ] # small model: ~ 300 basis functions +totaldegree_lge = [ 25, 21, 17 ] # large model: ~ 1000 basis functions + +## + +err = Dict("lge_blr" => Dict("E" => Dict(), "F" => Dict()), + "sm_blr" => Dict("E" => Dict(), "F" => Dict()), + "lge_svd" => Dict("E" => Dict(), "F" => Dict()), + "sm_svd" => Dict("E" => Dict(), "F" => Dict()) ) + +for sym in syms + @info("---------- fitting $(sym) ----------") + train, test, _ = ACEpotentials.example_dataset("Zuo20_$sym") + + # specify the models + model_sm = ace1_model(elements = [sym,], order = 3, totaldegree = totaldegree_sm) + model_lge = ace1_model(elements = [sym,], order = 3, totaldegree = totaldegree_lge) + @info("$sym models: length = $(length_basis(model_lge)), $(length_basis(model_sm))") + + # train the model + solver = ACEfit.BLR(; factorization = :svd) + acefit!(train, model_sm; solver=solver); GC.gc() + acefit!(train, model_lge; solver=solver); GC.gc() + + # compute and store errors for later visualisation + err_sm = ACEpotentials.linear_errors(test, model_sm) + err_lge = ACEpotentials.linear_errors(test, model_lge) + err["sm_blr" ]["E"][sym] = err_sm["mae"]["set"]["E"] * 1000 + err["sm_blr" ]["F"][sym] = err_sm["mae"]["set"]["F"] + err["lge_blr"]["E"][sym] = err_lge["mae"]["set"]["E"] * 1000 + err["lge_blr"]["F"][sym] = err_lge["mae"]["set"]["F"] + + # train models with validated svd + shuffle!(train); isplit = floor(Int, 0.8 * length(train)) + train1 = train[1:isplit]; val1 = train[isplit+1:end] + model_sm = ace1_model(elements = [sym,], order = 3, totaldegree = totaldegree_sm) + model_lge = ace1_model(elements = [sym,], order = 3, totaldegree = totaldegree_lge) + solver = ACEfit.TruncatedSVD() # truncation will be determined from validation set + acefit!(train1, model_sm; validation_set = val1, solver=solver); GC.gc() + acefit!(train1, model_lge; validation_set = val1, solver=solver); GC.gc() + err_sm = ACEpotentials.linear_errors(test, model_sm) + err_lge = ACEpotentials.linear_errors(test, model_lge) + err["sm_svd" ]["E"][sym] = err_sm["mae"]["set"]["E"] * 1000 + err["sm_svd" ]["F"][sym] = err_sm["mae"]["set"]["F"] + err["lge_svd"]["E"][sym] = err_lge["mae"]["set"]["E"] * 1000 + err["lge_svd"]["F"][sym] = err_lge["mae"]["set"]["F"] +end + + +## + +header = ([ "", "ACE(sm, blr)", "ACE(sm, svd)", "ACE(lge, blr)", "ACE(lge, svd)", "GAP", "MTP"]) +e_table_gap_mtp = [ 0.42 0.48; 0.46 0.41; 0.49 0.49; 2.24 2.83; 2.91 2.21; 2.06 1.79] +f_table_gap_mtp = [ 0.02 0.01; 0.01 0.01; 0.01 0.01; 0.09 0.09; 0.07 0.06; 0.05 0.05] + +e_table = hcat(string.(syms), + [round(err["sm_blr"]["E"][sym], digits=3) for sym in syms], + [round(err["sm_svd"]["E"][sym], digits=3) for sym in syms], + [round(err["lge_blr"]["E"][sym], digits=3) for sym in syms], + [round(err["lge_svd"]["E"][sym], digits=3) for sym in syms], + e_table_gap_mtp) + +f_table = hcat(string.(syms), + [round(err["sm_blr"]["F"][sym], digits=3) for sym in syms], + [round(err["sm_svd"]["F"][sym], digits=3) for sym in syms], + [round(err["lge_blr"]["F"][sym], digits=3) for sym in syms], + [round(err["lge_svd"]["F"][sym], digits=3) for sym in syms], + f_table_gap_mtp) + +println("Energy Error") +pretty_table(e_table; header = header) + +println("Force Error") +pretty_table(f_table; header = header) + +## + +# pretty_table(e_table, backend = Val(:latex), label = "Energy MAE", header = header) +# pretty_table(f_table, backend = Val(:latex), label = "Forces MAE", header = header) \ No newline at end of file diff --git a/src/fit_model.jl b/src/fit_model.jl index 34e5929d..7aae62c6 100644 --- a/src/fit_model.jl +++ b/src/fit_model.jl @@ -75,6 +75,7 @@ the label of the data to which the parameters will be fitted. in a JSON format, which can be read from Julia or Python """ function acefit!(raw_data::AbstractArray{<: AbstractSystem}, model; + validation_set = nothing, solver = ACEfit.BLR(), weights = default_weights(), energy_key = "energy", @@ -133,8 +134,25 @@ function acefit!(raw_data::AbstractArray{<: AbstractSystem}, model; # then solve the transformed problem Ap = Diagonal(W) * (A / P) Y = W .* Y - result = ACEfit.solve(solver, Ap, Y) - coeffs = P \ result["C"] + + if isnothing(validation_set) + result = ACEfit.solve(solver, Ap, Y) + + else + @info("assemble validation system") + val_data = make_atoms_data(validation_set, model; + energy_key = energy_key, + force_key = force_key, + virial_key = virial_key, + weights = weights) + Av, Yv, Wv = ACEfit.assemble(val_data, model) + Avp = Diagonal(Wv) * (Av / P) + Yv = Wv .* Yv + + result = ACEfit.solve(solver, Ap, Y, Avp, Yv) + end + + coeffs = P \ result["C"] # dispatch setting of parameters __set_params!(model, coeffs)