Compare commits
6 Commits
5895f39874
...
c8201b5175
Author | SHA1 | Date | |
---|---|---|---|
|
c8201b5175 | ||
|
3d823a19c7 | ||
|
570d6dbc25 | ||
|
77b266929d | ||
|
f281b981ea | ||
|
9f9d3d7e06 |
@ -21,7 +21,8 @@ dependencies = [
|
|||||||
"matplotlib>=3.6",
|
"matplotlib>=3.6",
|
||||||
"numpy",
|
"numpy",
|
||||||
"torch",
|
"torch",
|
||||||
"scikit-learn"
|
"scikit-learn",
|
||||||
|
"pandas",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.urls]
|
[project.urls]
|
||||||
|
14
readme.md
14
readme.md
@ -1,13 +1,17 @@
|
|||||||
# Machine learning for material recognition with a TENG
|
# Machine learning for material recognition with a triboelectric nanogenerator (TENG)
|
||||||
(Bi)LSTM for name classification.
|
This project was written for my bachelor's thesis.
|
||||||
More information on the project are [on my website](https://quintern.xyz/en/teng.html).
|
|
||||||
|
It was written to classify TENG voltage output from pressing it against different materials.
|
||||||
|
Contents:
|
||||||
|
- Data preparation/plotting/loading utilites
|
||||||
|
- (Bi)LSTM + fully connected + softmax model for name classifiying TENG output
|
||||||
|
- Progress tracking utilities to easily find the best parameters
|
||||||
|
|
||||||
## Model training
|
## Model training
|
||||||
Adjust the parameters in `main.py` and run it.
|
Adjust the parameters in `main.py` and run it.
|
||||||
All models and the settings they were trained with are automatically serialized with pickle and stored in a subfolder
|
All models and the settings they were trained with are automatically serialized with pickle and stored in a subfolder
|
||||||
of the `<model_dir>` that was set in `main.py`.
|
of the `<model_dir>` that was set in `main.py`.
|
||||||
|
|
||||||
|
|
||||||
## Model evaluation
|
## Model evaluation
|
||||||
Run `find_best_model.py <model_dir>` with the `<model_dir>` specified in `main.py` during training.
|
Run `find_best_model.py <model_dir>` with the `<model_dir>` specified in `main.py` during training.
|
||||||
|
|
||||||
|
@ -42,9 +42,10 @@ def test_interpol():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
labels = LabelConverter(["foam_PDMS_white", "foam_PDMS_black", "foam_PDMS_TX100", "foam_PE", "antistatic_foil", "cardboard", "glass", "kapton", "bubble_wrap_PE", "fabric_PP", ])
|
# labels = LabelConverter(["foam_PDMS_white", "foam_PDMS_black", "foam_PDMS_TX100", "foam_PE", "antistatic_foil", "cardboard", "glass", "kapton", "bubble_wrap_PE", "fabric_PP" ])
|
||||||
# labels = LabelConverter(["foam_PDMS_white", "foam_PDMS_black", "foam_PDMS_TX100", "foam_PE", "kapton", "bubble_wrap_PE", "fabric_PP", ])
|
labels = LabelConverter(["foam_PDMS_white", "foam_PDMS_black", "foam_PDMS_TX100", "foam_PE", "antistatic_foil", "cardboard", "kapton", "bubble_wrap_PE", "fabric_PP" ])
|
||||||
models_dir = "/home/matth/Uni/TENG/teng_2/models_gen_12" # where to save models, settings and results
|
# labels = LabelConverter(["foam_PDMS_white", "foam_PDMS_black", "foam_PDMS_TX100", "foam_PE", "kapton", "bubble_wrap_PE", "fabric_PP" ])
|
||||||
|
models_dir = "/home/matth/Uni/TENG/teng_2/models_gen_15" # where to save models, settings and results
|
||||||
if not path.isdir(models_dir):
|
if not path.isdir(models_dir):
|
||||||
makedirs(models_dir)
|
makedirs(models_dir)
|
||||||
data_dir = "/home/matth/Uni/TENG/teng_2/sorted_data"
|
data_dir = "/home/matth/Uni/TENG/teng_2/sorted_data"
|
||||||
@ -53,18 +54,18 @@ if __name__ == "__main__":
|
|||||||
# gen_6 best options: no glass, cardboard and antistatic_foil, not bidirectional, lr=0.0007, no datasplitter, 2 layers n_hidden = 10
|
# gen_6 best options: no glass, cardboard and antistatic_foil, not bidirectional, lr=0.0007, no datasplitter, 2 layers n_hidden = 10
|
||||||
|
|
||||||
# Test with
|
# Test with
|
||||||
num_layers = [ 2, 3 ]
|
num_layers = [ 4, 5 ]
|
||||||
hidden_size = [ 21, 28 ]
|
hidden_size = [ 28, 36 ]
|
||||||
bidirectional = [ False, True ]
|
bidirectional = [ True ]
|
||||||
t_const_int = ConstantInterval(0.01) # TODO check if needed: data was taken at equal rate, but it isnt perfect -> maybe just ignore?
|
t_const_int = ConstantInterval(0.01) # TODO check if needed: data was taken at equal rate, but it isnt perfect -> maybe just ignore?
|
||||||
t_norm = Normalize(-1, 1)
|
t_norm = Normalize(-1, 1)
|
||||||
transforms = [[ t_norm ]] #, [ t_norm, t_const_int ]]
|
transforms = [[]] #, [ t_norm, t_const_int ]]
|
||||||
batch_sizes = [ 4 ]
|
batch_sizes = [ 4 ]
|
||||||
splitters = [ DataSplitter(50, drop_if_smaller_than=30) ] # smallest file has length 68 TODO: try with 0.5-1second snippets
|
splitters = [ DataSplitter(50, drop_if_smaller_than=30) ] # smallest file has length 68 TODO: try with 0.5-1second snippets
|
||||||
num_epochs = [ 80 ]
|
num_epochs = [ 80 ]
|
||||||
# (epoch, min_accuracy)
|
# (epoch, min_accuracy)
|
||||||
training_cancel_points = [(15, 20), (40, 25)]
|
# training_cancel_points = [(15, 20), (40, 25)]
|
||||||
# training_cancel_points = []
|
training_cancel_points = []
|
||||||
|
|
||||||
args = [num_layers, hidden_size, bidirectional, [None], [None], [None], transforms, splitters, num_epochs, batch_sizes]
|
args = [num_layers, hidden_size, bidirectional, [None], [None], [None], transforms, splitters, num_epochs, batch_sizes]
|
||||||
|
|
||||||
@ -81,7 +82,7 @@ if __name__ == "__main__":
|
|||||||
None,
|
None,
|
||||||
# lambda optimizer, st: torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9),
|
# lambda optimizer, st: torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9),
|
||||||
# lambda optimizer, st: torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.5),
|
# lambda optimizer, st: torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.5),
|
||||||
lambda optimizer, st: torch.optim.lr_scheduler.StepLR(optimizer, step_size=st.num_epochs // 8, gamma=0.60, verbose=False),
|
# lambda optimizer, st: torch.optim.lr_scheduler.StepLR(optimizer, step_size=st.num_epochs // 8, gamma=0.60, verbose=False),
|
||||||
# lambda optimizer, st: torch.optim.lr_scheduler.StepLR(optimizer, step_size=st.num_epochs // 10, gamma=0.75, verbose=False),
|
# lambda optimizer, st: torch.optim.lr_scheduler.StepLR(optimizer, step_size=st.num_epochs // 10, gamma=0.75, verbose=False),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -111,26 +111,28 @@ class EpochTracker:
|
|||||||
"""
|
"""
|
||||||
@param model_dir: Optional. If given, save to model_dir as svg
|
@param model_dir: Optional. If given, save to model_dir as svg
|
||||||
"""
|
"""
|
||||||
fig, ax = plt.subplots(nrows=3, ncols=1, sharex=True, layout="tight")
|
fig, ax = plt.subplots(nrows=2, ncols=1, sharex=True, layout="tight", figsize=(6, 6))
|
||||||
|
|
||||||
ax[0].plot(self.epochs, self.accuracies, color="red")
|
ax[0].plot(self.epochs, self.accuracies, color="red")
|
||||||
ax[0].set_ylabel("Accuracy")
|
ax[0].set_ylabel("Accuracy")
|
||||||
|
ax[0].grid("minor")
|
||||||
|
|
||||||
ax[1].plot(self.epochs, self.learning_rate, color="green")
|
ax[1].plot(self.epochs, self.learning_rate, color="green")
|
||||||
ax[1].set_ylabel("Learning Rate")
|
ax[1].set_ylabel("Learning Rate")
|
||||||
|
ax[1].grid("minor")
|
||||||
|
|
||||||
ax[2].plot(self.epochs, self.loss, color="blue")
|
# ax[2].plot(self.epochs, self.loss, color="blue")
|
||||||
ax[2].set_ylabel("Loss")
|
# ax[2].set_ylabel("Loss")
|
||||||
|
|
||||||
fig.suptitle(title)
|
fig.suptitle(title)
|
||||||
ax[2].set_xlabel("Epoch")
|
ax[-1].set_xlabel("Epoch")
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
if model_dir is not None:
|
if model_dir is not None:
|
||||||
fig.savefig(f"{model_dir}/{name}.svg")
|
fig.savefig(f"{model_dir}/{name}.svg")
|
||||||
|
|
||||||
return fig, ax
|
return fig, ax
|
||||||
|
|
||||||
def plot_predictions(self, title="Predictions per Label", ep=-1, model_dir=None, name="img_training_predictions"):
|
def plot_predictions(self, title="Predictions per Label", ep=-1, model_dir=None, name="img_training_predictions", empty_zero=True):
|
||||||
"""
|
"""
|
||||||
@param model_dir: Optional. If given, save to model_dir as svg
|
@param model_dir: Optional. If given, save to model_dir as svg
|
||||||
@param ep: Epoch, defaults to last
|
@param ep: Epoch, defaults to last
|
||||||
@ -141,8 +143,23 @@ class EpochTracker:
|
|||||||
|
|
||||||
N = len(self.labels)
|
N = len(self.labels)
|
||||||
label_names = self.labels.get_labels()
|
label_names = self.labels.get_labels()
|
||||||
|
# print(label_names)
|
||||||
|
replace = {
|
||||||
|
"cloth": "fabric",
|
||||||
|
"foam": "foam_PDMS_pure",
|
||||||
|
"foil": "bubble_wrap",
|
||||||
|
"rigid_foam": "foam_PE",
|
||||||
|
"fabric_PP": "fabric",
|
||||||
|
"foam_PDMS_white": "foam_PDMS_pure",
|
||||||
|
"foam_PDMS_black": "foam_PEDOT",
|
||||||
|
"bubble_wrap_PE": "bubble_wrap",
|
||||||
|
}
|
||||||
|
label_names = [ replace[label] if label in replace else label for label in label_names ]
|
||||||
|
|
||||||
fig, ax = plt.subplots(layout="tight")
|
if len(label_names) > 6:
|
||||||
|
fig, ax = plt.subplots(layout="tight", figsize=(7, 6))
|
||||||
|
else:
|
||||||
|
fig, ax = plt.subplots(layout="tight", figsize=(6, 5))
|
||||||
im = ax.imshow(normalized_predictions, cmap='Blues') # cmap='BuPu', , norm=colors.PowerNorm(1./2.)
|
im = ax.imshow(normalized_predictions, cmap='Blues') # cmap='BuPu', , norm=colors.PowerNorm(1./2.)
|
||||||
ax.set_xticks(np.arange(N))
|
ax.set_xticks(np.arange(N))
|
||||||
ax.set_yticks(np.arange(N))
|
ax.set_yticks(np.arange(N))
|
||||||
@ -155,14 +172,21 @@ class EpochTracker:
|
|||||||
for i in range(1, N):
|
for i in range(1, N):
|
||||||
ax.axhline(i-0.5, color='black', linewidth=1)
|
ax.axhline(i-0.5, color='black', linewidth=1)
|
||||||
|
|
||||||
|
# for i in range(1, N):
|
||||||
|
# ax.axvline(i-0.5, color='#bbb', linewidth=1)
|
||||||
|
|
||||||
# rotate the x-axis labels for better readability
|
# rotate the x-axis labels for better readability
|
||||||
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
|
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
|
||||||
|
|
||||||
# create annotations
|
# create annotations
|
||||||
for i in range(N):
|
for i in range(N):
|
||||||
for j in range(N):
|
for j in range(N):
|
||||||
text = ax.text(j, i, round(normalized_predictions[i, j], 2),
|
val = round(normalized_predictions[i, j], 2)
|
||||||
ha="center", va="center", color="black")
|
if empty_zero and val == 0: continue
|
||||||
|
color = "black"
|
||||||
|
if normalized_predictions[i, j] >= 0.6: color = "white"
|
||||||
|
text = ax.text(j, i, val,
|
||||||
|
ha="center", va="center", color=color)
|
||||||
|
|
||||||
# add colorbar
|
# add colorbar
|
||||||
cbar = ax.figure.colorbar(im, ax=ax)
|
cbar = ax.figure.colorbar(im, ax=ax)
|
||||||
|
@ -10,8 +10,13 @@ import threading
|
|||||||
|
|
||||||
from sklearn.model_selection import train_test_split
|
from sklearn.model_selection import train_test_split
|
||||||
|
|
||||||
|
from teng_ml.util.transform import Multiply
|
||||||
|
|
||||||
# groups: date, name, n_object, voltage, distance, index
|
# groups: date, name, n_object, voltage, distance, index
|
||||||
# re_filename = r"(\d{4}-\d{2}-\d{2})_([a-zA-Z_]+)_(\d{1,2}(?:\.\d*)?)V_(\d+(?:\.\d*)?)mm(\d+).csv"
|
# re_filename = r"(\d{4}-\d{2}-\d{2})_([a-zA-Z_]+)_(\d{1,2}(?:\.\d*)?)V_(\d+(?:\.\d*)?)mm(\d+).csv"
|
||||||
|
# for teng_1
|
||||||
|
# re_filename = r"(\d{4}-\d{2}-\d{2})_([a-zA-Z_]+)_()(\d{1,2}(?:\.\d*)?)V_(\d+(?:\.\d*)?)mm(\d+).csv"
|
||||||
|
# for teng_2
|
||||||
re_filename = r"(\d{4}-\d{2}-\d{2})_([a-zA-Z0-9_]+)_(\d+)_(\d{1,2}(?:\.\d*)?)V_(\d+(?:\.\d*)?)mm(\d+).csv"
|
re_filename = r"(\d{4}-\d{2}-\d{2})_([a-zA-Z0-9_]+)_(\d+)_(\d{1,2}(?:\.\d*)?)V_(\d+(?:\.\d*)?)mm(\d+).csv"
|
||||||
|
|
||||||
class LabelConverter:
|
class LabelConverter:
|
||||||
@ -51,7 +56,7 @@ class Datasample:
|
|||||||
def __init__(self, date: str, label: str, n_object: str, voltage: str, distance: str, index: str, label_vec, datapath: str, init_data=False):
|
def __init__(self, date: str, label: str, n_object: str, voltage: str, distance: str, index: str, label_vec, datapath: str, init_data=False):
|
||||||
self.date = date
|
self.date = date
|
||||||
self.label = label
|
self.label = label
|
||||||
self.n_object = int(n_object)
|
self.n_object = 0 if n_object == "" else int(n_object)
|
||||||
self.voltage = float(voltage)
|
self.voltage = float(voltage)
|
||||||
self.distance = float(distance)
|
self.distance = float(distance)
|
||||||
self.index = int(index)
|
self.index = int(index)
|
||||||
@ -86,6 +91,19 @@ class Dataset:
|
|||||||
"""
|
"""
|
||||||
self.transforms = transforms
|
self.transforms = transforms
|
||||||
self.data = [] # (data, label)
|
self.data = [] # (data, label)
|
||||||
|
|
||||||
|
# NORMALIZE ALL DATA WITH THE SAME FACTOR
|
||||||
|
# sup = 0
|
||||||
|
# inf = 0
|
||||||
|
# for sample in datasamples:
|
||||||
|
# data = sample.get_data()
|
||||||
|
# max_ = np.max(data[:,2])
|
||||||
|
# min_ = np.min(data[:,2])
|
||||||
|
# if max_ > sup: sup = max_
|
||||||
|
# if min_ < inf: inf = min_
|
||||||
|
# multiplier = 1 / max(sup, abs(inf))
|
||||||
|
# self.transforms.append(Multiply(multiplier))
|
||||||
|
|
||||||
for sample in datasamples:
|
for sample in datasamples:
|
||||||
data = self.apply_transforms(sample.get_data())
|
data = self.apply_transforms(sample.get_data())
|
||||||
if split_function is None:
|
if split_function is None:
|
||||||
@ -128,7 +146,7 @@ def get_datafiles(datadir, labels: LabelConverter, exclude_n_object=None, filter
|
|||||||
label = match.groups()[1]
|
label = match.groups()[1]
|
||||||
if label not in labels: continue
|
if label not in labels: continue
|
||||||
|
|
||||||
sample_n_object = float(match.groups()[2])
|
sample_n_object = 0 if match.groups()[2] == "" else int(match.groups()[2])
|
||||||
if exclude_n_object and exclude_n_object == sample_n_object: continue
|
if exclude_n_object and exclude_n_object == sample_n_object: continue
|
||||||
sample_voltage = float(match.groups()[3])
|
sample_voltage = float(match.groups()[3])
|
||||||
if filter_voltage and filter_voltage != sample_voltage: continue
|
if filter_voltage and filter_voltage != sample_voltage: continue
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from scipy.interpolate import interp1d
|
from scipy.interpolate import interp1d
|
||||||
|
from torch import mul
|
||||||
|
|
||||||
class Normalize:
|
class Normalize:
|
||||||
"""
|
"""
|
||||||
@ -41,6 +42,15 @@ class NormalizeAmplitude:
|
|||||||
return f"NormalizeAmplitude(high={self.high})"
|
return f"NormalizeAmplitude(high={self.high})"
|
||||||
|
|
||||||
|
|
||||||
|
class Multiply:
|
||||||
|
def __init__(self, multiplier):
|
||||||
|
self.multiplier = multiplier
|
||||||
|
def __call__(self, data):
|
||||||
|
return data * self.multiplier
|
||||||
|
def __repr__(self):
|
||||||
|
return f"Multiply(multiplier={self.multiplier})"
|
||||||
|
|
||||||
|
|
||||||
class ConstantInterval:
|
class ConstantInterval:
|
||||||
"""
|
"""
|
||||||
Interpolate the data to have a constant interval / sample rate,
|
Interpolate the data to have a constant interval / sample rate,
|
||||||
|
Loading…
Reference in New Issue
Block a user