fixed epoch for cancel pts

This commit is contained in:
matthias@arch 2023-08-14 18:43:07 +02:00
parent 1d05da3abf
commit 61321e3919

View File

@ -68,10 +68,9 @@ def train(model, optimizer, scheduler, loss_func, train_loader: DataLoader, st:
else: end='\n' else: end='\n'
print(f"Training:", epoch_tracker.get_epoch_summary_str(), end=end) print(f"Training:", epoch_tracker.get_epoch_summary_str(), end=end)
# cancel training if model is not good enough # cancel training if model is not good enough
if len(training_cancel_points) > 0 and ep == training_cancel_points[0][0]: if len(training_cancel_points) > 0 and ep+1 == training_cancel_points[0][0]:
print(f"Checking training cancel point: epoch={ep}, point={training_cancel_points[0]}, accuracy={epoch_tracker.accuracies[-1]}")
if epoch_tracker.accuracies[-1] < training_cancel_points[0][1]: if epoch_tracker.accuracies[-1] < training_cancel_points[0][1]:
print(f"Training cancelled because the models accuracy={epoch_tracker.accuracies[-1]:.2f} < {training_cancel_points[0][1]} after {ep} epochs.") print(f"Training cancelled because the models accuracy={epoch_tracker.accuracies[-1]:.2f} < {training_cancel_points[0][1]} after {ep+1} epochs.")
break; break;
training_cancel_points.pop(0) training_cancel_points.pop(0)