From 030ab599bb6f70bba135640c1aea46bbdf5a1bec Mon Sep 17 00:00:00 2001 From: Samuel Larkin Date: Thu, 12 Sep 2024 10:40:58 -0400 Subject: [PATCH] feat: more debug logging --- everyvoice/base_cli/callback.py | 43 +++++++++++++++++++++------------ everyvoice/base_cli/helpers.py | 5 ++++ 2 files changed, 33 insertions(+), 15 deletions(-) diff --git a/everyvoice/base_cli/callback.py b/everyvoice/base_cli/callback.py index 39ed2a32..d54f3698 100644 --- a/everyvoice/base_cli/callback.py +++ b/everyvoice/base_cli/callback.py @@ -21,28 +21,41 @@ def on_save_checkpoint( pl_module: "pl.LightningModule", checkpoint: Dict[str, Any], ) -> None: + import json + + logger.debug(f"current_epoch={trainer.current_epoch}") + logger.debug(f"global_step={trainer.global_step}") + logger.debug(json.dumps(checkpoint["loops"])) batch_progress = trainer.fit_loop.epoch_loop.val_loop.batch_progress batch_progress.reset() + if True: + fl = checkpoint["loops"]["fit_loop"] + assert ( + fl["epoch_loop.batch_progress"]["total"]["completed"] + == fl["epoch_loop.batch_progress"]["total"]["processed"] + ) + assert ( + fl["epoch_loop.batch_progress"]["current"]["completed"] + == fl["epoch_loop.batch_progress"]["current"]["processed"] + ) if False: # WARNING: This makes it worse. The curves are even more staggered. # Could it be the order that messes it up that bad? # Incorrect batch progress saved in checkpoint at every_n_train_steps # https://github.com/Lightning-AI/pytorch-lightning/issues/18060#issuecomment-2080180970 - checkpoint["loops"]["fit_loop"]["epoch_loop.batch_progress"]["total"][ - "completed" - ] = checkpoint["loops"]["fit_loop"]["epoch_loop.batch_progress"]["total"][ - "processed" - ] - checkpoint["loops"]["fit_loop"]["epoch_loop.batch_progress"]["current"][ - "completed" - ] = checkpoint["loops"]["fit_loop"]["epoch_loop.batch_progress"]["current"][ - "processed" - ] - checkpoint["loops"]["fit_loop"]["epoch_loop.state_dict"][ - "_batches_that_stepped" - ] = checkpoint["loops"]["fit_loop"]["epoch_loop.batch_progress"]["current"][ - "processed" - ] + fl = checkpoint["loops"]["fit_loop"] + fl["epoch_loop.batch_progress"]["total"]["completed"] = fl[ + "epoch_loop.batch_progress" + ]["total"]["processed"] + fl["epoch_loop.batch_progress"]["current"]["completed"] = fl[ + "epoch_loop.batch_progress" + ]["current"]["processed"] + # NOTE: This cannot be good because _batches_that_stepped is the + # overall batch count whereas current processed is with respect to + # the current epoch. + fl["epoch_loop.state_dict"]["_batches_that_stepped"] = fl[ + "epoch_loop.batch_progress" + ]["current"]["processed"] def log_debug(method_name: str, trainer: "pl.Trainer") -> None: diff --git a/everyvoice/base_cli/helpers.py b/everyvoice/base_cli/helpers.py index 68821a50..d5eca827 100644 --- a/everyvoice/base_cli/helpers.py +++ b/everyvoice/base_cli/helpers.py @@ -248,6 +248,8 @@ def train_base_command( every_n_epochs=config.training.ckpt_epochs, enable_version_counter=True, ) + logger.debug(f"{last_ckpt_callback.state_dict()}") + logger.debug(f"{last_ckpt_callback.state_key}") # This callback will only save the top-k checkpoints # based on minimization of the monitored loss monitored_ckpt_callback = ModelCheckpoint( @@ -373,6 +375,9 @@ def train_base_command( new_config_with_paths = model_obj.config.model_dump(mode="json") old_ckpt = torch.load(last_ckpt, map_location=torch.device("cpu")) old_ckpt["hyper_parameters"]["config"] = new_config_with_paths + logger.debug(f"epoch={old_ckpt['epoch']}") + logger.debug(f"global_step={old_ckpt['global_step']}") + logger.debug(json.dumps(old_ckpt["loops"], indent=2)) # TODO: check if we need to do the same thing with stats and any thing else registered on the model with tempfile.NamedTemporaryFile() as tmp: torch.save(old_ckpt, tmp.name)