diff --git a/teng_ml/tracker/epoch_tracker.py b/teng_ml/tracker/epoch_tracker.py index 6a9a9ac..cdd4461 100644 --- a/teng_ml/tracker/epoch_tracker.py +++ b/teng_ml/tracker/epoch_tracker.py @@ -104,18 +104,7 @@ class EpochTracker: statistics = [ [ 0 for _ in range(len(self.labels)) ] for _ in range(len(self.labels)) ] for corr, pred in self.predictions[epoch]: for batch in range(len(corr)): - try: - statistics[corr[batch]][pred[batch]] += 1 - except IndexError as e: - print(f"IndexError in get_predictions_per_label: epoch={epoch}, len(corr)={len(corr)}, len(pred)={len(pred)}, batch={batch}, len(labels)={len(self.labels)}, len(statistics)={len(statistics)}") - print(f"statistics: {statistics}") - print(f"corr: {corr}") - print(f"pred: {pred}") - if batch in range(len(corr)): - if corr[batch] in range(len(statistics)): - print(f"len(statistics[corr[batch]])={len(statistics[corr[batch]])}") - print(f"corr[batch]={corr[batch]}, pred[batch]={pred[batch]}") - raise e + statistics[corr[batch]][pred[batch]] += 1 return statistics def plot_training(self, title="Training Summary", model_dir=None, name="img_training"):