Skip to content

Commit

Permalink
specify map_location in mace_create_lammps_model
Browse files Browse the repository at this point in the history
  • Loading branch information
stenczelt committed Sep 18, 2024
1 parent 158b1f2 commit a2d3dd7
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion mace/cli/create_lammps_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,10 @@ def select_head(model):
def main():
args = parse_args()
model_path = args.model_path # takes model name as command-line input
model = torch.load(model_path)
model = torch.load(
model_path,
map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
)
model = model.double().to("cpu")

if args.head is None:
Expand Down

0 comments on commit a2d3dd7

Please sign in to comment.