diff --git a/aviary/predict.py b/aviary/predict.py index a7b55efe..0fed9fd4 100644 --- a/aviary/predict.py +++ b/aviary/predict.py @@ -157,7 +157,10 @@ def make_ensemble_predictions( @print_walltime(end_desc="predict_from_wandb_checkpoints") def predict_from_wandb_checkpoints( - runs: list[wandb.apis.public.Run], cache_dir: str, **kwargs: Any + runs: list[wandb.apis.public.Run], + checkpoint_filename: str = "checkpoint.pth", + cache_dir: str = "./checkpoint_cache", + **kwargs: Any, ) -> pd.DataFrame | tuple[pd.DataFrame, pd.DataFrame]: """Download and cache checkpoints for an ensemble of models, then make predictions on some dataset. Finally print ensemble metrics and store @@ -167,6 +170,7 @@ def predict_from_wandb_checkpoints( runs (list[wandb.apis.public.Run]): List of WandB runs to download model checkpoints from which are then loaded into memory to generate predictions for the input_col in df. + checkpoint_filename (str): Name of the checkpoint file to download. cache_dir (str): Directory to cache downloaded checkpoints in. **kwargs: Additional keyword arguments to pass to make_ensemble_predictions(). @@ -194,7 +198,7 @@ def predict_from_wandb_checkpoints( out_dir = f"{cache_dir}/{run_path}" os.makedirs(out_dir, exist_ok=True) - checkpoint_path = f"{out_dir}/checkpoint.pth" + checkpoint_path = f"{out_dir}/{checkpoint_filename}" checkpoint_paths.append(checkpoint_path) print(f"{idx:>3}/{len(runs)}: {run.url}\n\t{checkpoint_path}\n") @@ -202,7 +206,7 @@ def predict_from_wandb_checkpoints( md_file.write(f"[{run.name}]({run.url})\n") if not os.path.isfile(checkpoint_path): - run.file("checkpoint.pth").download(root=out_dir) + run.file(f"{checkpoint_filename}").download(root=out_dir) if target_col in kwargs: df, ensemble_metrics = make_ensemble_predictions(checkpoint_paths, **kwargs)