diff --git a/examples/train_resnet_base.py b/examples/train_resnet_base.py index daab0b39ae5..c4a8890e9be 100644 --- a/examples/train_resnet_base.py +++ b/examples/train_resnet_base.py @@ -48,6 +48,7 @@ def step_fn(self, data, target): loss = self.loss_fn(output, target) loss.backward() self.run_optimizer() + return loss def train_loop_fn(self, loader, epoch): tracker = xm.RateTracker()