Compare commits

..

No commits in common. "c8201b5175b1070313181d03e203523d024194a5" and "5895f39874f970c7b2c52b7fe12a7361a7d06274" have entirely different histories.

6 changed files with 26 additions and 84 deletions

View File

@ -21,8 +21,7 @@ dependencies = [
"matplotlib>=3.6", "matplotlib>=3.6",
"numpy", "numpy",
"torch", "torch",
"scikit-learn", "scikit-learn"
"pandas",
] ]
[project.urls] [project.urls]

View File

@ -1,11 +1,6 @@
# Machine learning for material recognition with a triboelectric nanogenerator (TENG) # Machine learning for material recognition with a TENG
This project was written for my bachelor's thesis. (Bi)LSTM for name classification.
More information on the project are [on my website](https://quintern.xyz/en/teng.html).
It was written to classify TENG voltage output from pressing it against different materials.
Contents:
- Data preparation/plotting/loading utilites
- (Bi)LSTM + fully connected + softmax model for name classifiying TENG output
- Progress tracking utilities to easily find the best parameters
## Model training ## Model training
Adjust the parameters in `main.py` and run it. Adjust the parameters in `main.py` and run it.
@ -15,3 +10,4 @@ of the `<model_dir>` that was set in `main.py`.
## Model evaluation ## Model evaluation
Run `find_best_model.py <model_dir>` with the `<model_dir>` specified in `main.py` during training. Run `find_best_model.py <model_dir>` with the `<model_dir>` specified in `main.py` during training.

View File

@ -42,10 +42,9 @@ def test_interpol():
if __name__ == "__main__": if __name__ == "__main__":
# labels = LabelConverter(["foam_PDMS_white", "foam_PDMS_black", "foam_PDMS_TX100", "foam_PE", "antistatic_foil", "cardboard", "glass", "kapton", "bubble_wrap_PE", "fabric_PP" ]) labels = LabelConverter(["foam_PDMS_white", "foam_PDMS_black", "foam_PDMS_TX100", "foam_PE", "antistatic_foil", "cardboard", "glass", "kapton", "bubble_wrap_PE", "fabric_PP", ])
labels = LabelConverter(["foam_PDMS_white", "foam_PDMS_black", "foam_PDMS_TX100", "foam_PE", "antistatic_foil", "cardboard", "kapton", "bubble_wrap_PE", "fabric_PP" ]) # labels = LabelConverter(["foam_PDMS_white", "foam_PDMS_black", "foam_PDMS_TX100", "foam_PE", "kapton", "bubble_wrap_PE", "fabric_PP", ])
# labels = LabelConverter(["foam_PDMS_white", "foam_PDMS_black", "foam_PDMS_TX100", "foam_PE", "kapton", "bubble_wrap_PE", "fabric_PP" ]) models_dir = "/home/matth/Uni/TENG/teng_2/models_gen_12" # where to save models, settings and results
models_dir = "/home/matth/Uni/TENG/teng_2/models_gen_15" # where to save models, settings and results
if not path.isdir(models_dir): if not path.isdir(models_dir):
makedirs(models_dir) makedirs(models_dir)
data_dir = "/home/matth/Uni/TENG/teng_2/sorted_data" data_dir = "/home/matth/Uni/TENG/teng_2/sorted_data"
@ -54,18 +53,18 @@ if __name__ == "__main__":
# gen_6 best options: no glass, cardboard and antistatic_foil, not bidirectional, lr=0.0007, no datasplitter, 2 layers n_hidden = 10 # gen_6 best options: no glass, cardboard and antistatic_foil, not bidirectional, lr=0.0007, no datasplitter, 2 layers n_hidden = 10
# Test with # Test with
num_layers = [ 4, 5 ] num_layers = [ 2, 3 ]
hidden_size = [ 28, 36 ] hidden_size = [ 21, 28 ]
bidirectional = [ True ] bidirectional = [ False, True ]
t_const_int = ConstantInterval(0.01) # TODO check if needed: data was taken at equal rate, but it isnt perfect -> maybe just ignore? t_const_int = ConstantInterval(0.01) # TODO check if needed: data was taken at equal rate, but it isnt perfect -> maybe just ignore?
t_norm = Normalize(-1, 1) t_norm = Normalize(-1, 1)
transforms = [[]] #, [ t_norm, t_const_int ]] transforms = [[ t_norm ]] #, [ t_norm, t_const_int ]]
batch_sizes = [ 4 ] batch_sizes = [ 4 ]
splitters = [ DataSplitter(50, drop_if_smaller_than=30) ] # smallest file has length 68 TODO: try with 0.5-1second snippets splitters = [ DataSplitter(50, drop_if_smaller_than=30) ] # smallest file has length 68 TODO: try with 0.5-1second snippets
num_epochs = [ 80 ] num_epochs = [ 80 ]
# (epoch, min_accuracy) # (epoch, min_accuracy)
# training_cancel_points = [(15, 20), (40, 25)] training_cancel_points = [(15, 20), (40, 25)]
training_cancel_points = [] # training_cancel_points = []
args = [num_layers, hidden_size, bidirectional, [None], [None], [None], transforms, splitters, num_epochs, batch_sizes] args = [num_layers, hidden_size, bidirectional, [None], [None], [None], transforms, splitters, num_epochs, batch_sizes]
@ -82,7 +81,7 @@ if __name__ == "__main__":
None, None,
# lambda optimizer, st: torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9), # lambda optimizer, st: torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9),
# lambda optimizer, st: torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.5), # lambda optimizer, st: torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.5),
# lambda optimizer, st: torch.optim.lr_scheduler.StepLR(optimizer, step_size=st.num_epochs // 8, gamma=0.60, verbose=False), lambda optimizer, st: torch.optim.lr_scheduler.StepLR(optimizer, step_size=st.num_epochs // 8, gamma=0.60, verbose=False),
# lambda optimizer, st: torch.optim.lr_scheduler.StepLR(optimizer, step_size=st.num_epochs // 10, gamma=0.75, verbose=False), # lambda optimizer, st: torch.optim.lr_scheduler.StepLR(optimizer, step_size=st.num_epochs // 10, gamma=0.75, verbose=False),
] ]

