From f9df18d294b13ab2a86db7aa106f889cbafcd241 Mon Sep 17 00:00:00 2001 From: Samuel Larkin Date: Wed, 2 Oct 2024 14:31:36 -0400 Subject: [PATCH] feat: added a friendlier error message when providing the wrong model type --- fs2/cli/synthesize.py | 11 ++++++++++- fs2/model.py | 11 ++++++++++- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/fs2/cli/synthesize.py b/fs2/cli/synthesize.py index 1edacc8..b41e695 100644 --- a/fs2/cli/synthesize.py +++ b/fs2/cli/synthesize.py @@ -416,7 +416,16 @@ def synthesize( # noqa: C901 # Load checkpoints print(f"Loading checkpoint from {model_path}", file=sys.stderr) - model: FastSpeech2 = FastSpeech2.load_from_checkpoint(model_path).to(device) # type: ignore + + from pydantic import ValidationError + + try: + model: FastSpeech2 = FastSpeech2.load_from_checkpoint(model_path).to(device) # type: ignore + except (TypeError, ValidationError) as e: + # TODO: should we print the following message only if there are more than X validation errors? + # e.error_count() > X + logger.error(f"Unable to load {model_path}: {e}") + sys.exit(1) model.eval() # get global step diff --git a/fs2/model.py b/fs2/model.py index 34b909c..d258782 100644 --- a/fs2/model.py +++ b/fs2/model.py @@ -43,7 +43,16 @@ def __init__( """ """ super().__init__() if not isinstance(config, FastSpeech2Config): - config = FastSpeech2Config(**config) + from pydantic import ValidationError + + try: + config = FastSpeech2Config(**config) + except ValidationError as e: + logger.error(f"{e}") + raise TypeError( + "Unable to load config. Possible causes: is it really a FastSpeech2Config? or the correct version?" + ) + if stats is not None and not isinstance(stats, Stats): stats = Stats(**stats) self.config = config