Skip to content

Commit

Permalink
feat: added the model type and version when checkpointing
Browse files Browse the repository at this point in the history
  • Loading branch information
SamuelLarkin committed Oct 2, 2024
1 parent 9b9749e commit dacf533
Showing 1 changed file with 13 additions and 0 deletions.
13 changes: 13 additions & 0 deletions hfgl/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,8 @@ class HiFiGANGenerator(pl.LightningModule):
for low-requirement model storage and inference.
"""

__version__ = "1"

def __init__(self, config: dict | VocoderConfig):
super().__init__()

Expand Down Expand Up @@ -481,6 +483,13 @@ def on_load_checkpoint(self, checkpoint):
Note, this shouldn't fail on different versions of pydantic anymore,
but it will fail on breaking changes to the config. We should catch those exceptions
and handle them appropriately."""
if "model_info" in checkpoint:
assert (
checkpoint["model_info"]["name"] == self.__class__.__name__
), f"""Wrong model type ({checkpoint["model_info"]["name"]}), we are expecting a { self.__class__.__name__ }"""
assert (
checkpoint["model_info"]["version"] == "1"
), f"""Wrong model's version({checkpoint["model_info"]["version"]}), we are expecting version {self.__version__}"""
try:
config = VocoderConfig(**checkpoint["hyper_parameters"]["config"])
except ValidationError as e:
Expand All @@ -500,6 +509,10 @@ def on_load_checkpoint(self, checkpoint):
def on_save_checkpoint(self, checkpoint):
"""Serialize the checkpoint hyperparameters"""
checkpoint["hyper_parameters"]["config"] = self.config.model_checkpoint_dump()
checkpoint["model_info"] = {
"name": self.__class__.__name__,
"version": self.__version__,
}


class HiFiGAN(HiFiGANGenerator):
Expand Down

0 comments on commit dacf533

Please sign in to comment.