From 77b266929d0d798085fc8cf6b4e8a1ec86c274bf Mon Sep 17 00:00:00 2001 From: "matthias@arch" Date: Wed, 30 Aug 2023 17:35:38 +0200 Subject: [PATCH] improve plots --- teng_ml/util/data_loader.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/teng_ml/util/data_loader.py b/teng_ml/util/data_loader.py index 2c9be7a..4e009f2 100644 --- a/teng_ml/util/data_loader.py +++ b/teng_ml/util/data_loader.py @@ -10,8 +10,13 @@ import threading from sklearn.model_selection import train_test_split +from teng_ml.util.transform import Multiply + # groups: date, name, n_object, voltage, distance, index # re_filename = r"(\d{4}-\d{2}-\d{2})_([a-zA-Z_]+)_(\d{1,2}(?:\.\d*)?)V_(\d+(?:\.\d*)?)mm(\d+).csv" +# for teng_1 +# re_filename = r"(\d{4}-\d{2}-\d{2})_([a-zA-Z_]+)_()(\d{1,2}(?:\.\d*)?)V_(\d+(?:\.\d*)?)mm(\d+).csv" +# for teng_2 re_filename = r"(\d{4}-\d{2}-\d{2})_([a-zA-Z0-9_]+)_(\d+)_(\d{1,2}(?:\.\d*)?)V_(\d+(?:\.\d*)?)mm(\d+).csv" class LabelConverter: @@ -51,7 +56,7 @@ class Datasample: def __init__(self, date: str, label: str, n_object: str, voltage: str, distance: str, index: str, label_vec, datapath: str, init_data=False): self.date = date self.label = label - self.n_object = int(n_object) + self.n_object = 0 if n_object == "" else int(n_object) self.voltage = float(voltage) self.distance = float(distance) self.index = int(index) @@ -86,6 +91,19 @@ class Dataset: """ self.transforms = transforms self.data = [] # (data, label) + + # NORMALIZE ALL DATA WITH THE SAME FACTOR + # sup = 0 + # inf = 0 + # for sample in datasamples: + # data = sample.get_data() + # max_ = np.max(data[:,2]) + # min_ = np.min(data[:,2]) + # if max_ > sup: sup = max_ + # if min_ < inf: inf = min_ + # multiplier = 1 / max(sup, abs(inf)) + # self.transforms.append(Multiply(multiplier)) + for sample in datasamples: data = self.apply_transforms(sample.get_data()) if split_function is None: @@ -128,7 +146,7 @@ def get_datafiles(datadir, labels: LabelConverter, exclude_n_object=None, filter label = match.groups()[1] if label not in labels: continue - sample_n_object = float(match.groups()[2]) + sample_n_object = 0 if match.groups()[2] == "" else int(match.groups()[2]) if exclude_n_object and exclude_n_object == sample_n_object: continue sample_voltage = float(match.groups()[3]) if filter_voltage and filter_voltage != sample_voltage: continue