improve plots

This commit is contained in:
matthias@arch 2023-08-30 17:35:38 +02:00
parent f281b981ea
commit 77b266929d

View File

@ -10,8 +10,13 @@ import threading
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
from teng_ml.util.transform import Multiply
# groups: date, name, n_object, voltage, distance, index # groups: date, name, n_object, voltage, distance, index
# re_filename = r"(\d{4}-\d{2}-\d{2})_([a-zA-Z_]+)_(\d{1,2}(?:\.\d*)?)V_(\d+(?:\.\d*)?)mm(\d+).csv" # re_filename = r"(\d{4}-\d{2}-\d{2})_([a-zA-Z_]+)_(\d{1,2}(?:\.\d*)?)V_(\d+(?:\.\d*)?)mm(\d+).csv"
# for teng_1
# re_filename = r"(\d{4}-\d{2}-\d{2})_([a-zA-Z_]+)_()(\d{1,2}(?:\.\d*)?)V_(\d+(?:\.\d*)?)mm(\d+).csv"
# for teng_2
re_filename = r"(\d{4}-\d{2}-\d{2})_([a-zA-Z0-9_]+)_(\d+)_(\d{1,2}(?:\.\d*)?)V_(\d+(?:\.\d*)?)mm(\d+).csv" re_filename = r"(\d{4}-\d{2}-\d{2})_([a-zA-Z0-9_]+)_(\d+)_(\d{1,2}(?:\.\d*)?)V_(\d+(?:\.\d*)?)mm(\d+).csv"
class LabelConverter: class LabelConverter:
@ -51,7 +56,7 @@ class Datasample:
def __init__(self, date: str, label: str, n_object: str, voltage: str, distance: str, index: str, label_vec, datapath: str, init_data=False): def __init__(self, date: str, label: str, n_object: str, voltage: str, distance: str, index: str, label_vec, datapath: str, init_data=False):
self.date = date self.date = date
self.label = label self.label = label
self.n_object = int(n_object) self.n_object = 0 if n_object == "" else int(n_object)
self.voltage = float(voltage) self.voltage = float(voltage)
self.distance = float(distance) self.distance = float(distance)
self.index = int(index) self.index = int(index)
@ -86,6 +91,19 @@ class Dataset:
""" """
self.transforms = transforms self.transforms = transforms
self.data = [] # (data, label) self.data = [] # (data, label)
# NORMALIZE ALL DATA WITH THE SAME FACTOR
# sup = 0
# inf = 0
# for sample in datasamples:
# data = sample.get_data()
# max_ = np.max(data[:,2])
# min_ = np.min(data[:,2])
# if max_ > sup: sup = max_
# if min_ < inf: inf = min_
# multiplier = 1 / max(sup, abs(inf))
# self.transforms.append(Multiply(multiplier))
for sample in datasamples: for sample in datasamples:
data = self.apply_transforms(sample.get_data()) data = self.apply_transforms(sample.get_data())
if split_function is None: if split_function is None:
@ -128,7 +146,7 @@ def get_datafiles(datadir, labels: LabelConverter, exclude_n_object=None, filter
label = match.groups()[1] label = match.groups()[1]
if label not in labels: continue if label not in labels: continue
sample_n_object = float(match.groups()[2]) sample_n_object = 0 if match.groups()[2] == "" else int(match.groups()[2])
if exclude_n_object and exclude_n_object == sample_n_object: continue if exclude_n_object and exclude_n_object == sample_n_object: continue
sample_voltage = float(match.groups()[3]) sample_voltage = float(match.groups()[3])
if filter_voltage and filter_voltage != sample_voltage: continue if filter_voltage and filter_voltage != sample_voltage: continue