diff --git a/teng_ml/rnn/training.py b/teng_ml/rnn/training.py index 298ab4f..b648462 100644 --- a/teng_ml/rnn/training.py +++ b/teng_ml/rnn/training.py @@ -68,10 +68,9 @@ def train(model, optimizer, scheduler, loss_func, train_loader: DataLoader, st: else: end='\n' print(f"Training:", epoch_tracker.get_epoch_summary_str(), end=end) # cancel training if model is not good enough - if len(training_cancel_points) > 0 and ep == training_cancel_points[0][0]: - print(f"Checking training cancel point: epoch={ep}, point={training_cancel_points[0]}, accuracy={epoch_tracker.accuracies[-1]}") + if len(training_cancel_points) > 0 and ep+1 == training_cancel_points[0][0]: if epoch_tracker.accuracies[-1] < training_cancel_points[0][1]: - print(f"Training cancelled because the models accuracy={epoch_tracker.accuracies[-1]:.2f} < {training_cancel_points[0][1]} after {ep} epochs.") + print(f"Training cancelled because the models accuracy={epoch_tracker.accuracies[-1]:.2f} < {training_cancel_points[0][1]} after {ep+1} epochs.") break; training_cancel_points.pop(0)