diff --git a/hfgl/model.py b/hfgl/model.py index b73cfde..c646767 100644 --- a/hfgl/model.py +++ b/hfgl/model.py @@ -454,6 +454,8 @@ class HiFiGANGenerator(pl.LightningModule): for low-requirement model storage and inference. """ + __version__ = "1" + def __init__(self, config: dict | VocoderConfig): super().__init__() @@ -481,6 +483,13 @@ def on_load_checkpoint(self, checkpoint): Note, this shouldn't fail on different versions of pydantic anymore, but it will fail on breaking changes to the config. We should catch those exceptions and handle them appropriately.""" + if "model_info" in checkpoint: + assert ( + checkpoint["model_info"]["name"] == self.__class__.__name__ + ), f"""Wrong model type ({checkpoint["model_info"]["name"]}), we are expecting a { self.__class__.__name__ }""" + assert ( + checkpoint["model_info"]["version"] == "1" + ), f"""Wrong model's version({checkpoint["model_info"]["version"]}), we are expecting version {self.__version__}""" try: config = VocoderConfig(**checkpoint["hyper_parameters"]["config"]) except ValidationError as e: @@ -500,6 +509,10 @@ def on_load_checkpoint(self, checkpoint): def on_save_checkpoint(self, checkpoint): """Serialize the checkpoint hyperparameters""" checkpoint["hyper_parameters"]["config"] = self.config.model_checkpoint_dump() + checkpoint["model_info"] = { + "name": self.__class__.__name__, + "version": self.__version__, + } class HiFiGAN(HiFiGANGenerator):