diff --git a/mace/cli/create_lammps_model.py b/mace/cli/create_lammps_model.py index 3f647906..2530937e 100644 --- a/mace/cli/create_lammps_model.py +++ b/mace/cli/create_lammps_model.py @@ -54,7 +54,10 @@ def select_head(model): def main(): args = parse_args() model_path = args.model_path # takes model name as command-line input - model = torch.load(model_path) + model = torch.load( + model_path, + map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"), + ) model = model.double().to("cpu") if args.head is None: