diff --git a/teng_ml/tracker/epoch_tracker.py b/teng_ml/tracker/epoch_tracker.py index cdd4461..e6b999d 100644 --- a/teng_ml/tracker/epoch_tracker.py +++ b/teng_ml/tracker/epoch_tracker.py @@ -111,26 +111,28 @@ class EpochTracker: """ @param model_dir: Optional. If given, save to model_dir as svg """ - fig, ax = plt.subplots(nrows=3, ncols=1, sharex=True, layout="tight") + fig, ax = plt.subplots(nrows=2, ncols=1, sharex=True, layout="tight", figsize=(6, 6)) ax[0].plot(self.epochs, self.accuracies, color="red") ax[0].set_ylabel("Accuracy") + ax[0].grid("minor") ax[1].plot(self.epochs, self.learning_rate, color="green") ax[1].set_ylabel("Learning Rate") + ax[1].grid("minor") - ax[2].plot(self.epochs, self.loss, color="blue") - ax[2].set_ylabel("Loss") + # ax[2].plot(self.epochs, self.loss, color="blue") + # ax[2].set_ylabel("Loss") fig.suptitle(title) - ax[2].set_xlabel("Epoch") + ax[-1].set_xlabel("Epoch") plt.tight_layout() if model_dir is not None: fig.savefig(f"{model_dir}/{name}.svg") return fig, ax - def plot_predictions(self, title="Predictions per Label", ep=-1, model_dir=None, name="img_training_predictions"): + def plot_predictions(self, title="Predictions per Label", ep=-1, model_dir=None, name="img_training_predictions", empty_zero=True): """ @param model_dir: Optional. If given, save to model_dir as svg @param ep: Epoch, defaults to last @@ -141,8 +143,23 @@ class EpochTracker: N = len(self.labels) label_names = self.labels.get_labels() + # print(label_names) + replace = { + "cloth": "fabric", + "foam": "foam_PDMS_pure", + "foil": "bubble_wrap", + "rigid_foam": "foam_PE", + "fabric_PP": "fabric", + "foam_PDMS_white": "foam_PDMS_pure", + "foam_PDMS_black": "foam_PEDOT", + "bubble_wrap_PE": "bubble_wrap", + } + label_names = [ replace[label] if label in replace else label for label in label_names ] - fig, ax = plt.subplots(layout="tight") + if len(label_names) > 6: + fig, ax = plt.subplots(layout="tight", figsize=(7, 6)) + else: + fig, ax = plt.subplots(layout="tight", figsize=(6, 5)) im = ax.imshow(normalized_predictions, cmap='Blues') # cmap='BuPu', , norm=colors.PowerNorm(1./2.) ax.set_xticks(np.arange(N)) ax.set_yticks(np.arange(N)) @@ -155,14 +172,21 @@ class EpochTracker: for i in range(1, N): ax.axhline(i-0.5, color='black', linewidth=1) + # for i in range(1, N): + # ax.axvline(i-0.5, color='#bbb', linewidth=1) + # rotate the x-axis labels for better readability plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") # create annotations for i in range(N): for j in range(N): - text = ax.text(j, i, round(normalized_predictions[i, j], 2), - ha="center", va="center", color="black") + val = round(normalized_predictions[i, j], 2) + if empty_zero and val == 0: continue + color = "black" + if normalized_predictions[i, j] >= 0.6: color = "white" + text = ax.text(j, i, val, + ha="center", va="center", color=color) # add colorbar cbar = ax.figure.colorbar(im, ax=ax)