diff --git a/teng_ml/tracker/epoch_tracker.py b/teng_ml/tracker/epoch_tracker.py index 0f9cd3c..7c87cec 100644 --- a/teng_ml/tracker/epoch_tracker.py +++ b/teng_ml/tracker/epoch_tracker.py @@ -1,5 +1,6 @@ from ..util.data_loader import LabelConverter import matplotlib.pyplot as plt +import matplotlib.colors as colors import time import torch import numpy as np @@ -141,8 +142,7 @@ class EpochTracker: label_names = self.labels.get_labels() fig, ax = plt.subplots(layout="tight") - - im = ax.imshow(normalized_predictions, cmap='Blues') # cmap='BuPu' + im = ax.imshow(normalized_predictions, cmap='Blues') # cmap='BuPu', , norm=colors.PowerNorm(1./2.) ax.set_xticks(np.arange(N)) ax.set_yticks(np.arange(N)) ax.set_xticklabels(label_names)