Skip to content

Commit

Permalink
fix(finetune): split training into three parts
Browse files Browse the repository at this point in the history
1) training from scratch 2) resuming from a checkpoint without changes (preserves epoch and current step) and 3) fine-tuning by changing values in the training configuration
  • Loading branch information
roedoejet committed Sep 19, 2023
1 parent 65e4b01 commit cca7eae
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 2 deletions.
31 changes: 29 additions & 2 deletions everyvoice/base_cli/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
import os
from enum import Enum
from pathlib import Path
from pprint import pformat
from typing import List, Optional, Union

from deepdiff import DeepDiff
from loguru import logger
from tqdm import tqdm

Expand Down Expand Up @@ -147,8 +149,33 @@ def train_base_command(
and os.path.exists(config.training.finetune_checkpoint)
else None
)
tensorboard_logger.log_hyperparams(config.dict())
trainer.fit(model_obj, data, ckpt_path=last_ckpt)
# Train from Scratch
if last_ckpt is None:
model_obj = model(config)
tensorboard_logger.log_hyperparams(config.dict())
trainer.fit(model_obj, data)
else:
model_obj = model.load_from_checkpoint(last_ckpt)
# Check if the trainer has changed (but ignore subdir since it is specific to the run)
diff = DeepDiff(model_obj.config.training.dict(), config.training.dict())
training_config_diff = [
item for item in diff["values_changed"].items() if "sub_dir" not in item[0]
]
if training_config_diff:
model_obj.config.training = config.training
tensorboard_logger.log_hyperparams(config.dict())
# Finetune from Checkpoint
logger.warning(
f"""Some of your training hyperparameters have changed from your checkpoint at '{last_ckpt}', so we will override your checkpoint hyperparameters.
Your training logs will start from epoch 0/step 0, but will still use the weights from your checkpoint. Values Changed: {pformat(training_config_diff)}
"""
)
trainer.fit(model_obj, data)
else:
logger.info(f"Resuming from checkpoint '{last_ckpt}'")
# Resume from checkpoint
tensorboard_logger.log_hyperparams(config.dict())
trainer.fit(model_obj, data, ckpt_path=last_ckpt)


def inference_base_command(name: Enum):
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
clipdetect>=0.1.3
deepdiff>=6.5.0
anytree>=2.8.0
einops==0.5.0
g2p>=1.0.20230417
Expand Down

0 comments on commit cca7eae

Please sign in to comment.