changes for teng2
This commit is contained in:
parent
ddeec83e31
commit
577e47d03f
@ -41,23 +41,23 @@ def test_interpol():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
labels = LabelConverter(["white_foam", "glass", "Kapton", "bubble_wrap", "cloth", "black_foam"])
|
labels = LabelConverter(["white_foam", "black_foam", "rigid_foam", "cardboard", "glass", "Kapton", "bubble_wrap", "cloth_ffp2", ])
|
||||||
models_dir = "/home/matth/Uni/TENG/models" # where to save models, settings and results
|
models_dir = "/home/matth/Uni/TENG/teng_2/models_gen_1" # where to save models, settings and results
|
||||||
if not path.isdir(models_dir):
|
if not path.isdir(models_dir):
|
||||||
makedirs(models_dir)
|
makedirs(models_dir)
|
||||||
data_dir = "/home/matth/Uni/TENG/data"
|
data_dir = "/home/matth/Uni/TENG/teng_2/sorted_data"
|
||||||
|
|
||||||
|
|
||||||
# Test with
|
# Test with
|
||||||
num_layers = [ 3 ]
|
num_layers = [ 3 ]
|
||||||
hidden_size = [ 8 ]
|
hidden_size = [ 8 ]
|
||||||
bidirectional = [ True ]
|
bidirectional = [ True ]
|
||||||
t_const_int = ConstantInterval(0.01)
|
# t_const_int = ConstantInterval(0.01) TODO check if needed: data was taken at equal rate, but it isnt perfect -> maybe just ignore?
|
||||||
t_norm = Normalize(0, 1)
|
t_norm = Normalize(-1, 1)
|
||||||
transforms = [[ t_const_int ]] #, [ t_const_int, t_norm ]]
|
transforms = [[ t_const_int, t_norm ]]
|
||||||
batch_sizes = [ 64 ] # , 16]
|
batch_sizes = [ 64 ] # , 16]
|
||||||
splitters = [ DataSplitter(100) ]
|
splitters = [ DataSplitter(100) ] # TODO: try with 0.5-1second snippets
|
||||||
num_epochs = [ 80 ]
|
num_epochs = [ 60 ]
|
||||||
|
|
||||||
# num_layers=1,
|
# num_layers=1,
|
||||||
# hidden_size=1,
|
# hidden_size=1,
|
||||||
|
@ -43,6 +43,7 @@ class Datasample:
|
|||||||
def __init__(self, date: str, label: str, voltage: str, distance: str, index: str, label_vec, datapath: str, init_data=False):
|
def __init__(self, date: str, label: str, voltage: str, distance: str, index: str, label_vec, datapath: str, init_data=False):
|
||||||
self.date = date
|
self.date = date
|
||||||
self.label = label
|
self.label = label
|
||||||
|
self.n_object = n_object
|
||||||
self.voltage = float(voltage)
|
self.voltage = float(voltage)
|
||||||
self.distance = float(distance)
|
self.distance = float(distance)
|
||||||
self.index = int(index)
|
self.index = int(index)
|
||||||
|
Loading…
Reference in New Issue
Block a user