Skip to content

Commit

Permalink
test load potential
Browse files Browse the repository at this point in the history
  • Loading branch information
Christoph Ortner committed Sep 12, 2024
1 parent f9928ee commit 41e0774
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 47 deletions.
5 changes: 3 additions & 2 deletions docs/src/gettingstarted/saving-and-loading.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,11 @@ Suppose the result of `runfit.jl` (or an analogous approach) is saved to
`path/result.json`. If the original or a compatible Julia environment is
activated, then
```julia
model = ACEpotentials.load_model("path/result.json")
model, meta = ACEpotentials.load_model("path/result.json")
```
will return a `model::ACEPotential` structure that should be equivalent
to the original fitted potential.
to the original fitted potential. The `meta::Dict` dictionary contains all
the remaining information saved in `result.json`.

### Recovering the Julia environment

Expand Down
7 changes: 6 additions & 1 deletion scripts/runfit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,5 +62,10 @@ end
# saving results
result_file = args_dict["output"]["model"]
ACEpotentials.save_model(model, joinpath(@__DIR__(), args_dict["output"]["model"]);
make_model_args = args_dict,
model_spec = args_dict,
errors = err, )


# To load the model, active the same Julia environment, then run
# `model, meta = ACEpotentials.load_model("path/to/model.json")`
# the resulting `model` object should be equivalent to the fitted `model`.
29 changes: 18 additions & 11 deletions src/json_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@ import ArgParse
using NamedTupleTools
import .ACE1compat
using ACEfit
using Optimisers: destructure

# === nt utilities ===
function create_namedtuple(dict)
return NamedTuple{Tuple(Symbol.(keys(dict)))}(values(dict))
end

function nested_namedtuple_to_dict(nt)
return Dict(k => isa(v, NamedTuple) ? nested_namedtuple_to_dict(v) : v for (k, v) in pairs(nt))
end
recursive_dict2nt(x) = x

recursive_dict2nt(D::Dict) = (;
[ Symbol(key) => recursive_dict2nt(D[key]) for key in keys(D)]... )

function _sanitize_arg(arg)
if isa(arg, Vector)
Expand Down Expand Up @@ -107,26 +106,26 @@ save model constructor, model parameters, and other information to a JSON file.
* `model` : the model to be saved
* `filename` : the name of the file to which the model will be saved
* `make_model_args` : the arguments used to construct the model; without this
* `model_spec` : the arguments used to construct the model; without this
the model cannot be reconstructed unless the original script is available
* `errors` : the fitting / test errors computed during the fitting
* `verbose` : print information about the saving process
"""
function save_model(model, filename;
make_model_args = nothing,
model_spec = nothing,
errors = nothing,
verbose = true,
meta = Dict(), )

D = Dict("model_parameters" => model.ps,
D = Dict("model_parameters" => destructure(model.ps)[1],
"meta" => meta)

if isnothing(make_model_args)
if isnothing(model_spec)
if verbose
@warn("Only model parameters are saved but no information to reconstruct the model.")
end
else
D["make_model_args"] = make_model_args
D["model_spec"] = model_spec
end

if !isnothing(errors)
Expand All @@ -141,3 +140,11 @@ function save_model(model, filename;
@info "Results saved to file: $filename"
end
end


function load_model(filename)
D = JSON.parsefile(filename)
model = make_model(D["model_spec"])
set_parameters!(model, D["model_parameters"])
return model, D
end
4 changes: 2 additions & 2 deletions src/models/calculators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,14 @@ length_unit(::ACEPotential) = u"Å"
initialparameters(rng::AbstractRNG, V::ACEPotential) = initialparameters(rng, V.model)
initialstates(rng::AbstractRNG, V::ACEPotential) = initialstates(rng, V.model)

set_parameters!(V::ACEPotential, ps) = (V.ps = ps; V)
set_parameters!(V::ACEPotential, ps::NamedTuple) = (V.ps = ps; V)
set_states!(V::ACEPotential, st) = (V.st = st; V)
set_psst!(V::ACEPotential, ps, st) = (V.ps = ps; V.st = st; V)

splinify(V::ACEPotential) = splinify(V, V.ps)
splinify(V::ACEPotential, ps) = ACEPotential(splinify(V.model, ps), nothing, nothing)

function set_parameters!(V::ACEPotential, θ::AbstractVector{<: Number})
function set_parameters!(V::ACEPotential, θ::AbstractVector)
ps_vec, _restruct = destructure(V.ps)
ps = _restruct(θ)
return set_parameters!(V, ps)
Expand Down
52 changes: 21 additions & 31 deletions test/test_io.jl
Original file line number Diff line number Diff line change
@@ -1,35 +1,25 @@
using Test
using ACEpotentials
using LazyArtifacts

using Test, ACEpotentials, AtomsBuilder
using AtomsCalculators: potential_energy
using Polynomials4ML.Testing: print_tf

model = acemodel(elements = [:Si],
Eref = [:Si => -158.54496821],
rcut = 5.5,
order = 3,
totaldegree = 12)
data = read_extxyz(artifact"Si_tiny_dataset" * "/Si_tiny.xyz")
data_keys = [:energy_key => "dft_energy",
:force_key => "dft_force",
:virial_key => "dft_virial"]
weights = Dict("default" => Dict("E"=>30.0, "F"=>1.0, "V"=>1.0),
"liq" => Dict("E"=>10.0, "F"=>0.66, "V"=>0.25))
model_spec = Dict("model_name" => "ACE1",
"elements" => ["Ti", "Al"],
"rcut" => 5.5,
"order" => 3,
"totaldegree" => 8)
model = ACEpotentials.make_model(model_spec)
set_parameters!(model, randn(length_basis(model)))


acefit!(model, data;
data_keys...,
weights = weights,
solver = ACEfit.LSQR(
damp = 2e-2,
conlim = 1e12,
atol = 1e-7,
maxiter = 100000,
verbose = false
)
)
fname = tempname() * ".json"
pot = ACEpotential(model.potential.components)
@test_throws AssertionError save_potential(fname, model; meta="meta test")
save_potential(fname, model; meta=Dict("test"=>"meta test") )
npot = load_potential(fname; new_format=true)
@test ace_energy(pot, data[1]) ace_energy(npot, data[1])
ACEpotentials.save_model(model, fname; model_spec = model_spec)

model1, meta = ACEpotentials.load_model(fname)

for ntest = 1:10
sys = rattle!(bulk(:Al, cubic=true) * 2, 0.1)
sys = randz!(sys, [:Ti => 0.5, :Al => 0.5])
print_tf( @test potential_energy(sys, model) potential_energy(sys, model1) )
end

rm(fname)

0 comments on commit 41e0774

Please sign in to comment.