teng-ml/teng_ml/tracker/epoch_tracker.py

213 lines
7.8 KiB
Python
Raw Permalink Normal View History

2023-08-10 17:29:09 +02:00
from os import stat
2023-05-10 22:44:14 +02:00
from ..util.data_loader import LabelConverter
import matplotlib.pyplot as plt
2023-08-02 10:57:01 +02:00
import matplotlib.colors as colors
2023-05-10 22:44:14 +02:00
import time
import torch
import numpy as np
class EpochTracker:
"""
Track accuracy, loss, learning_rate etc. during model training
Can also be used for validation (which will probably be only one epoch)
"""
def __init__(self, labels: LabelConverter):
self.labels = labels
self.times: list[float] = [] # (epoch)
self.predictions = [[]] # (epoch, batch_nr, (correct_indices | predicted_indices), ind:ex_nr)
self.loss: list[float] = [] # (epoch)
self.learning_rate: list[float] = [] # (epoch)
self.epochs: list[int] = [] # 1 based for FINISHED epochs
self._current_epoch = 0 # 0 based
# after training
self.accuracies: list[float] = [] # (epoch)
def begin(self):
self.times.append(time.time())
def end(self):
self.times.append(time.time())
# if end_epoch was called before end:
if len(self.predictions[-1]) == 0:
self.predictions.pop()
self._current_epoch -= 1
else: # if end_epoch was not called
self.epochs.append(len(self.epochs) + 1)
self._calculate_accuracies(self._current_epoch)
s = f"Summary: After {self.epochs[-1]} epochs: "
s += f"Accuracy={self.accuracies[-1]:.2f}%"
s += f", Total time={self.get_total_time():.2f}s"
return s
def get_total_time(self):
if len(self.times) > 1: return self.times[-1] - self.times[0]
else: return -1
#
# EPOCH
#
def end_epoch(self, loss, learning_rate):
"""
loss and learning_rate of last epoch
call before scheduler.step()
"""
self.times.append(time.time())
self.epochs.append(len(self.epochs) + 1)
if type(loss) == torch.Tensor: self.loss.append(loss.item())
else: self.loss.append(loss)
self.learning_rate.append(learning_rate)
self._calculate_accuracies(self._current_epoch)
self._current_epoch += 1
self.predictions.append([])
def get_epoch_summary_str(self, ep=-1):
"""call after next_epoch()"""
m = max(ep, 0) # if ep == -1, check if len is > 0
assert(len(self.epochs) > m)
s = f"Epoch {self.epochs[ep]:3}"
if len(self.accuracies) > m:s += f", Accuracy={self.accuracies[ep]:.2f}%"
if len(self.loss) > m: s += f", Loss={self.loss[ep]:.3f}"
if len(self.loss) > m: s += f", lr={self.learning_rate[ep]:.4f}"
if len(self.times) > m+1: s += f", dt={self.times[ep] - self.times[ep-1]:.2f}s"
return s
def add_prediction(self, correct_indices: torch.Tensor, predicted_indices: torch.Tensor):
"""for accuracy calculation"""
self.predictions[self._current_epoch].append((correct_indices.detach().numpy(), predicted_indices.detach().numpy()))
#
# STATISTICS
#
def get_count_per_label(self, epoch=-1):
"""
the number of times where <label> was the correct label, per label
@returns shape: (label)
"""
count_per_label = [ 0 for _ in range(len(self.labels)) ]
for corr, _ in self.predictions[epoch]:
for batch in range(len(corr)):
count_per_label[corr[batch]] += 1
return count_per_label
def get_predictions_per_label(self, epoch=-1):
"""
How often label_i was predicted, when label_j was the correct label
@returns shape: (label_j, label_i)
"""
statistics = [ [ 0 for _ in range(len(self.labels)) ] for _ in range(len(self.labels)) ]
for corr, pred in self.predictions[epoch]:
for batch in range(len(corr)):
2023-08-14 18:43:27 +02:00
statistics[corr[batch]][pred[batch]] += 1
2023-05-10 22:44:14 +02:00
return statistics
def plot_training(self, title="Training Summary", model_dir=None, name="img_training"):
"""
@param model_dir: Optional. If given, save to model_dir as svg
"""
2023-08-30 17:46:49 +02:00
fig, ax = plt.subplots(nrows=2, ncols=1, sharex=True, layout="tight", figsize=(6, 6))
2023-05-10 22:44:14 +02:00
ax[0].plot(self.epochs, self.accuracies, color="red")
ax[0].set_ylabel("Accuracy")
2023-08-30 17:46:49 +02:00
ax[0].grid("minor")
2023-05-10 22:44:14 +02:00
ax[1].plot(self.epochs, self.learning_rate, color="green")
ax[1].set_ylabel("Learning Rate")
2023-08-30 17:46:49 +02:00
ax[1].grid("minor")
2023-05-10 22:44:14 +02:00
2023-08-30 17:46:49 +02:00
# ax[2].plot(self.epochs, self.loss, color="blue")
# ax[2].set_ylabel("Loss")
2023-05-10 22:44:14 +02:00
fig.suptitle(title)
2023-08-30 17:46:49 +02:00
ax[-1].set_xlabel("Epoch")
2023-05-10 22:44:14 +02:00
plt.tight_layout()
if model_dir is not None:
fig.savefig(f"{model_dir}/{name}.svg")
return fig, ax
2023-08-30 17:46:49 +02:00
def plot_predictions(self, title="Predictions per Label", ep=-1, model_dir=None, name="img_training_predictions", empty_zero=True):
2023-05-10 22:44:14 +02:00
"""
@param model_dir: Optional. If given, save to model_dir as svg
@param ep: Epoch, defaults to last
"""
# Normalize the data
predictions_per_label = self.get_predictions_per_label(ep)
normalized_predictions = predictions_per_label / np.sum(predictions_per_label, axis=1, keepdims=True)
N = len(self.labels)
label_names = self.labels.get_labels()
2023-08-30 17:46:49 +02:00
# print(label_names)
replace = {
"cloth": "fabric",
"foam": "foam_PDMS_pure",
"foil": "bubble_wrap",
"rigid_foam": "foam_PE",
"fabric_PP": "fabric",
"foam_PDMS_white": "foam_PDMS_pure",
"foam_PDMS_black": "foam_PEDOT",
"bubble_wrap_PE": "bubble_wrap",
}
label_names = [ replace[label] if label in replace else label for label in label_names ]
if len(label_names) > 6:
fig, ax = plt.subplots(layout="tight", figsize=(7, 6))
else:
fig, ax = plt.subplots(layout="tight", figsize=(6, 5))
2023-08-02 10:57:01 +02:00
im = ax.imshow(normalized_predictions, cmap='Blues') # cmap='BuPu', , norm=colors.PowerNorm(1./2.)
2023-05-10 22:44:14 +02:00
ax.set_xticks(np.arange(N))
ax.set_yticks(np.arange(N))
ax.set_xticklabels(label_names)
ax.set_yticklabels(label_names)
ax.set_xlabel('Predicted Label')
ax.set_ylabel('Correct Label')
# horizontal lines between labels to better show that the sum of a row is 1
for i in range(1, N):
ax.axhline(i-0.5, color='black', linewidth=1)
2023-08-30 17:46:49 +02:00
# for i in range(1, N):
# ax.axvline(i-0.5, color='#bbb', linewidth=1)
2023-05-10 22:44:14 +02:00
# rotate the x-axis labels for better readability
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
# create annotations
for i in range(N):
for j in range(N):
2023-08-30 17:46:49 +02:00
val = round(normalized_predictions[i, j], 2)
if empty_zero and val == 0: continue
color = "black"
if normalized_predictions[i, j] >= 0.6: color = "white"
text = ax.text(j, i, val,
ha="center", va="center", color=color)
2023-05-10 22:44:14 +02:00
# add colorbar
cbar = ax.figure.colorbar(im, ax=ax)
ax.set_title(title)
plt.tight_layout()
if model_dir is not None:
fig.savefig(f"{model_dir}/{name}.svg")
return fig, ax
#
# CALCULATION
#
def _calculate_accuracies(self, ep):
correct_predictions = 0
total_predictions = 0
for correct_indices, predicted_indices in self.predictions[ep]:
correct_predictions += (predicted_indices == correct_indices).sum().item()
total_predictions += len(predicted_indices)
accuracy = correct_predictions / total_predictions * 100
while len(self.accuracies) <= ep:
self.accuracies.append(-1)
self.accuracies[ep] = accuracy