fixed epoch for cancel pts
This commit is contained in:
parent
1d05da3abf
commit
61321e3919
@ -68,10 +68,9 @@ def train(model, optimizer, scheduler, loss_func, train_loader: DataLoader, st:
|
||||
else: end='\n'
|
||||
print(f"Training:", epoch_tracker.get_epoch_summary_str(), end=end)
|
||||
# cancel training if model is not good enough
|
||||
if len(training_cancel_points) > 0 and ep == training_cancel_points[0][0]:
|
||||
print(f"Checking training cancel point: epoch={ep}, point={training_cancel_points[0]}, accuracy={epoch_tracker.accuracies[-1]}")
|
||||
if len(training_cancel_points) > 0 and ep+1 == training_cancel_points[0][0]:
|
||||
if epoch_tracker.accuracies[-1] < training_cancel_points[0][1]:
|
||||
print(f"Training cancelled because the models accuracy={epoch_tracker.accuracies[-1]:.2f} < {training_cancel_points[0][1]} after {ep} epochs.")
|
||||
print(f"Training cancelled because the models accuracy={epoch_tracker.accuracies[-1]:.2f} < {training_cancel_points[0][1]} after {ep+1} epochs.")
|
||||
break;
|
||||
training_cancel_points.pop(0)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user