fixed epoch for cancel pts
This commit is contained in:
parent
1d05da3abf
commit
61321e3919
@ -68,10 +68,9 @@ def train(model, optimizer, scheduler, loss_func, train_loader: DataLoader, st:
|
|||||||
else: end='\n'
|
else: end='\n'
|
||||||
print(f"Training:", epoch_tracker.get_epoch_summary_str(), end=end)
|
print(f"Training:", epoch_tracker.get_epoch_summary_str(), end=end)
|
||||||
# cancel training if model is not good enough
|
# cancel training if model is not good enough
|
||||||
if len(training_cancel_points) > 0 and ep == training_cancel_points[0][0]:
|
if len(training_cancel_points) > 0 and ep+1 == training_cancel_points[0][0]:
|
||||||
print(f"Checking training cancel point: epoch={ep}, point={training_cancel_points[0]}, accuracy={epoch_tracker.accuracies[-1]}")
|
|
||||||
if epoch_tracker.accuracies[-1] < training_cancel_points[0][1]:
|
if epoch_tracker.accuracies[-1] < training_cancel_points[0][1]:
|
||||||
print(f"Training cancelled because the models accuracy={epoch_tracker.accuracies[-1]:.2f} < {training_cancel_points[0][1]} after {ep} epochs.")
|
print(f"Training cancelled because the models accuracy={epoch_tracker.accuracies[-1]:.2f} < {training_cancel_points[0][1]} after {ep+1} epochs.")
|
||||||
break;
|
break;
|
||||||
training_cancel_points.pop(0)
|
training_cancel_points.pop(0)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user