From 577e47d03f9fc0d7db58bd460fc23c42f1823cf9 Mon Sep 17 00:00:00 2001 From: "Matthias@Dell" Date: Thu, 3 Aug 2023 18:43:40 +0200 Subject: [PATCH] changes for teng2 --- teng_ml/main.py | 16 ++++++++-------- teng_ml/util/data_loader.py | 1 + 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/teng_ml/main.py b/teng_ml/main.py index f288c91..f2915a6 100644 --- a/teng_ml/main.py +++ b/teng_ml/main.py @@ -41,23 +41,23 @@ def test_interpol(): if __name__ == "__main__": - labels = LabelConverter(["white_foam", "glass", "Kapton", "bubble_wrap", "cloth", "black_foam"]) - models_dir = "/home/matth/Uni/TENG/models" # where to save models, settings and results + labels = LabelConverter(["white_foam", "black_foam", "rigid_foam", "cardboard", "glass", "Kapton", "bubble_wrap", "cloth_ffp2", ]) + models_dir = "/home/matth/Uni/TENG/teng_2/models_gen_1" # where to save models, settings and results if not path.isdir(models_dir): makedirs(models_dir) - data_dir = "/home/matth/Uni/TENG/data" + data_dir = "/home/matth/Uni/TENG/teng_2/sorted_data" # Test with num_layers = [ 3 ] hidden_size = [ 8 ] bidirectional = [ True ] - t_const_int = ConstantInterval(0.01) - t_norm = Normalize(0, 1) - transforms = [[ t_const_int ]] #, [ t_const_int, t_norm ]] + # t_const_int = ConstantInterval(0.01) TODO check if needed: data was taken at equal rate, but it isnt perfect -> maybe just ignore? + t_norm = Normalize(-1, 1) + transforms = [[ t_const_int, t_norm ]] batch_sizes = [ 64 ] # , 16] - splitters = [ DataSplitter(100) ] - num_epochs = [ 80 ] + splitters = [ DataSplitter(100) ] # TODO: try with 0.5-1second snippets + num_epochs = [ 60 ] # num_layers=1, # hidden_size=1, diff --git a/teng_ml/util/data_loader.py b/teng_ml/util/data_loader.py index 92b3711..e25e421 100644 --- a/teng_ml/util/data_loader.py +++ b/teng_ml/util/data_loader.py @@ -43,6 +43,7 @@ class Datasample: def __init__(self, date: str, label: str, voltage: str, distance: str, index: str, label_vec, datapath: str, init_data=False): self.date = date self.label = label + self.n_object = n_object self.voltage = float(voltage) self.distance = float(distance) self.index = int(index)