diff --git a/teng_ml/util/data_loader.py b/teng_ml/util/data_loader.py index e25e421..7e8a770 100644 --- a/teng_ml/util/data_loader.py +++ b/teng_ml/util/data_loader.py @@ -9,8 +9,9 @@ import threading from sklearn.model_selection import train_test_split -# groups: date, name, voltage, distance, index -re_filename = r"(\d{4}-\d{2}-\d{2})_([a-zA-Z_]+)_(\d{1,2}(?:\.\d*)?)V_(\d+(?:\.\d*)?)mm(\d+).csv" +# groups: date, name, n_object, voltage, distance, index +# re_filename = r"(\d{4}-\d{2}-\d{2})_([a-zA-Z_]+)_(\d{1,2}(?:\.\d*)?)V_(\d+(?:\.\d*)?)mm(\d+).csv" +re_filename = r"(\d{4}-\d{2}-\d{2})_([a-zA-Z0-9_]+)_(\d+)_(\d{1,2}(?:\.\d*)?)V_(\d+(?:\.\d*)?)mm(\d+).csv" class LabelConverter: def __init__(self, class_labels: list[str]): @@ -40,10 +41,10 @@ class LabelConverter: class Datasample: - def __init__(self, date: str, label: str, voltage: str, distance: str, index: str, label_vec, datapath: str, init_data=False): + def __init__(self, date: str, label: str, n_object: str, voltage: str, distance: str, index: str, label_vec, datapath: str, init_data=False): self.date = date self.label = label - self.n_object = n_object + self.n_object = int(n_object) self.voltage = float(voltage) self.distance = float(distance) self.index = int(index) @@ -101,9 +102,9 @@ class Dataset: return len(self.data) -def get_datafiles(datadir, labels: LabelConverter, voltage=None): +def get_datafiles(datadir, labels: LabelConverter, exclude_n_object=None, filter_voltage=None): """ - get a list of all matching datafiles from datadir that are in the format: yyyy-mm-dd_label_x.xV_xxxmm.csv + get a list of all matching datafiles from datadir that are in the format: yyyy-mm-dd_label__n_object_x.xV_xxxmm.csv """ datafiles = [] files = listdir(datadir) @@ -115,8 +116,11 @@ def get_datafiles(datadir, labels: LabelConverter, voltage=None): label = match.groups()[1] if label not in labels: continue - sample_voltage = float(match.groups()[2]) - if voltage and voltage != sample_voltage: continue + sample_n_object = float(match.groups()[2]) + if exclude_n_object and exclude_n_object == sample_n_object: continue + sample_voltage = float(match.groups()[3]) + if filter_voltage and filter_voltage != sample_voltage: continue + datafiles.append((datadir + "/" + file, match, label)) return datafiles @@ -145,7 +149,6 @@ def load_datasets(datadir, labels: LabelConverter, transforms=None, split_functi for t in threads: t.join() - # TODO do the train_test_split after the Dataset split # problem: needs to be after transforms train_samples, test_samples = train_test_split(datasamples, train_size=train_to_test_ratio, shuffle=True, random_state=random_state)