diff --git a/scripts/example_params.json b/scripts/example_params.json index 0b322a52..3dc6b35a 100644 --- a/scripts/example_params.json +++ b/scripts/example_params.json @@ -63,7 +63,11 @@ "virial_key": "dft_virial" }, "output": { - "model": "results.json" + "model": "results.json", + "dimer": true, + "error_table": true, + "scatter": true, + "make_plots": false } } \ No newline at end of file diff --git a/scripts/runfit.jl b/scripts/runfit.jl index 245947da..0651afde 100644 --- a/scripts/runfit.jl +++ b/scripts/runfit.jl @@ -18,6 +18,8 @@ parser = ArgParseSettings(description="Fit an ACE potential from parameters file help = "Number of processes for BLAS to use when solving the LsqDB" arg_type = Int default = 1 + "--result_folder", "-o" + help = "folder path to store all results" end # parse the command @@ -26,6 +28,15 @@ args = parse_args(parser) @info("Load parameter file") args_dict = JSON.parsefile(args["params"]) +# outputs +if args["result_folder"] === nothing + @info("result_folder not specified, create result folder at $(args["params"][1:end-5] * "results")") + args["result_folder"] = args["params"][1:end-5] * "_results" +end +res_path = args["result_folder"] +mkpath(res_path) +@info("result storing at $(res_path)") + @info("Construct ACEmodel of type $(args_dict["model"]["model_name"])") model = ACEpotentials.make_model(args_dict["model"]) @@ -48,24 +59,114 @@ acefit!(train, model; weights = weights, solver = solver) -# training errors -err_train = ACEpotentials.linear_errors(train, model; data_keys..., weights=weights) -err = Dict("train" => err_train) +# --- saving results and model below --- +D = Dict() +OD = args_dict["output"] + +# train/test errors +if OD["error_table"] || OD["scatter"] + @info("evaluating errors") + # training errors + err_train, train_evf = ACEpotentials.linear_errors(train, model; data_keys..., weights=weights, return_efv = true) + err = Dict("train" => err_train) + if OD["scatter"] + D["train_evf"] = train_evf + end -# test errors (if a test dataset exists) -if haskey(args_dict["data"], "test_file") - test = ExtXYZ.load(args_dict["data"]["test_file"]) - err_test = ACEpotentials.linear_errors(test, model; data_keys..., weights=weights) - err["test"] = err_test + # test errors (if a test dataset exists) + if haskey(args_dict["data"], "test_file") + test = ExtXYZ.load(args_dict["data"]["test_file"]) + err_test, test_evf = ACEpotentials.linear_errors(test, model; data_keys..., weights=weights, return_efv = true) + err["test"] = err_test + if OD["scatter"] + D["test_evf"] = test_evf + end + end end -# saving results -result_file = args_dict["output"]["model"] -ACEpotentials.save_model(model, joinpath(@__DIR__(), args_dict["output"]["model"]); - model_spec = args_dict, - errors = err, ) +# dimer analysis if specified +if args_dict["output"]["dimer"] + D["dimers"] = ACEpotentials.dimers(model, args_dict["model"]["elements"]) +end +# saving results to folder +ACEpotentials.save_model(model, joinpath(res_path, args_dict["output"]["model"]), + model_spec = args_dict, + errors = err, + meta = D) # To load the model, active the same Julia environment, then run # `model, meta = ACEpotentials.load_model("path/to/model.json")` -# the resulting `model` object should be equivalent to the fitted `model`. \ No newline at end of file +# the resulting `model` object should be equivalent to the fitted `model`. + +if args_dict["output"]["error_table"] + et_file = open(joinpath(res_path, "error_table.txt"), "w") + ori_stdout = stdout; ori_stderr = stderr + redirect_stdio(stdout=et_file, stderr = et_file) + @info("Training error") + print_errors_tables(err_train) + @info("Testing error") + print_errors_tables(err_test) + redirect_stdio(stdout=ori_stdout, stderr=ori_stderr); + close(et_file) +end + +# --- make plots if specified --- +if args_dict["output"]["make_plots"] + # 1. scatter EFV + if args_dict["output"]["scatter"] + using Plots + function scatter_quantity(Xs, Ys; args...) + p = scatter(Xs, Ys, markersize=5; args...) + min_val = min(minimum(Xs), minimum(Ys)) + max_val = max(maximum(Xs), maximum(Ys)) + plot!([min_val, max_val], [min_val, max_val], linestyle=:dash, color=:red, linewidth=2.0) + return p + end + function concat_preallocate(vectors) + total_length = sum(length(v) for v in vectors) + result = zeros(Float64, total_length) + pos = 1 + for v in vectors + result[pos:pos+length(v)-1] .= v + pos += length(v) + end + return result + end + PP = Dict("train" => Dict(), "test" => Dict()) + for X in ("E", "F", "V") + PP["train"][X] = scatter_quantity(concat_preallocate(train_evf[X * "pred"]), + concat_preallocate(train_evf[X * "ref"]); + title = X * "train", xlabel = "predicted", markerstrokewidth=0, + ylabel = "ground truth", legend = nothing, + alpha = 0.5 + ) + print("Done train") + PP["test"][X] = scatter_quantity(concat_preallocate(test_evf[X * "pred"]), + concat_preallocate(test_evf[X * "ref"]); + title = X * "test", xlabel = "predicted", markerstrokewidth=0, + ylabel = "ground truth", legend = nothing, + alpha = 0.5 + ) + end + pall = plot( + PP["train"]["E"], PP["test"]["E"], + PP["train"]["F"], PP["test"]["F"], + PP["test"]["V"], PP["test"]["V"], + layout=(3, 2), size=(800, 800), + ) + savefig(pall, joinpath(res_path, "scatter.png")) + end + + # 2. dimer analysis + if args_dict["output"]["dimer"] + ZZ = args_dict["model"]["elements"] + for i = 1:length(args_dict["model"]["elements"]), j = 1:i + Dij = D["dimers"][(ZZ[i], ZZ[j])] + # take > -10 eV portions + idx = ustrip.(Dij[2]) .> -10 + p = plot(Dij[1][idx], -Dij[2][idx], xlabel = "r", ylabel = "E") + savefig(p, joinpath(res_path, "dimer_[$(ZZ[i]), $(ZZ[j])].png")) + end + end +end \ No newline at end of file diff --git a/src/atoms_data.jl b/src/atoms_data.jl index 3e511792..ad41a8a3 100644 --- a/src/atoms_data.jl +++ b/src/atoms_data.jl @@ -235,7 +235,9 @@ end function linear_errors(data::AbstractArray{AtomsData}, model; - group_key="config_type", verbose=true) + group_key="config_type", verbose=true, + return_efv = false + ) mae = Dict("E"=>0.0, "F"=>0.0, "V"=>0.0) rmse = Dict("E"=>0.0, "F"=>0.0, "V"=>0.0) @@ -246,6 +248,11 @@ function linear_errors(data::AbstractArray{AtomsData}, model; config_rmse = Dict{String,Any}() config_num = Dict{String,Any}() + evf_dict = Dict("Epred" => [], "Eref" => [], + "Fpred" => [], "Fref" => [], + "Vpred" => [], "Vref" => [], + ) + for d in data c_t = group_type(d; group_key) @@ -268,6 +275,8 @@ function linear_errors(data::AbstractArray{AtomsData}, model; config_mae[c_t]["E"] += abs(estim-exact) config_rmse[c_t]["E"] += (estim-exact)^2 config_num[c_t]["E"] += 1 + push!(evf_dict["Epred"], estim) + push!(evf_dict["Eref"], exact) end # force errors @@ -280,6 +289,8 @@ function linear_errors(data::AbstractArray{AtomsData}, model; config_mae[c_t]["F"] += sum(abs, estim - exact) config_rmse[c_t]["F"] += sum(abs2, estim - exact) config_num[c_t]["F"] += 3*length(d.system) + push!(evf_dict["Fpred"], estim) + push!(evf_dict["Fref"], exact) end # virial errors @@ -297,6 +308,8 @@ function linear_errors(data::AbstractArray{AtomsData}, model; config_mae[c_t]["V"] += sum(abs, estim-exact) config_rmse[c_t]["V"] += sum(abs2, estim-exact) config_num[c_t]["V"] += 6 + push!(evf_dict["Vpred"], estim) + push!(evf_dict["Vref"], exact) end end @@ -327,7 +340,11 @@ function linear_errors(data::AbstractArray{AtomsData}, model; print_errors_tables(config_errors) end - return config_errors + if return_efv + return config_errors, evf_dict + else + return config_errors + end end diff --git a/src/fit_model.jl b/src/fit_model.jl index 7aae62c6..c6a5da00 100644 --- a/src/fit_model.jl +++ b/src/fit_model.jl @@ -189,12 +189,14 @@ function linear_errors(raw_data::AbstractArray{<: AbstractSystem}, model; force_key = "force", virial_key = "virial", weights = default_weights(), - verbose = true ) + verbose = true, + return_efv = false + ) data = [ AtomsData(at; energy_key = energy_key, force_key=force_key, virial_key = virial_key, weights = weights, v_ref = _get_Vref(model)) for at in raw_data ] - return linear_errors(data, model; verbose=verbose) + return linear_errors(data, model; verbose=verbose, return_efv = return_efv) end diff --git a/test/test_json.jl b/test/test_json.jl index ed810ffd..d7391269 100644 --- a/test/test_json.jl +++ b/test/test_json.jl @@ -29,7 +29,7 @@ end @info("Load the results") using JSON example_params = JSON.parsefile(joinpath(tmpproj, "example_params.json")) -results = JSON.parsefile(joinpath(tmpproj, "results.json")) +results = JSON.parsefile(joinpath(tmpproj, "example_params_results/results.json")) @info("Clean up temporary project") run(`rm -rf $tmpproj`)