changes for teng2

This commit is contained in:
Matthias@Dell 2023-08-03 18:43:40 +02:00
parent ddeec83e31
commit 577e47d03f
2 changed files with 9 additions and 8 deletions

View File

@ -41,23 +41,23 @@ def test_interpol():
if __name__ == "__main__": if __name__ == "__main__":
labels = LabelConverter(["white_foam", "glass", "Kapton", "bubble_wrap", "cloth", "black_foam"]) labels = LabelConverter(["white_foam", "black_foam", "rigid_foam", "cardboard", "glass", "Kapton", "bubble_wrap", "cloth_ffp2", ])
models_dir = "/home/matth/Uni/TENG/models" # where to save models, settings and results models_dir = "/home/matth/Uni/TENG/teng_2/models_gen_1" # where to save models, settings and results
if not path.isdir(models_dir): if not path.isdir(models_dir):
makedirs(models_dir) makedirs(models_dir)
data_dir = "/home/matth/Uni/TENG/data" data_dir = "/home/matth/Uni/TENG/teng_2/sorted_data"
# Test with # Test with
num_layers = [ 3 ] num_layers = [ 3 ]
hidden_size = [ 8 ] hidden_size = [ 8 ]
bidirectional = [ True ] bidirectional = [ True ]
t_const_int = ConstantInterval(0.01) # t_const_int = ConstantInterval(0.01) TODO check if needed: data was taken at equal rate, but it isnt perfect -> maybe just ignore?
t_norm = Normalize(0, 1) t_norm = Normalize(-1, 1)
transforms = [[ t_const_int ]] #, [ t_const_int, t_norm ]] transforms = [[ t_const_int, t_norm ]]
batch_sizes = [ 64 ] # , 16] batch_sizes = [ 64 ] # , 16]
splitters = [ DataSplitter(100) ] splitters = [ DataSplitter(100) ] # TODO: try with 0.5-1second snippets
num_epochs = [ 80 ] num_epochs = [ 60 ]
# num_layers=1, # num_layers=1,
# hidden_size=1, # hidden_size=1,

View File

@ -43,6 +43,7 @@ class Datasample:
def __init__(self, date: str, label: str, voltage: str, distance: str, index: str, label_vec, datapath: str, init_data=False): def __init__(self, date: str, label: str, voltage: str, distance: str, index: str, label_vec, datapath: str, init_data=False):
self.date = date self.date = date
self.label = label self.label = label
self.n_object = n_object
self.voltage = float(voltage) self.voltage = float(voltage)
self.distance = float(distance) self.distance = float(distance)
self.index = int(index) self.index = int(index)