diff --git a/src/evaluate/saving.py b/src/evaluate/saving.py index 4eea1a6a..19f4b8a2 100644 --- a/src/evaluate/saving.py +++ b/src/evaluate/saving.py @@ -4,12 +4,26 @@ import sys from datetime import datetime from pathlib import Path +import numpy as np from datasets.utils.filelock import FileLock from . import __version__ +class NpEncoder(json.JSONEncoder): + """Numpy aware JSON encoder.""" + + def default(self, o): + if isinstance(o, np.floating): + return float(o) + if isinstance(o, np.integer): + return int(o) + if isinstance(o, np.ndarray): + return o.tolist() + return super().default(o) + + def save(path_or_file, **data): """ Saves results to a JSON file. Also saves system information such as current time, current commit @@ -40,7 +54,7 @@ def save(path_or_file, **data): with FileLock(str(file_path) + ".lock"): with open(file_path, "w") as f: - json.dump(data, f) + json.dump(data, f, cls=NpEncoder) # cleanup lock file try: @@ -65,9 +79,13 @@ def _setup_path(path_or_file, current_time): def _git_commit_hash(): - res = subprocess.run("git rev-parse --is-inside-work-tree".split(), cwd="./", stdout=subprocess.PIPE) + res = subprocess.run( + "git rev-parse --is-inside-work-tree".split(), cwd="./", stdout=subprocess.PIPE + ) if res.stdout.decode().strip() == "true": - res = subprocess.run("git rev-parse HEAD".split(), cwd=os.getcwd(), stdout=subprocess.PIPE) + res = subprocess.run( + "git rev-parse HEAD".split(), cwd=os.getcwd(), stdout=subprocess.PIPE + ) return res.stdout.decode().strip() else: return None