diff --git a/hfgl/model.py b/hfgl/model.py index fce5845..846bf9d 100644 --- a/hfgl/model.py +++ b/hfgl/model.py @@ -11,6 +11,8 @@ dynamic_range_compression_torch, get_spectral_transform, ) +from loguru import logger +from pydantic import ValidationError from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d from torch.nn.utils import spectral_norm from torch.nn.utils.parametrizations import weight_norm @@ -454,8 +456,16 @@ class HiFiGANGenerator(pl.LightningModule): def __init__(self, config: dict | VocoderConfig): super().__init__() + if not isinstance(config, VocoderConfig): - config = VocoderConfig(**config) + try: + config = VocoderConfig(**config) + except ValidationError as e: + logger.error(f"{e}") + raise TypeError( + "Unable to load config. Possible causes: is it really a VocoderConfig? or the correct version?" + ) + self.config = config self.generator = Generator(self.config) @@ -464,7 +474,14 @@ def on_load_checkpoint(self, checkpoint): Note, this shouldn't fail on different versions of pydantic anymore, but it will fail on breaking changes to the config. We should catch those exceptions and handle them appropriately.""" - self.config = VocoderConfig(**checkpoint["hyper_parameters"]["config"]) + try: + config = VocoderConfig(**checkpoint["hyper_parameters"]["config"]) + except ValidationError as e: + logger.error(f"{e}") + raise TypeError( + "Unable to load config. Possible causes: is it really a VocoderConfig? or the correct version?" + ) + self.config = config def on_save_checkpoint(self, checkpoint): """Serialize the checkpoint hyperparameters""" @@ -479,7 +496,13 @@ def __init__(self, config: dict | VocoderConfig): # Because we serialize the configurations when saving checkpoints, # sometimes what is passed is actually just a dict. if not isinstance(config, VocoderConfig): - config = VocoderConfig(**config) + try: + config = VocoderConfig(**config) + except ValidationError as e: + logger.error(f"{e}") + raise TypeError( + "Unable to load config. Possible causes: is it really a VocoderConfig? or the correct version?" + ) self.config = config self.mpd = MultiPeriodDiscriminator(config) self.msd = MultiScaleDiscriminator(config)