View File

@ -111,28 +111,26 @@ class EpochTracker:
""" """
@param model_dir: Optional. If given, save to model_dir as svg @param model_dir: Optional. If given, save to model_dir as svg
""" """
fig, ax = plt.subplots(nrows=2, ncols=1, sharex=True, layout="tight", figsize=(6, 6)) fig, ax = plt.subplots(nrows=3, ncols=1, sharex=True, layout="tight")
ax[0].plot(self.epochs, self.accuracies, color="red") ax[0].plot(self.epochs, self.accuracies, color="red")
ax[0].set_ylabel("Accuracy") ax[0].set_ylabel("Accuracy")
ax[0].grid("minor")
ax[1].plot(self.epochs, self.learning_rate, color="green") ax[1].plot(self.epochs, self.learning_rate, color="green")
ax[1].set_ylabel("Learning Rate") ax[1].set_ylabel("Learning Rate")
ax[1].grid("minor")
# ax[2].plot(self.epochs, self.loss, color="blue") ax[2].plot(self.epochs, self.loss, color="blue")
# ax[2].set_ylabel("Loss") ax[2].set_ylabel("Loss")
fig.suptitle(title) fig.suptitle(title)
ax[-1].set_xlabel("Epoch") ax[2].set_xlabel("Epoch")
plt.tight_layout() plt.tight_layout()
if model_dir is not None: if model_dir is not None:
fig.savefig(f"{model_dir}/{name}.svg") fig.savefig(f"{model_dir}/{name}.svg")
return fig, ax return fig, ax
def plot_predictions(self, title="Predictions per Label", ep=-1, model_dir=None, name="img_training_predictions", empty_zero=True): def plot_predictions(self, title="Predictions per Label", ep=-1, model_dir=None, name="img_training_predictions"):
""" """
@param model_dir: Optional. If given, save to model_dir as svg @param model_dir: Optional. If given, save to model_dir as svg
@param ep: Epoch, defaults to last @param ep: Epoch, defaults to last
@ -143,23 +141,8 @@ class EpochTracker:
N = len(self.labels) N = len(self.labels)
label_names = self.labels.get_labels() label_names = self.labels.get_labels()
# print(label_names)
replace = {
"cloth": "fabric",
"foam": "foam_PDMS_pure",
"foil": "bubble_wrap",
"rigid_foam": "foam_PE",
"fabric_PP": "fabric",
"foam_PDMS_white": "foam_PDMS_pure",
"foam_PDMS_black": "foam_PEDOT",
"bubble_wrap_PE": "bubble_wrap",
}
label_names = [ replace[label] if label in replace else label for label in label_names ]
if len(label_names) > 6: fig, ax = plt.subplots(layout="tight")
fig, ax = plt.subplots(layout="tight", figsize=(7, 6))
else:
fig, ax = plt.subplots(layout="tight", figsize=(6, 5))
im = ax.imshow(normalized_predictions, cmap='Blues') # cmap='BuPu', , norm=colors.PowerNorm(1./2.) im = ax.imshow(normalized_predictions, cmap='Blues') # cmap='BuPu', , norm=colors.PowerNorm(1./2.)
ax.set_xticks(np.arange(N)) ax.set_xticks(np.arange(N))
ax.set_yticks(np.arange(N)) ax.set_yticks(np.arange(N))
@ -172,21 +155,14 @@ class EpochTracker:
for i in range(1, N): for i in range(1, N):
ax.axhline(i-0.5, color='black', linewidth=1) ax.axhline(i-0.5, color='black', linewidth=1)
# for i in range(1, N):
# ax.axvline(i-0.5, color='#bbb', linewidth=1)
# rotate the x-axis labels for better readability # rotate the x-axis labels for better readability
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
# create annotations # create annotations
for i in range(N): for i in range(N):
for j in range(N): for j in range(N):
val = round(normalized_predictions[i, j], 2) text = ax.text(j, i, round(normalized_predictions[i, j], 2),
if empty_zero and val == 0: continue ha="center", va="center", color="black")
color = "black"
if normalized_predictions[i, j] >= 0.6: color = "white"
text = ax.text(j, i, val,
ha="center", va="center", color=color)
# add colorbar # add colorbar
cbar = ax.figure.colorbar(im, ax=ax) cbar = ax.figure.colorbar(im, ax=ax)

