changes for teng2

This commit is contained in:
Matthias@Dell 2023-08-03 18:43:40 +02:00
parent ddeec83e31
commit 577e47d03f
2 changed files with 9 additions and 8 deletions

View File

@ -41,23 +41,23 @@ def test_interpol():
if __name__ == "__main__":
labels = LabelConverter(["white_foam", "glass", "Kapton", "bubble_wrap", "cloth", "black_foam"])
models_dir = "/home/matth/Uni/TENG/models" # where to save models, settings and results
labels = LabelConverter(["white_foam", "black_foam", "rigid_foam", "cardboard", "glass", "Kapton", "bubble_wrap", "cloth_ffp2", ])
models_dir = "/home/matth/Uni/TENG/teng_2/models_gen_1" # where to save models, settings and results
if not path.isdir(models_dir):
makedirs(models_dir)
data_dir = "/home/matth/Uni/TENG/data"
data_dir = "/home/matth/Uni/TENG/teng_2/sorted_data"
# Test with
num_layers = [ 3 ]
hidden_size = [ 8 ]
bidirectional = [ True ]
t_const_int = ConstantInterval(0.01)
t_norm = Normalize(0, 1)
transforms = [[ t_const_int ]] #, [ t_const_int, t_norm ]]
# t_const_int = ConstantInterval(0.01) TODO check if needed: data was taken at equal rate, but it isnt perfect -> maybe just ignore?
t_norm = Normalize(-1, 1)
transforms = [[ t_const_int, t_norm ]]
batch_sizes = [ 64 ] # , 16]
splitters = [ DataSplitter(100) ]
num_epochs = [ 80 ]
splitters = [ DataSplitter(100) ] # TODO: try with 0.5-1second snippets
num_epochs = [ 60 ]
# num_layers=1,
# hidden_size=1,

View File

@ -43,6 +43,7 @@ class Datasample:
def __init__(self, date: str, label: str, voltage: str, distance: str, index: str, label_vec, datapath: str, init_data=False):
self.date = date
self.label = label
self.n_object = n_object
self.voltage = float(voltage)
self.distance = float(distance)
self.index = int(index)