improve plots
This commit is contained in:
parent
3d823a19c7
commit
c8201b5175
@ -111,26 +111,28 @@ class EpochTracker:
|
|||||||
"""
|
"""
|
||||||
@param model_dir: Optional. If given, save to model_dir as svg
|
@param model_dir: Optional. If given, save to model_dir as svg
|
||||||
"""
|
"""
|
||||||
fig, ax = plt.subplots(nrows=3, ncols=1, sharex=True, layout="tight")
|
fig, ax = plt.subplots(nrows=2, ncols=1, sharex=True, layout="tight", figsize=(6, 6))
|
||||||
|
|
||||||
ax[0].plot(self.epochs, self.accuracies, color="red")
|
ax[0].plot(self.epochs, self.accuracies, color="red")
|
||||||
ax[0].set_ylabel("Accuracy")
|
ax[0].set_ylabel("Accuracy")
|
||||||
|
ax[0].grid("minor")
|
||||||
|
|
||||||
ax[1].plot(self.epochs, self.learning_rate, color="green")
|
ax[1].plot(self.epochs, self.learning_rate, color="green")
|
||||||
ax[1].set_ylabel("Learning Rate")
|
ax[1].set_ylabel("Learning Rate")
|
||||||
|
ax[1].grid("minor")
|
||||||
|
|
||||||
ax[2].plot(self.epochs, self.loss, color="blue")
|
# ax[2].plot(self.epochs, self.loss, color="blue")
|
||||||
ax[2].set_ylabel("Loss")
|
# ax[2].set_ylabel("Loss")
|
||||||
|
|
||||||
fig.suptitle(title)
|
fig.suptitle(title)
|
||||||
ax[2].set_xlabel("Epoch")
|
ax[-1].set_xlabel("Epoch")
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
if model_dir is not None:
|
if model_dir is not None:
|
||||||
fig.savefig(f"{model_dir}/{name}.svg")
|
fig.savefig(f"{model_dir}/{name}.svg")
|
||||||
|
|
||||||
return fig, ax
|
return fig, ax
|
||||||
|
|
||||||
def plot_predictions(self, title="Predictions per Label", ep=-1, model_dir=None, name="img_training_predictions"):
|
def plot_predictions(self, title="Predictions per Label", ep=-1, model_dir=None, name="img_training_predictions", empty_zero=True):
|
||||||
"""
|
"""
|
||||||
@param model_dir: Optional. If given, save to model_dir as svg
|
@param model_dir: Optional. If given, save to model_dir as svg
|
||||||
@param ep: Epoch, defaults to last
|
@param ep: Epoch, defaults to last
|
||||||
@ -141,8 +143,23 @@ class EpochTracker:
|
|||||||
|
|
||||||
N = len(self.labels)
|
N = len(self.labels)
|
||||||
label_names = self.labels.get_labels()
|
label_names = self.labels.get_labels()
|
||||||
|
# print(label_names)
|
||||||
|
replace = {
|
||||||
|
"cloth": "fabric",
|
||||||
|
"foam": "foam_PDMS_pure",
|
||||||
|
"foil": "bubble_wrap",
|
||||||
|
"rigid_foam": "foam_PE",
|
||||||
|
"fabric_PP": "fabric",
|
||||||
|
"foam_PDMS_white": "foam_PDMS_pure",
|
||||||
|
"foam_PDMS_black": "foam_PEDOT",
|
||||||
|
"bubble_wrap_PE": "bubble_wrap",
|
||||||
|
}
|
||||||
|
label_names = [ replace[label] if label in replace else label for label in label_names ]
|
||||||
|
|
||||||
fig, ax = plt.subplots(layout="tight")
|
if len(label_names) > 6:
|
||||||
|
fig, ax = plt.subplots(layout="tight", figsize=(7, 6))
|
||||||
|
else:
|
||||||
|
fig, ax = plt.subplots(layout="tight", figsize=(6, 5))
|
||||||
im = ax.imshow(normalized_predictions, cmap='Blues') # cmap='BuPu', , norm=colors.PowerNorm(1./2.)
|
im = ax.imshow(normalized_predictions, cmap='Blues') # cmap='BuPu', , norm=colors.PowerNorm(1./2.)
|
||||||
ax.set_xticks(np.arange(N))
|
ax.set_xticks(np.arange(N))
|
||||||
ax.set_yticks(np.arange(N))
|
ax.set_yticks(np.arange(N))
|
||||||
@ -155,14 +172,21 @@ class EpochTracker:
|
|||||||
for i in range(1, N):
|
for i in range(1, N):
|
||||||
ax.axhline(i-0.5, color='black', linewidth=1)
|
ax.axhline(i-0.5, color='black', linewidth=1)
|
||||||
|
|
||||||
|
# for i in range(1, N):
|
||||||
|
# ax.axvline(i-0.5, color='#bbb', linewidth=1)
|
||||||
|
|
||||||
# rotate the x-axis labels for better readability
|
# rotate the x-axis labels for better readability
|
||||||
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
|
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
|
||||||
|
|
||||||
# create annotations
|
# create annotations
|
||||||
for i in range(N):
|
for i in range(N):
|
||||||
for j in range(N):
|
for j in range(N):
|
||||||
text = ax.text(j, i, round(normalized_predictions[i, j], 2),
|
val = round(normalized_predictions[i, j], 2)
|
||||||
ha="center", va="center", color="black")
|
if empty_zero and val == 0: continue
|
||||||
|
color = "black"
|
||||||
|
if normalized_predictions[i, j] >= 0.6: color = "white"
|
||||||
|
text = ax.text(j, i, val,
|
||||||
|
ha="center", va="center", color=color)
|
||||||
|
|
||||||
# add colorbar
|
# add colorbar
|
||||||
cbar = ax.figure.colorbar(im, ax=ax)
|
cbar = ax.figure.colorbar(im, ax=ax)
|
||||||
|
Loading…
Reference in New Issue
Block a user