from os import stat from ..util.data_loader import LabelConverter import matplotlib.pyplot as plt import matplotlib.colors as colors import time import torch import numpy as np class EpochTracker: """ Track accuracy, loss, learning_rate etc. during model training Can also be used for validation (which will probably be only one epoch) """ def __init__(self, labels: LabelConverter): self.labels = labels self.times: list[float] = [] # (epoch) self.predictions = [[]] # (epoch, batch_nr, (correct_indices | predicted_indices), ind:ex_nr) self.loss: list[float] = [] # (epoch) self.learning_rate: list[float] = [] # (epoch) self.epochs: list[int] = [] # 1 based for FINISHED epochs self._current_epoch = 0 # 0 based # after training self.accuracies: list[float] = [] # (epoch) def begin(self): self.times.append(time.time()) def end(self): self.times.append(time.time()) # if end_epoch was called before end: if len(self.predictions[-1]) == 0: self.predictions.pop() self._current_epoch -= 1 else: # if end_epoch was not called self.epochs.append(len(self.epochs) + 1) self._calculate_accuracies(self._current_epoch) s = f"Summary: After {self.epochs[-1]} epochs: " s += f"Accuracy={self.accuracies[-1]:.2f}%" s += f", Total time={self.get_total_time():.2f}s" return s def get_total_time(self): if len(self.times) > 1: return self.times[-1] - self.times[0] else: return -1 # # EPOCH # def end_epoch(self, loss, learning_rate): """ loss and learning_rate of last epoch call before scheduler.step() """ self.times.append(time.time()) self.epochs.append(len(self.epochs) + 1) if type(loss) == torch.Tensor: self.loss.append(loss.item()) else: self.loss.append(loss) self.learning_rate.append(learning_rate) self._calculate_accuracies(self._current_epoch) self._current_epoch += 1 self.predictions.append([]) def get_epoch_summary_str(self, ep=-1): """call after next_epoch()""" m = max(ep, 0) # if ep == -1, check if len is > 0 assert(len(self.epochs) > m) s = f"Epoch {self.epochs[ep]:3}" if len(self.accuracies) > m:s += f", Accuracy={self.accuracies[ep]:.2f}%" if len(self.loss) > m: s += f", Loss={self.loss[ep]:.3f}" if len(self.loss) > m: s += f", lr={self.learning_rate[ep]:.4f}" if len(self.times) > m+1: s += f", dt={self.times[ep] - self.times[ep-1]:.2f}s" return s def add_prediction(self, correct_indices: torch.Tensor, predicted_indices: torch.Tensor): """for accuracy calculation""" self.predictions[self._current_epoch].append((correct_indices.detach().numpy(), predicted_indices.detach().numpy())) # # STATISTICS # def get_count_per_label(self, epoch=-1): """ the number of times where