From 61321e3919d749a7ab0394b94fe9b3524593360e Mon Sep 17 00:00:00 2001 From: "matthias@arch" Date: Mon, 14 Aug 2023 18:43:07 +0200 Subject: [PATCH] fixed epoch for cancel pts --- teng_ml/rnn/training.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/teng_ml/rnn/training.py b/teng_ml/rnn/training.py index 298ab4f..b648462 100644 --- a/teng_ml/rnn/training.py +++ b/teng_ml/rnn/training.py @@ -68,10 +68,9 @@ def train(model, optimizer, scheduler, loss_func, train_loader: DataLoader, st: else: end='\n' print(f"Training:", epoch_tracker.get_epoch_summary_str(), end=end) # cancel training if model is not good enough - if len(training_cancel_points) > 0 and ep == training_cancel_points[0][0]: - print(f"Checking training cancel point: epoch={ep}, point={training_cancel_points[0]}, accuracy={epoch_tracker.accuracies[-1]}") + if len(training_cancel_points) > 0 and ep+1 == training_cancel_points[0][0]: if epoch_tracker.accuracies[-1] < training_cancel_points[0][1]: - print(f"Training cancelled because the models accuracy={epoch_tracker.accuracies[-1]:.2f} < {training_cancel_points[0][1]} after {ep} epochs.") + print(f"Training cancelled because the models accuracy={epoch_tracker.accuracies[-1]:.2f} < {training_cancel_points[0][1]} after {ep+1} epochs.") break; training_cancel_points.pop(0)