Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: add shell output file and simple analysis #254

Merged
merged 6 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion scripts/example_params.json
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,11 @@
"virial_key": "dft_virial"
},
"output": {
"model": "results.json"
"model": "results.json",
"dimer": true,
"error_table": true,
"scatter": true,
"make_plots": false
}
}

129 changes: 115 additions & 14 deletions scripts/runfit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ parser = ArgParseSettings(description="Fit an ACE potential from parameters file
help = "Number of processes for BLAS to use when solving the LsqDB"
arg_type = Int
default = 1
"--result_folder", "-o"
help = "folder path to store all results"
end

# parse the command
Expand All @@ -26,6 +28,15 @@ args = parse_args(parser)
@info("Load parameter file")
args_dict = JSON.parsefile(args["params"])

# outputs
if args["result_folder"] === nothing
@info("result_folder not specified, create result folder at $(args["params"][1:end-5] * "results")")
args["result_folder"] = args["params"][1:end-5] * "_results"
end
res_path = args["result_folder"]
mkpath(res_path)
@info("result storing at $(res_path)")

@info("Construct ACEmodel of type $(args_dict["model"]["model_name"])")
model = ACEpotentials.make_model(args_dict["model"])

Expand All @@ -48,24 +59,114 @@ acefit!(train, model;
weights = weights,
solver = solver)

# training errors
err_train = ACEpotentials.linear_errors(train, model; data_keys..., weights=weights)
err = Dict("train" => err_train)
# --- saving results and model below ---
D = Dict()
OD = args_dict["output"]

# train/test errors
if OD["error_table"] || OD["scatter"]
@info("evaluating errors")
# training errors
err_train, train_evf = ACEpotentials.linear_errors(train, model; data_keys..., weights=weights, return_efv = true)
err = Dict("train" => err_train)
if OD["scatter"]
D["train_evf"] = train_evf
end

# test errors (if a test dataset exists)
if haskey(args_dict["data"], "test_file")
test = ExtXYZ.load(args_dict["data"]["test_file"])
err_test = ACEpotentials.linear_errors(test, model; data_keys..., weights=weights)
err["test"] = err_test
# test errors (if a test dataset exists)
if haskey(args_dict["data"], "test_file")
test = ExtXYZ.load(args_dict["data"]["test_file"])
err_test, test_evf = ACEpotentials.linear_errors(test, model; data_keys..., weights=weights, return_efv = true)
err["test"] = err_test
if OD["scatter"]
D["test_evf"] = test_evf
end
end
end

# saving results
result_file = args_dict["output"]["model"]
ACEpotentials.save_model(model, joinpath(@__DIR__(), args_dict["output"]["model"]);
model_spec = args_dict,
errors = err, )
# dimer analysis if specified
if args_dict["output"]["dimer"]
D["dimers"] = ACEpotentials.dimers(model, args_dict["model"]["elements"])
end

# saving results to folder
ACEpotentials.save_model(model, joinpath(res_path, args_dict["output"]["model"]),
model_spec = args_dict,
errors = err,
meta = D)

# To load the model, active the same Julia environment, then run
# `model, meta = ACEpotentials.load_model("path/to/model.json")`
# the resulting `model` object should be equivalent to the fitted `model`.
# the resulting `model` object should be equivalent to the fitted `model`.

if args_dict["output"]["error_table"]
et_file = open(joinpath(res_path, "error_table.txt"), "w")
ori_stdout = stdout; ori_stderr = stderr
redirect_stdio(stdout=et_file, stderr = et_file)
@info("Training error")
print_errors_tables(err_train)
@info("Testing error")
print_errors_tables(err_test)
redirect_stdio(stdout=ori_stdout, stderr=ori_stderr);
close(et_file)
end

# --- make plots if specified ---
if args_dict["output"]["make_plots"]
# 1. scatter EFV
if args_dict["output"]["scatter"]
using Plots
function scatter_quantity(Xs, Ys; args...)
p = scatter(Xs, Ys, markersize=5; args...)
min_val = min(minimum(Xs), minimum(Ys))
max_val = max(maximum(Xs), maximum(Ys))
plot!([min_val, max_val], [min_val, max_val], linestyle=:dash, color=:red, linewidth=2.0)
return p
end
function concat_preallocate(vectors)
total_length = sum(length(v) for v in vectors)
result = zeros(Float64, total_length)
pos = 1
for v in vectors
result[pos:pos+length(v)-1] .= v
pos += length(v)
end
return result
end
PP = Dict("train" => Dict(), "test" => Dict())
for X in ("E", "F", "V")
PP["train"][X] = scatter_quantity(concat_preallocate(train_evf[X * "pred"]),
concat_preallocate(train_evf[X * "ref"]);
title = X * "train", xlabel = "predicted", markerstrokewidth=0,
ylabel = "ground truth", legend = nothing,
alpha = 0.5
)
print("Done train")
PP["test"][X] = scatter_quantity(concat_preallocate(test_evf[X * "pred"]),
concat_preallocate(test_evf[X * "ref"]);
title = X * "test", xlabel = "predicted", markerstrokewidth=0,
ylabel = "ground truth", legend = nothing,
alpha = 0.5
)
end
pall = plot(
PP["train"]["E"], PP["test"]["E"],
PP["train"]["F"], PP["test"]["F"],
PP["test"]["V"], PP["test"]["V"],
layout=(3, 2), size=(800, 800),
)
savefig(pall, joinpath(res_path, "scatter.png"))
end

# 2. dimer analysis
if args_dict["output"]["dimer"]
ZZ = args_dict["model"]["elements"]
for i = 1:length(args_dict["model"]["elements"]), j = 1:i
Dij = D["dimers"][(ZZ[i], ZZ[j])]
# take > -10 eV portions
idx = ustrip.(Dij[2]) .> -10
p = plot(Dij[1][idx], -Dij[2][idx], xlabel = "r", ylabel = "E")
savefig(p, joinpath(res_path, "dimer_[$(ZZ[i]), $(ZZ[j])].png"))
end
end
end
21 changes: 19 additions & 2 deletions src/atoms_data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,9 @@ end


function linear_errors(data::AbstractArray{AtomsData}, model;
group_key="config_type", verbose=true)
group_key="config_type", verbose=true,
return_efv = false
)

