Skip to content

Commit

Permalink
feat: more debug logging
Browse files Browse the repository at this point in the history
  • Loading branch information
SamuelLarkin committed Sep 12, 2024
1 parent 61ef60d commit 030ab59
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 15 deletions.
43 changes: 28 additions & 15 deletions everyvoice/base_cli/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,28 +21,41 @@ def on_save_checkpoint(
pl_module: "pl.LightningModule",
checkpoint: Dict[str, Any],
) -> None:
import json

logger.debug(f"current_epoch={trainer.current_epoch}")
logger.debug(f"global_step={trainer.global_step}")
logger.debug(json.dumps(checkpoint["loops"]))
batch_progress = trainer.fit_loop.epoch_loop.val_loop.batch_progress
batch_progress.reset()
if True:
fl = checkpoint["loops"]["fit_loop"]
assert (
fl["epoch_loop.batch_progress"]["total"]["completed"]
== fl["epoch_loop.batch_progress"]["total"]["processed"]
)
assert (
fl["epoch_loop.batch_progress"]["current"]["completed"]
== fl["epoch_loop.batch_progress"]["current"]["processed"]
)
if False:
# WARNING: This makes it worse. The curves are even more staggered.
# Could it be the order that messes it up that bad?
# Incorrect batch progress saved in checkpoint at every_n_train_steps
# https://github.com/Lightning-AI/pytorch-lightning/issues/18060#issuecomment-2080180970
checkpoint["loops"]["fit_loop"]["epoch_loop.batch_progress"]["total"][
"completed"
] = checkpoint["loops"]["fit_loop"]["epoch_loop.batch_progress"]["total"][
"processed"
]
checkpoint["loops"]["fit_loop"]["epoch_loop.batch_progress"]["current"][
"completed"
] = checkpoint["loops"]["fit_loop"]["epoch_loop.batch_progress"]["current"][
"processed"
]
checkpoint["loops"]["fit_loop"]["epoch_loop.state_dict"][
"_batches_that_stepped"
] = checkpoint["loops"]["fit_loop"]["epoch_loop.batch_progress"]["current"][
"processed"
]
fl = checkpoint["loops"]["fit_loop"]
fl["epoch_loop.batch_progress"]["total"]["completed"] = fl[
"epoch_loop.batch_progress"
]["total"]["processed"]
fl["epoch_loop.batch_progress"]["current"]["completed"] = fl[
"epoch_loop.batch_progress"
]["current"]["processed"]
# NOTE: This cannot be good because _batches_that_stepped is the
# overall batch count whereas current processed is with respect to
# the current epoch.
fl["epoch_loop.state_dict"]["_batches_that_stepped"] = fl[
"epoch_loop.batch_progress"
]["current"]["processed"]


def log_debug(method_name: str, trainer: "pl.Trainer") -> None:
Expand Down
5 changes: 5 additions & 0 deletions everyvoice/base_cli/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,8 @@ def train_base_command(
every_n_epochs=config.training.ckpt_epochs,
enable_version_counter=True,
)
logger.debug(f"{last_ckpt_callback.state_dict()}")
logger.debug(f"{last_ckpt_callback.state_key}")
# This callback will only save the top-k checkpoints
# based on minimization of the monitored loss
monitored_ckpt_callback = ModelCheckpoint(
Expand Down Expand Up @@ -373,6 +375,9 @@ def train_base_command(
new_config_with_paths = model_obj.config.model_dump(mode="json")
old_ckpt = torch.load(last_ckpt, map_location=torch.device("cpu"))
old_ckpt["hyper_parameters"]["config"] = new_config_with_paths
logger.debug(f"epoch={old_ckpt['epoch']}")
logger.debug(f"global_step={old_ckpt['global_step']}")
logger.debug(json.dumps(old_ckpt["loops"], indent=2))
# TODO: check if we need to do the same thing with stats and any thing else registered on the model
with tempfile.NamedTemporaryFile() as tmp:
torch.save(old_ckpt, tmp.name)
Expand Down

0 comments on commit 030ab59

Please sign in to comment.