From 2f114567944f044dbb17469ba8dd84b9328ff992 Mon Sep 17 00:00:00 2001 From: Christoph Ortner Date: Wed, 11 Sep 2024 13:08:38 -0700 Subject: [PATCH] some fixes to zuo example --- examples/zuobench/error_table.jl | 18 +++--- examples/zuobench/zuo_asp.jl | 95 ++++++++++++++++---------------- src/fit_model.jl | 51 +++++++++++------ 3 files changed, 89 insertions(+), 75 deletions(-) diff --git a/examples/zuobench/error_table.jl b/examples/zuobench/error_table.jl index 88a8786f..a556e4d0 100644 --- a/examples/zuobench/error_table.jl +++ b/examples/zuobench/error_table.jl @@ -1,13 +1,15 @@ +# This script reproduces the error table from the 2023/24 ACEpotentials +# paper, but with the new version 0.8 ACE models. + using Distributed addprocs(10, exeflags="--project=$(Base.active_project())") @everywhere using ACEpotentials, PrettyTables # the dataset is provided via ACE1pack artifacts as a convenient benchmarkset # the following chemical symbols are available: -# syms = [:Ni, :Cu, :Li, :Mo, :Si, :Ge] -syms = [:Ni, :Cu, ] +syms = [:Ni, :Cu, :Li, :Mo, :Si, :Ge] -totaldegree_tiny = [ 18, 14, 10 ] # very small model: ~ 100 basis functions +totaldegree_tiny = [ 16, 12, 8 ] # very small model: ~ 120 basis functions totaldegree_sm = [ 20, 16, 12 ] # small model: ~ 300 basis functions totaldegree_lge = [ 25, 21, 17 ] # large model: ~ 1000 basis functions @@ -19,7 +21,6 @@ err = Dict("lge" => Dict("E" => Dict(), "F" => Dict()), for sym in syms @info("---------- fitting $(sym) ----------") train, test, _ = ACEpotentials.example_dataset("Zuo20_$sym") - train = train[1:5:end] # specify the models model_sm = ace1_model(elements = [sym,], order = 3, totaldegree = totaldegree_sm) @@ -27,8 +28,9 @@ for sym in syms @info("$sym models: length = $(length_basis(model_lge)), $(length_basis(model_sm))") # train the model - acefit!(train, model_sm; solver=ACEfit.BLR()); GC.gc() - acefit!(train, model_lge; solver=ACEfit.BLR()); GC.gc() + solver = ACEfit.BLR(; factorization = :svd) + acefit!(train, model_sm; solver=solver); GC.gc() + acefit!(train, model_lge; solver=solver); GC.gc() # compute and store errors for later visualisation err_sm = ACEpotentials.linear_errors(test, model_sm) @@ -64,5 +66,5 @@ pretty_table(f_table; header = header) ## -pretty_table(e_table, backend = Val(:latex), label = "Energy MAE", header = header) -pretty_table(f_table, backend = Val(:latex), label = "Forces MAE", header = header) \ No newline at end of file +# pretty_table(e_table, backend = Val(:latex), label = "Energy MAE", header = header) +# pretty_table(f_table, backend = Val(:latex), label = "Forces MAE", header = header) \ No newline at end of file diff --git a/examples/zuobench/zuo_asp.jl b/examples/zuobench/zuo_asp.jl index cf07e4e4..e07c9edf 100644 --- a/examples/zuobench/zuo_asp.jl +++ b/examples/zuobench/zuo_asp.jl @@ -1,3 +1,4 @@ + using Distributed addprocs(10, exeflags="--project=$(Base.active_project())") @everywhere using ACEpotentials, PrettyTables @@ -5,73 +6,69 @@ addprocs(10, exeflags="--project=$(Base.active_project())") # the dataset is provided via ACE1pack artifacts as a convenient benchmarkset # the following chemical symbols are available: # syms = [:Ni, :Cu, :Li, :Mo, :Si, :Ge] -syms = [:Ni, :Cu, ] -totaldegree_sm = [ 20, 16, 12 ] # small model: ~ 300 basis functions -totaldegree_lge = [ 25, 21, 17 ] # large model: ~ 1000 basis functions +# this element is quite interesting because it gives odd results +sym = :Ni # :Mo -## +# start with a large-ish model +totaldegree = [ 28, 23, 18 ] +model = ace1_model(elements = [sym,], order = 3, totaldegree = totaldegree) +P = algebraic_smoothness_prior(model; p = 4) -sym = :Si -model = ace1_model(elements = [sym,], order = 3, totaldegree = totaldegree_sm) +@info("$sym model, basis length = $(length_basis(model))") +## +@info("---------- Assemble Training and Validation Systems ----------") +_train_data, test_data, _ = ACEpotentials.example_dataset("Zuo20_$sym") +Random.shuffle!(_train_data) +train_data = _train_data[1:5:end] +val_data = _train_data[2:5:end] +# datakeys = (:) +At, yt, Wt = ACEpotentials.assemble(train_data, model) +Av, yv, Wv = ACEpotentials.assemble(val_data, model) -## +@info("Compute ASP Path") +solver = ACEfit.ASP(; P = P, select = (:byerror, 1.0), tsvd = true ) +asp_result = ACEfit.solve(solver, Wt .* At, Wt .* yt, Wv .* Av, Wv .* yv) -err = Dict("lge" => Dict("E" => Dict(), "F" => Dict()), - "sm" => Dict("E" => Dict(), "F" => Dict()) ) +## - -for sym in syms - @info("---------- fitting $(sym) ----------") - train, test, _ = ACEpotentials.example_dataset("Zuo20_$sym") - train = train[1:5:end] +@info("Look at the best model") +asp_result["C"] +set_parameters!(model, asp_result["C"]) +ACEpotentials.linear_errors(test_data, model) - # specify the models - # model_lge = ace1_model(elements = [sym,], order = 3, totaldegree = totaldegree_lge) - model_sm = ace1_model(elements = [sym,], order = 3, totaldegree = totaldegree_sm) - @info("$sym models: length = $(length_basis(model_lge)), $(length_basis(model_sm))") - # train the model - acefit!(train, model_sm; solver=ACEfit.BLR()); GC.gc() - # acefit!(train, model_lge; solver=ACEfit.BLR()); GC.gc() - # compute and store errors for later visualisation - err_sm = ACEpotentials.linear_errors(test, model_sm) - err_lge = ACEpotentials.linear_errors(test, model_lge) - err["sm" ]["E"][sym] = err_sm["mae"]["set"]["E"] * 1000 - err["sm" ]["F"][sym] = err_sm["mae"]["set"]["F"] - # err["lge"]["E"][sym] = err_lge["mae"]["set"]["E"] * 1000 - # err["lge"]["F"][sym] = err_lge["mae"]["set"]["F"] - err["lge"]["E"][sym] = 0.0 - err["lge"]["F"][sym] = 0.0 -end +# 120 +# 300 +# 1000 ## -header = ([ "", "ACE(sm)", "ACE(lge)", "GAP", "MTP"]) -e_table_gap_mtp = [ 0.42 0.48; 0.46 0.41; 0.49 0.49; 2.24 2.83; 2.91 2.21; 2.06 1.79] -f_table_gap_mtp = [ 0.02 0.01; 0.01 0.01; 0.01 0.01; 0.09 0.09; 0.07 0.06; 0.05 0.05] +# header = ([ "", "ACE(sm)", "ACE(lge)", "GAP", "MTP"]) +# e_table_gap_mtp = [ 0.42 0.48; 0.46 0.41; 0.49 0.49; 2.24 2.83; 2.91 2.21; 2.06 1.79] +# f_table_gap_mtp = [ 0.02 0.01; 0.01 0.01; 0.01 0.01; 0.09 0.09; 0.07 0.06; 0.05 0.05] -e_table = hcat(string.(syms), - [round(err["sm"]["E"][sym], digits=3) for sym in syms], - [round(err["lge"]["E"][sym], digits=3) for sym in syms], - e_table_gap_mtp) +# e_table = hcat(string.(syms), +# [round(err["sm"]["E"][sym], digits=3) for sym in syms], +# [round(err["lge"]["E"][sym], digits=3) for sym in syms], +# e_table_gap_mtp) -f_table = hcat(string.(syms), - [round(err["sm"]["F"][sym], digits=3) for sym in syms], - [round(err["lge"]["F"][sym], digits=3) for sym in syms], - f_table_gap_mtp) +# f_table = hcat(string.(syms), +# [round(err["sm"]["F"][sym], digits=3) for sym in syms], +# [round(err["lge"]["F"][sym], digits=3) for sym in syms], +# f_table_gap_mtp) -println("Energy Error") -pretty_table(e_table; header = header) +# println("Energy Error") +# pretty_table(e_table; header = header) -println("Force Error") -pretty_table(f_table; header = header) +# println("Force Error") +# pretty_table(f_table; header = header) -## +# ## -pretty_table(e_table, backend = Val(:latex), label = "Energy MAE", header = header) -pretty_table(f_table, backend = Val(:latex), label = "Forces MAE", header = header) \ No newline at end of file +# pretty_table(e_table, backend = Val(:latex), label = "Energy MAE", header = header) +# pretty_table(f_table, backend = Val(:latex), label = "Forces MAE", header = header) \ No newline at end of file diff --git a/src/fit_model.jl b/src/fit_model.jl index 381fa688..34e5929d 100644 --- a/src/fit_model.jl +++ b/src/fit_model.jl @@ -32,6 +32,24 @@ end # ---------------- the main fitting function +function make_atoms_data(raw_data::AbstractArray{<: AbstractSystem}, model; + energy_key, force_key, virial_key, weights) + + # convert raw data to AtomsData, which ACEfit.jl understands + data = map( raw_data ) do d + AtomsData(d; + energy_key = energy_key, + force_key = force_key, + virial_key = virial_key, + weights = weights, + v_ref = _get_Vref(model) + ) + end + + return data +end + + """ acefit!(rawdata, model; kwargs...) @@ -72,16 +90,11 @@ function acefit!(raw_data::AbstractArray{<: AbstractSystem}, model; kwargs... ) - # convert raw data to AtomsData, which ACEfit.jl understands - data = map( raw_data ) do d - AtomsData(d; - energy_key = energy_key, - force_key=force_key, - virial_key = virial_key, - weights = weights, - v_ref = _get_Vref(model) - ) - end + data = make_atoms_data(raw_data, model; + energy_key = energy_key, + force_key = force_key, + virial_key = virial_key, + weights = weights) # print some information about the dataset # (how many observations in each class) @@ -170,10 +183,10 @@ end function assemble(raw_data::AbstractArray{<: AbstractSystem}, model; - weights = default_weights(), - energy_key = "energy", - force_key = "force", - virial_key = "virial", + weights = default_weights(), + energy_key = "energy", + force_key = "force", + virial_key = "virial", # smoothness = 4, # prior = nothing, repulsion_restraint = false, @@ -181,9 +194,11 @@ function assemble(raw_data::AbstractArray{<: AbstractSystem}, model; mode = :serial, weights_only = false) - data = [ AtomsData(at; energy_key = energy_key, force_key=force_key, - virial_key = virial_key, weights = weights, - v_ref = model.Vref) for at in raw_data ] + data = make_atoms_data(raw_data, model; + energy_key = energy_key, + force_key = force_key, + virial_key = virial_key, + weights = weights) if repulsion_restraint error("Repulsion restraint is currently not implemented") @@ -195,6 +210,6 @@ function assemble(raw_data::AbstractArray{<: AbstractSystem}, model; return W end - A, Y, W = assemble(data, model.basis, mode) + A, Y, W = assemble(data, model) return A, Y, W end