mae = Dict("E"=>0.0, "F"=>0.0, "V"=>0.0)
rmse = Dict("E"=>0.0, "F"=>0.0, "V"=>0.0)
Expand All @@ -246,6 +248,11 @@ function linear_errors(data::AbstractArray{AtomsData}, model;
config_rmse = Dict{String,Any}()
config_num = Dict{String,Any}()

evf_dict = Dict("Epred" => [], "Eref" => [],
"Fpred" => [], "Fref" => [],
"Vpred" => [], "Vref" => [],
)

for d in data

c_t = group_type(d; group_key)
Expand All @@ -268,6 +275,8 @@ function linear_errors(data::AbstractArray{AtomsData}, model;
config_mae[c_t]["E"] += abs(estim-exact)
config_rmse[c_t]["E"] += (estim-exact)^2
config_num[c_t]["E"] += 1
push!(evf_dict["Epred"], estim)
push!(evf_dict["Eref"], exact)
end

# force errors
Expand All @@ -280,6 +289,8 @@ function linear_errors(data::AbstractArray{AtomsData}, model;
config_mae[c_t]["F"] += sum(abs, estim - exact)
config_rmse[c_t]["F"] += sum(abs2, estim - exact)
config_num[c_t]["F"] += 3*length(d.system)
push!(evf_dict["Fpred"], estim)
push!(evf_dict["Fref"], exact)
end

# virial errors
Expand All @@ -297,6 +308,8 @@ function linear_errors(data::AbstractArray{AtomsData}, model;
config_mae[c_t]["V"] += sum(abs, estim-exact)
config_rmse[c_t]["V"] += sum(abs2, estim-exact)
config_num[c_t]["V"] += 6
push!(evf_dict["Vpred"], estim)
push!(evf_dict["Vref"], exact)
end
end

Expand Down Expand Up @@ -327,7 +340,11 @@ function linear_errors(data::AbstractArray{AtomsData}, model;
print_errors_tables(config_errors)
end

return config_errors
if return_efv
return config_errors, evf_dict
else
return config_errors
end
end


Expand Down
6 changes: 4 additions & 2 deletions src/fit_model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -189,12 +189,14 @@ function linear_errors(raw_data::AbstractArray{<: AbstractSystem}, model;
force_key = "force",
virial_key = "virial",
weights = default_weights(),
verbose = true )
verbose = true,
return_efv = false
)
data = [ AtomsData(at; energy_key = energy_key, force_key=force_key,
virial_key = virial_key, weights = weights,
v_ref = _get_Vref(model))
for at in raw_data ]
return linear_errors(data, model; verbose=verbose)
return linear_errors(data, model; verbose=verbose, return_efv = return_efv)
end


Expand Down
2 changes: 1 addition & 1 deletion test/test_json.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ end
@info("Load the results")
using JSON
example_params = JSON.parsefile(joinpath(tmpproj, "example_params.json"))
results = JSON.parsefile(joinpath(tmpproj, "results.json"))
results = JSON.parsefile(joinpath(tmpproj, "example_params_results/results.json"))

@info("Clean up temporary project")
run(`rm -rf $tmpproj`)
Expand Down
Loading