teng-ml/teng_ml/util/data_loader.py

155 lines
5.3 KiB
Python
Raw Normal View History

2023-04-27 01:53:47 +02:00
2023-04-28 16:03:31 +02:00
from os import path, listdir
import re
import numpy as np
import pandas as pd
2023-05-10 22:44:14 +02:00
from scipy.sparse import data
import threading
2023-04-27 01:53:47 +02:00
2023-04-28 16:03:31 +02:00
from sklearn.model_selection import train_test_split
2023-04-27 01:53:47 +02:00
2023-04-28 16:03:31 +02:00
# groups: date, name, voltage, distance, index
2023-05-07 21:40:39 +02:00
re_filename = r"(\d{4}-\d{2}-\d{2})_([a-zA-Z_]+)_(\d{1,2}(?:\.\d*)?)V_(\d+(?:\.\d*)?)mm(\d+).csv"
2023-04-27 01:53:47 +02:00
2023-04-28 16:03:31 +02:00
class LabelConverter:
2023-05-10 22:44:14 +02:00
def __init__(self, class_labels: list[str]):
2023-04-28 16:03:31 +02:00
self.class_labels = class_labels.copy()
self.class_labels.sort()
2023-04-27 01:53:47 +02:00
2023-04-28 16:03:31 +02:00
def get_one_hot(self, label):
"""return one hot vector for given label"""
vec = np.zeros(len(self.class_labels), dtype=np.float32)
vec[self.class_labels.index(label)] = 1.0
return vec
2023-04-27 01:53:47 +02:00
2023-04-28 16:03:31 +02:00
def __getitem__(self, index):
return self.class_labels[index]
2023-04-27 01:53:47 +02:00
2023-04-28 16:03:31 +02:00
def __contains__(self, value):
return value in self.class_labels
2023-04-27 01:53:47 +02:00
2023-05-05 18:26:44 +02:00
def __len__(self):
return len(self.class_labels)
2023-04-28 16:03:31 +02:00
def get_labels(self):
return self.class_labels.copy()
2023-04-27 01:53:47 +02:00
2023-05-10 22:44:14 +02:00
def __repr__(self):
return str(self.class_labels)
2023-04-28 16:03:31 +02:00
class Datasample:
2023-05-10 22:44:14 +02:00
def __init__(self, date: str, label: str, voltage: str, distance: str, index: str, label_vec, datapath: str, init_data=False):
2023-04-28 16:03:31 +02:00
self.date = date
self.label = label
2023-08-03 18:43:40 +02:00
self.n_object = n_object
2023-04-28 16:03:31 +02:00
self.voltage = float(voltage)
self.distance = float(distance)
self.index = int(index)
self.label_vec = label_vec
self.datapath = datapath
self.data = None
2023-05-10 22:44:14 +02:00
if init_data: self._load_data()
2023-04-28 16:03:31 +02:00
def __repr__(self):
2023-05-07 21:40:39 +02:00
size = self.data.size if self.data is not None else "Unknown"
2023-04-28 16:03:31 +02:00
return f"{self.label}-{self.index}: dimension={size}, recorded at {self.date} with U={self.voltage}V, d={self.distance}mm"
def _load_data(self):
2023-05-10 22:44:14 +02:00
# df = pd.read_csv(self.datapath)
self.data = np.loadtxt(self.datapath, skiprows=1, dtype=np.float32, delimiter=",")
2023-04-28 16:03:31 +02:00
def get_data(self):
2023-05-10 22:44:14 +02:00
"""[[timestamp, idata, vdata]]"""
2023-05-05 18:26:44 +02:00
if self.data is None:
2023-04-28 16:03:31 +02:00
self._load_data()
return self.data
2023-05-10 22:44:14 +02:00
2023-04-28 16:03:31 +02:00
class Dataset:
"""
Store the whole dataset, compatible with torch.data.Dataloader
"""
2023-05-10 22:44:14 +02:00
def __init__(self, datasamples, transforms=[], split_function=None):
"""
@param transforms: single callable or list of callables that are applied to the data (before eventual split)
@param split_function: (data) -> [data0, data1...] callable that splits the data
"""
2023-05-05 13:16:39 +02:00
self.transforms = transforms
2023-05-10 22:44:14 +02:00
self.data = [] # (data, label)
for sample in datasamples:
data = self.apply_transforms(sample.get_data())
if split_function is None:
self.data.append((data, sample.label_vec))
else:
for data_split in split_function(data):
self.data.append((data_split, sample.label_vec))
def apply_transforms(self, data):
2023-05-05 13:16:39 +02:00
if type(self.transforms) == list:
for t in self.transforms:
data = t(data)
2023-05-10 22:44:14 +02:00
elif self.transforms is not None:
2023-05-05 13:16:39 +02:00
data = self.transforms(data)
2023-05-10 22:44:14 +02:00
return data
def __getitem__(self, index):
return self.data[index]
2023-04-28 16:03:31 +02:00
def __len__(self):
2023-05-10 22:44:14 +02:00
return len(self.data)
2023-04-28 16:03:31 +02:00
2023-05-10 22:44:14 +02:00
def get_datafiles(datadir, labels: LabelConverter, voltage=None):
2023-04-28 16:03:31 +02:00
"""
2023-05-10 22:44:14 +02:00
get a list of all matching datafiles from datadir that are in the format: yyyy-mm-dd_label_x.xV_xxxmm.csv
2023-04-28 16:03:31 +02:00
"""
2023-05-10 22:44:14 +02:00
datafiles = []
2023-04-28 16:03:31 +02:00
files = listdir(datadir)
files.sort()
for file in files:
match = re.fullmatch(re_filename, file)
if not match: continue
label = match.groups()[1]
if label not in labels: continue
sample_voltage = float(match.groups()[2])
if voltage and voltage != sample_voltage: continue
2023-05-10 22:44:14 +02:00
datafiles.append((datadir + "/" + file, match, label))
return datafiles
2023-04-28 16:03:31 +02:00
2023-05-10 22:44:14 +02:00
def load_datasets(datadir, labels: LabelConverter, transforms=None, split_function=None, voltage=None, train_to_test_ratio=0.7, random_state=None, num_workers=None):
"""
load all data from datadir that are in the format: yyyy-mm-dd_label_x.xV_xxxmm.csv
"""
datasamples = []
if num_workers == None:
for file, match, label in get_datafiles(datadir, labels, voltage):
datasamples.append(Datasample(*match.groups(), labels.get_one_hot(label), file))
else:
files = get_datafiles(datadir, labels, voltage)
def worker():
while True:
try:
file, match, label = files.pop()
except IndexError:
# No more files to process
return
datasamples.append(Datasample(*match.groups(), labels.get_one_hot(label), file, init_data=True))
threads = [threading.Thread(target=worker) for _ in range(num_workers)]
for t in threads:
t.start()
for t in threads:
t.join()
# TODO do the train_test_split after the Dataset split
# problem: needs to be after transforms
2023-04-28 16:03:31 +02:00
train_samples, test_samples = train_test_split(datasamples, train_size=train_to_test_ratio, shuffle=True, random_state=random_state)
2023-05-10 22:44:14 +02:00
train_dataset = Dataset(train_samples, transforms=transforms, split_function=split_function)
test_dataset = Dataset(test_samples, transforms=transforms, split_function=split_function)
2023-04-28 16:03:31 +02:00
return train_dataset, test_dataset