View File

@ -10,13 +10,8 @@ import threading
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
from teng_ml.util.transform import Multiply
# groups: date, name, n_object, voltage, distance, index # groups: date, name, n_object, voltage, distance, index
# re_filename = r"(\d{4}-\d{2}-\d{2})_([a-zA-Z_]+)_(\d{1,2}(?:\.\d*)?)V_(\d+(?:\.\d*)?)mm(\d+).csv" # re_filename = r"(\d{4}-\d{2}-\d{2})_([a-zA-Z_]+)_(\d{1,2}(?:\.\d*)?)V_(\d+(?:\.\d*)?)mm(\d+).csv"
# for teng_1
# re_filename = r"(\d{4}-\d{2}-\d{2})_([a-zA-Z_]+)_()(\d{1,2}(?:\.\d*)?)V_(\d+(?:\.\d*)?)mm(\d+).csv"
# for teng_2
re_filename = r"(\d{4}-\d{2}-\d{2})_([a-zA-Z0-9_]+)_(\d+)_(\d{1,2}(?:\.\d*)?)V_(\d+(?:\.\d*)?)mm(\d+).csv" re_filename = r"(\d{4}-\d{2}-\d{2})_([a-zA-Z0-9_]+)_(\d+)_(\d{1,2}(?:\.\d*)?)V_(\d+(?:\.\d*)?)mm(\d+).csv"
class LabelConverter: class LabelConverter:
@ -56,7 +51,7 @@ class Datasample:
def __init__(self, date: str, label: str, n_object: str, voltage: str, distance: str, index: str, label_vec, datapath: str, init_data=False): def __init__(self, date: str, label: str, n_object: str, voltage: str, distance: str, index: str, label_vec, datapath: str, init_data=False):
self.date = date self.date = date
self.label = label self.label = label
self.n_object = 0 if n_object == "" else int(n_object) self.n_object = int(n_object)
self.voltage = float(voltage) self.voltage = float(voltage)
self.distance = float(distance) self.distance = float(distance)
self.index = int(index) self.index = int(index)
@ -91,19 +86,6 @@ class Dataset:
""" """
self.transforms = transforms self.transforms = transforms
self.data = [] # (data, label) self.data = [] # (data, label)
# NORMALIZE ALL DATA WITH THE SAME FACTOR
# sup = 0
# inf = 0
# for sample in datasamples:
# data = sample.get_data()
# max_ = np.max(data[:,2])
# min_ = np.min(data[:,2])
# if max_ > sup: sup = max_
# if min_ < inf: inf = min_
# multiplier = 1 / max(sup, abs(inf))
# self.transforms.append(Multiply(multiplier))
for sample in datasamples: for sample in datasamples:
data = self.apply_transforms(sample.get_data()) data = self.apply_transforms(sample.get_data())
if split_function is None: if split_function is None:
@ -146,7 +128,7 @@ def get_datafiles(datadir, labels: LabelConverter, exclude_n_object=None, filter
label = match.groups()[1] label = match.groups()[1]
if label not in labels: continue if label not in labels: continue
sample_n_object = 0 if match.groups()[2] == "" else int(match.groups()[2]) sample_n_object = float(match.groups()[2])
if exclude_n_object and exclude_n_object == sample_n_object: continue if exclude_n_object and exclude_n_object == sample_n_object: continue
sample_voltage = float(match.groups()[3]) sample_voltage = float(match.groups()[3])
if filter_voltage and filter_voltage != sample_voltage: continue if filter_voltage and filter_voltage != sample_voltage: continue

View File

@ -1,6 +1,5 @@
import numpy as np import numpy as np
from scipy.interpolate import interp1d from scipy.interpolate import interp1d
from torch import mul
class Normalize: class Normalize:
""" """
@ -42,15 +41,6 @@ class NormalizeAmplitude:
return f"NormalizeAmplitude(high={self.high})" return f"NormalizeAmplitude(high={self.high})"
class Multiply:
def __init__(self, multiplier):
self.multiplier = multiplier
def __call__(self, data):
return data * self.multiplier
def __repr__(self):
return f"Multiply(multiplier={self.multiplier})"
class ConstantInterval: class ConstantInterval:
""" """
Interpolate the data to have a constant interval / sample rate, Interpolate the data to have a constant interval / sample rate,