2023-04-27 01:53:47 +02:00
|
|
|
|
2023-04-28 16:03:31 +02:00
|
|
|
from os import path, listdir
|
|
|
|
import re
|
|
|
|
import numpy as np
|
|
|
|
import pandas as pd
|
2023-05-10 22:44:14 +02:00
|
|
|
from scipy.sparse import data
|
2023-08-10 17:29:09 +02:00
|
|
|
import torch
|
2023-05-10 22:44:14 +02:00
|
|
|
|
|
|
|
import threading
|
2023-04-27 01:53:47 +02:00
|
|
|
|
2023-04-28 16:03:31 +02:00
|
|
|
from sklearn.model_selection import train_test_split
|
2023-04-27 01:53:47 +02:00
|
|
|
|
2023-08-04 13:37:45 +02:00
|
|
|
# groups: date, name, n_object, voltage, distance, index
|
|
|
|
# re_filename = r"(\d{4}-\d{2}-\d{2})_([a-zA-Z_]+)_(\d{1,2}(?:\.\d*)?)V_(\d+(?:\.\d*)?)mm(\d+).csv"
|
|
|
|
re_filename = r"(\d{4}-\d{2}-\d{2})_([a-zA-Z0-9_]+)_(\d+)_(\d{1,2}(?:\.\d*)?)V_(\d+(?:\.\d*)?)mm(\d+).csv"
|
2023-04-27 01:53:47 +02:00
|
|
|
|
2023-04-28 16:03:31 +02:00
|
|
|
class LabelConverter:
|
2023-05-10 22:44:14 +02:00
|
|
|
def __init__(self, class_labels: list[str]):
|
2023-04-28 16:03:31 +02:00
|
|
|
self.class_labels = class_labels.copy()
|
|
|
|
self.class_labels.sort()
|
2023-04-27 01:53:47 +02:00
|
|
|
|
2023-04-28 16:03:31 +02:00
|
|
|
def get_one_hot(self, label):
|
|
|
|
"""return one hot vector for given label"""
|
|
|
|
vec = np.zeros(len(self.class_labels), dtype=np.float32)
|
|
|
|
vec[self.class_labels.index(label)] = 1.0
|
|
|
|
return vec
|
2023-04-27 01:53:47 +02:00
|
|
|
|
2023-08-10 17:29:09 +02:00
|
|
|
def get_label_index(self, one_hot: torch.Tensor):
|
|
|
|
"""return one hot vector for given label"""
|
|
|
|
return int(torch.argmax(one_hot).item())
|
|
|
|
|
2023-04-28 16:03:31 +02:00
|
|
|
def __getitem__(self, index):
|
2023-08-10 17:29:09 +02:00
|
|
|
if type(index) == torch.Tensor:
|
|
|
|
return self.class_labels[self.get_label_index(index)]
|
2023-04-28 16:03:31 +02:00
|
|
|
return self.class_labels[index]
|
2023-04-27 01:53:47 +02:00
|
|
|
|
2023-04-28 16:03:31 +02:00
|
|
|
def __contains__(self, value):
|
|
|
|
return value in self.class_labels
|
2023-04-27 01:53:47 +02:00
|
|
|
|
2023-05-05 18:26:44 +02:00
|
|
|
def __len__(self):
|
|
|
|
return len(self.class_labels)
|
|
|
|
|
2023-04-28 16:03:31 +02:00
|
|
|
def get_labels(self):
|
|
|
|
return self.class_labels.copy()
|
2023-04-27 01:53:47 +02:00
|
|
|
|
2023-05-10 22:44:14 +02:00
|
|
|
def __repr__(self):
|
|
|
|
return str(self.class_labels)
|
2023-04-28 16:03:31 +02:00
|
|
|
|
|
|
|
|
|
|
|
class Datasample:
|
2023-08-04 13:37:45 +02:00
|
|
|
def __init__(self, date: str, label: str, n_object: str, voltage: str, distance: str, index: str, label_vec, datapath: str, init_data=False):
|
2023-04-28 16:03:31 +02:00
|
|
|
self.date = date
|
|
|
|
self.label = label
|
2023-08-04 13:37:45 +02:00
|
|
|
self.n_object = int(n_object)
|
2023-04-28 16:03:31 +02:00
|
|
|
self.voltage = float(voltage)
|
|
|
|
self.distance = float(distance)
|
|
|
|
self.index = int(index)
|
|
|
|
self.label_vec = label_vec
|
|
|
|
self.datapath = datapath
|
|
|
|
self.data = None
|
2023-05-10 22:44:14 +02:00
|
|
|
if init_data: self._load_data()
|
2023-04-28 16:03:31 +02:00
|
|
|
|
|
|
|
def __repr__(self):
|
2023-05-07 21:40:39 +02:00
|
|
|
size = self.data.size if self.data is not None else "Unknown"
|
2023-04-28 16:03:31 +02:00
|
|
|
return f"{self.label}-{self.index}: dimension={size}, recorded at {self.date} with U={self.voltage}V, d={self.distance}mm"
|
|
|
|
|
|
|
|
def _load_data(self):
|
2023-05-10 22:44:14 +02:00
|
|
|
# df = pd.read_csv(self.datapath)
|
|
|
|
self.data = np.loadtxt(self.datapath, skiprows=1, dtype=np.float32, delimiter=",")
|
2023-04-28 16:03:31 +02:00
|
|
|
|
|
|
|
def get_data(self):
|
2023-05-10 22:44:14 +02:00
|
|
|
"""[[timestamp, idata, vdata]]"""
|
2023-05-05 18:26:44 +02:00
|
|
|
if self.data is None:
|
2023-04-28 16:03:31 +02:00
|
|
|
self._load_data()
|
|
|
|
return self.data
|
|
|
|
|
2023-05-10 22:44:14 +02:00
|
|
|
|
2023-04-28 16:03:31 +02:00
|
|
|
class Dataset:
|
|
|
|
"""
|
|
|
|
Store the whole dataset, compatible with torch.data.Dataloader
|
|
|
|
"""
|
2023-05-10 22:44:14 +02:00
|
|
|
def __init__(self, datasamples, transforms=[], split_function=None):
|
|
|
|
"""
|
|
|
|
@param transforms: single callable or list of callables that are applied to the data (before eventual split)
|
|
|
|
@param split_function: (data) -> [data0, data1...] callable that splits the data
|
|
|
|
"""
|
2023-05-05 13:16:39 +02:00
|
|
|
self.transforms = transforms
|
2023-05-10 22:44:14 +02:00
|
|
|
self.data = [] # (data, label)
|
|
|
|
for sample in datasamples:
|
|
|
|
data = self.apply_transforms(sample.get_data())
|
|
|
|
if split_function is None:
|
|
|
|
self.data.append((data, sample.label_vec))
|
|
|
|
else:
|
2023-08-10 17:29:09 +02:00
|
|
|
try:
|
|
|
|
for data_split in split_function(data):
|
|
|
|
self.data.append((data_split, sample.label_vec))
|
|
|
|
except ValueError as e:
|
|
|
|
raise ValueError(f"Exception occured during splitting of sample '{sample.datapath}': {e}")
|
2023-05-10 22:44:14 +02:00
|
|
|
|
|
|
|
def apply_transforms(self, data):
|
2023-05-05 13:16:39 +02:00
|
|
|
if type(self.transforms) == list:
|
|
|
|
for t in self.transforms:
|
|
|
|
data = t(data)
|
2023-05-10 22:44:14 +02:00
|
|
|
elif self.transforms is not None:
|
2023-05-05 13:16:39 +02:00
|
|
|
data = self.transforms(data)
|
2023-05-10 22:44:14 +02:00
|
|
|
return data
|
|
|
|
|
|
|
|
def __getitem__(self, index):
|
|
|
|
return self.data[index]
|
2023-04-28 16:03:31 +02:00
|
|
|
|
|
|
|
def __len__(self):
|
2023-05-10 22:44:14 +02:00
|
|
|
return len(self.data)
|
|
|
|
|
2023-04-28 16:03:31 +02:00
|
|
|
|
2023-08-04 13:37:45 +02:00
|
|
|
def get_datafiles(datadir, labels: LabelConverter, exclude_n_object=None, filter_voltage=None):
|
2023-04-28 16:03:31 +02:00
|
|
|
"""
|
2023-08-04 13:37:45 +02:00
|
|
|
get a list of all matching datafiles from datadir that are in the format: yyyy-mm-dd_label__n_object_x.xV_xxxmm.csv
|
2023-04-28 16:03:31 +02:00
|
|
|
"""
|
2023-05-10 22:44:14 +02:00
|
|
|
datafiles = []
|
2023-04-28 16:03:31 +02:00
|
|
|
files = listdir(datadir)
|
|
|
|
files.sort()
|
|
|
|
for file in files:
|
|
|
|
match = re.fullmatch(re_filename, file)
|
2023-08-10 17:29:09 +02:00
|
|
|
if not match:
|
|
|
|
print(f"get_datafiles: dropping non matching file '{file}'")
|
|
|
|
continue
|
2023-04-28 16:03:31 +02:00
|
|
|
|
|
|
|
label = match.groups()[1]
|
|
|
|
if label not in labels: continue
|
|
|
|
|
2023-08-04 13:37:45 +02:00
|
|
|
sample_n_object = float(match.groups()[2])
|
|
|
|
if exclude_n_object and exclude_n_object == sample_n_object: continue
|
|
|
|
sample_voltage = float(match.groups()[3])
|
|
|
|
if filter_voltage and filter_voltage != sample_voltage: continue
|
|
|
|
|
2023-05-10 22:44:14 +02:00
|
|
|
datafiles.append((datadir + "/" + file, match, label))
|
|
|
|
return datafiles
|
2023-04-28 16:03:31 +02:00
|
|
|
|
2023-05-10 22:44:14 +02:00
|
|
|
|
2023-08-10 17:29:09 +02:00
|
|
|
def load_datasets(datadir, labels: LabelConverter, transforms=None, split_function=None, exclude_n_object=None, voltage=None, train_to_test_ratio=0.7, random_state=None, num_workers=None):
|
2023-05-10 22:44:14 +02:00
|
|
|
"""
|
|
|
|
load all data from datadir that are in the format: yyyy-mm-dd_label_x.xV_xxxmm.csv
|
|
|
|
"""
|
|
|
|
datasamples = []
|
|
|
|
if num_workers == None:
|
2023-08-10 17:29:09 +02:00
|
|
|
for file, match, label in get_datafiles(datadir, labels, exclude_n_object=exclude_n_object, filter_voltage=voltage):
|
2023-05-10 22:44:14 +02:00
|
|
|
datasamples.append(Datasample(*match.groups(), labels.get_one_hot(label), file))
|
|
|
|
else:
|
2023-08-10 17:29:09 +02:00
|
|
|
files = get_datafiles(datadir, labels, exclude_n_object=exclude_n_object, filter_voltage=voltage)
|
2023-05-10 22:44:14 +02:00
|
|
|
def worker():
|
|
|
|
while True:
|
|
|
|
try:
|
|
|
|
file, match, label = files.pop()
|
|
|
|
except IndexError:
|
|
|
|
# No more files to process
|
|
|
|
return
|
|
|
|
datasamples.append(Datasample(*match.groups(), labels.get_one_hot(label), file, init_data=True))
|
|
|
|
threads = [threading.Thread(target=worker) for _ in range(num_workers)]
|
|
|
|
for t in threads:
|
|
|
|
t.start()
|
|
|
|
for t in threads:
|
|
|
|
t.join()
|
|
|
|
|
|
|
|
# TODO do the train_test_split after the Dataset split
|
|
|
|
# problem: needs to be after transforms
|
2023-04-28 16:03:31 +02:00
|
|
|
train_samples, test_samples = train_test_split(datasamples, train_size=train_to_test_ratio, shuffle=True, random_state=random_state)
|
2023-05-10 22:44:14 +02:00
|
|
|
train_dataset = Dataset(train_samples, transforms=transforms, split_function=split_function)
|
|
|
|
test_dataset = Dataset(test_samples, transforms=transforms, split_function=split_function)
|
2023-04-28 16:03:31 +02:00
|
|
|
return train_dataset, test_dataset
|
2023-08-10 17:29:09 +02:00
|
|
|
|
|
|
|
|
|
|
|
def count_data(data_loader, label_converter: LabelConverter, print_summary=False):
|
|
|
|
"""
|
|
|
|
@param data_loader: unbatched data loader
|
|
|
|
"""
|
|
|
|
n_sequences = 0 # count number of sequences
|
|
|
|
labels = [ 0 for _ in range(len(label_converter)) ] # count number of sequences per label
|
|
|
|
len_data = [ 0 for _ in range(len(label_converter)) ] # count number of datapoints per label
|
|
|
|
for i, (data, y) in enumerate(data_loader):
|
|
|
|
n_sequences = i
|
|
|
|
label_i = label_converter.get_label_index(y)
|
|
|
|
len_data[label_i] += data.shape[0]
|
|
|
|
labels[label_i] += 1
|
|
|
|
if print_summary:
|
|
|
|
print("=" * 50)
|
|
|
|
print("Dataset summary" + f" for {print_summary}:" if type(print_summary) == str else ":")
|
|
|
|
print(f"Number of sequences: {n_sequences}")
|
|
|
|
for i in range(len(label_converter)):
|
|
|
|
print(f"- {label_converter[i]:15}: {labels[i]:3} sequences, {len_data[i]:5} datapoints")
|
|
|
|
|
|
|
|
return n_sequences, labels, len_data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|