add n_object to regex
This commit is contained in:
parent
577e47d03f
commit
ad2e3468f7
@ -9,8 +9,9 @@ import threading
|
|||||||
|
|
||||||
from sklearn.model_selection import train_test_split
|
from sklearn.model_selection import train_test_split
|
||||||
|
|
||||||
# groups: date, name, voltage, distance, index
|
# groups: date, name, n_object, voltage, distance, index
|
||||||
re_filename = r"(\d{4}-\d{2}-\d{2})_([a-zA-Z_]+)_(\d{1,2}(?:\.\d*)?)V_(\d+(?:\.\d*)?)mm(\d+).csv"
|
# re_filename = r"(\d{4}-\d{2}-\d{2})_([a-zA-Z_]+)_(\d{1,2}(?:\.\d*)?)V_(\d+(?:\.\d*)?)mm(\d+).csv"
|
||||||
|
re_filename = r"(\d{4}-\d{2}-\d{2})_([a-zA-Z0-9_]+)_(\d+)_(\d{1,2}(?:\.\d*)?)V_(\d+(?:\.\d*)?)mm(\d+).csv"
|
||||||
|
|
||||||
class LabelConverter:
|
class LabelConverter:
|
||||||
def __init__(self, class_labels: list[str]):
|
def __init__(self, class_labels: list[str]):
|
||||||
@ -40,10 +41,10 @@ class LabelConverter:
|
|||||||
|
|
||||||
|
|
||||||
class Datasample:
|
class Datasample:
|
||||||
def __init__(self, date: str, label: str, voltage: str, distance: str, index: str, label_vec, datapath: str, init_data=False):
|
def __init__(self, date: str, label: str, n_object: str, voltage: str, distance: str, index: str, label_vec, datapath: str, init_data=False):
|
||||||
self.date = date
|
self.date = date
|
||||||
self.label = label
|
self.label = label
|
||||||
self.n_object = n_object
|
self.n_object = int(n_object)
|
||||||
self.voltage = float(voltage)
|
self.voltage = float(voltage)
|
||||||
self.distance = float(distance)
|
self.distance = float(distance)
|
||||||
self.index = int(index)
|
self.index = int(index)
|
||||||
@ -101,9 +102,9 @@ class Dataset:
|
|||||||
return len(self.data)
|
return len(self.data)
|
||||||
|
|
||||||
|
|
||||||
def get_datafiles(datadir, labels: LabelConverter, voltage=None):
|
def get_datafiles(datadir, labels: LabelConverter, exclude_n_object=None, filter_voltage=None):
|
||||||
"""
|
"""
|
||||||
get a list of all matching datafiles from datadir that are in the format: yyyy-mm-dd_label_x.xV_xxxmm.csv
|
get a list of all matching datafiles from datadir that are in the format: yyyy-mm-dd_label__n_object_x.xV_xxxmm.csv
|
||||||
"""
|
"""
|
||||||
datafiles = []
|
datafiles = []
|
||||||
files = listdir(datadir)
|
files = listdir(datadir)
|
||||||
@ -115,8 +116,11 @@ def get_datafiles(datadir, labels: LabelConverter, voltage=None):
|
|||||||
label = match.groups()[1]
|
label = match.groups()[1]
|
||||||
if label not in labels: continue
|
if label not in labels: continue
|
||||||
|
|
||||||
sample_voltage = float(match.groups()[2])
|
sample_n_object = float(match.groups()[2])
|
||||||
if voltage and voltage != sample_voltage: continue
|
if exclude_n_object and exclude_n_object == sample_n_object: continue
|
||||||
|
sample_voltage = float(match.groups()[3])
|
||||||
|
if filter_voltage and filter_voltage != sample_voltage: continue
|
||||||
|
|
||||||
datafiles.append((datadir + "/" + file, match, label))
|
datafiles.append((datadir + "/" + file, match, label))
|
||||||
return datafiles
|
return datafiles
|
||||||
|
|
||||||
@ -145,7 +149,6 @@ def load_datasets(datadir, labels: LabelConverter, transforms=None, split_functi
|
|||||||
for t in threads:
|
for t in threads:
|
||||||
t.join()
|
t.join()
|
||||||
|
|
||||||
|
|
||||||
# TODO do the train_test_split after the Dataset split
|
# TODO do the train_test_split after the Dataset split
|
||||||
# problem: needs to be after transforms
|
# problem: needs to be after transforms
|
||||||
train_samples, test_samples = train_test_split(datasamples, train_size=train_to_test_ratio, shuffle=True, random_state=random_state)
|
train_samples, test_samples = train_test_split(datasamples, train_size=train_to_test_ratio, shuffle=True, random_state=random_state)
|
||||||
|
Loading…
Reference in New Issue
Block a user