removed debug

This commit is contained in:
matthias@arch 2023-08-14 18:43:27 +02:00
parent 61321e3919
commit 37bb1f444e

View File

@ -104,18 +104,7 @@ class EpochTracker:
statistics = [ [ 0 for _ in range(len(self.labels)) ] for _ in range(len(self.labels)) ] statistics = [ [ 0 for _ in range(len(self.labels)) ] for _ in range(len(self.labels)) ]
for corr, pred in self.predictions[epoch]: for corr, pred in self.predictions[epoch]:
for batch in range(len(corr)): for batch in range(len(corr)):
try:
statistics[corr[batch]][pred[batch]] += 1 statistics[corr[batch]][pred[batch]] += 1
except IndexError as e:
print(f"IndexError in get_predictions_per_label: epoch={epoch}, len(corr)={len(corr)}, len(pred)={len(pred)}, batch={batch}, len(labels)={len(self.labels)}, len(statistics)={len(statistics)}")
print(f"statistics: {statistics}")
print(f"corr: {corr}")
print(f"pred: {pred}")
if batch in range(len(corr)):
if corr[batch] in range(len(statistics)):
print(f"len(statistics[corr[batch]])={len(statistics[corr[batch]])}")
print(f"corr[batch]={corr[batch]}, pred[batch]={pred[batch]}")
raise e
return statistics return statistics
def plot_training(self, title="Training Summary", model_dir=None, name="img_training"): def plot_training(self, title="Training Summary", model_dir=None, name="img_training"):