Skip to content

Commit

Permalink
feat: removed staggering runs but is this the minimal code?
Browse files Browse the repository at this point in the history
  • Loading branch information
SamuelLarkin committed Sep 12, 2024
1 parent 030ab59 commit 1092bb7
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions everyvoice/base_cli/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,11 +242,12 @@ def train_base_command(
# This callback will always save the last checkpoint
# regardless of its performance.
last_ckpt_callback = ModelCheckpoint(
save_top_k=1,
save_top_k=-1,
save_last=True,
every_n_train_steps=config.training.ckpt_steps,
every_n_epochs=config.training.ckpt_epochs,
enable_version_counter=True,
save_on_train_epoch_end=True,
)
logger.debug(f"{last_ckpt_callback.state_dict()}")
logger.debug(f"{last_ckpt_callback.state_key}")
Expand All @@ -273,7 +274,11 @@ def train_base_command(
devices=devices,
max_epochs=config.training.max_epochs,
max_steps=config.training.max_steps,
check_val_every_n_epoch=config.training.check_val_every_n_epoch,
# NOTE: If we don't check validation at the end of the training epoch, saving the last checkpoint is disabled.
# pytorch_lightning/callbacks/model_checkpoint.py : 431
# _should_save_on_train_epoch_end [ModelCheckpoint]
# check_val_every_n_epoch=config.training.check_val_every_n_epoch,
check_val_every_n_epoch=1,
val_check_interval=config.training.val_check_interval,
callbacks=[
monitored_ckpt_callback,
Expand Down

0 comments on commit 1092bb7

Please sign in to comment.