2023-05-10 22:44:14 +02:00
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
class DataSplitter:
|
|
|
|
r"""
|
|
|
|
Split a numpy array into smaller arrays of size datapoints_per_split
|
|
|
|
If data.shape(0) % datapoints_per_split != 0, the remaining datapoints are dropped
|
|
|
|
"""
|
2023-08-10 17:29:09 +02:00
|
|
|
def __init__(self, datapoints_per_split, drop_if_smaller_than=-1):
|
|
|
|
"""
|
|
|
|
@param drop_if_smaller_than: drop the remaining datapoints if the sequence would be smaller than this value. -1 means drop_if_smaller_than=datapoints_per_split
|
|
|
|
"""
|
2023-05-10 22:44:14 +02:00
|
|
|
self.split_size = datapoints_per_split
|
2023-08-10 17:29:09 +02:00
|
|
|
self.drop_threshhold = datapoints_per_split if drop_if_smaller_than == -1 else drop_if_smaller_than
|
2023-05-10 22:44:14 +02:00
|
|
|
|
|
|
|
def __call__(self, data: np.ndarray):
|
|
|
|
"""
|
|
|
|
data: [[t, i, v]]
|
|
|
|
"""
|
|
|
|
ret_data = []
|
|
|
|
for i in range(self.split_size, data.shape[0], self.split_size):
|
|
|
|
ret_data.append(data[i-self.split_size:i, :])
|
2023-08-10 17:29:09 +02:00
|
|
|
|
|
|
|
rest_start = len(ret_data) * self.split_size
|
|
|
|
if len(data) - rest_start >= self.drop_threshhold:
|
|
|
|
ret_data.append(data[rest_start:,:])
|
|
|
|
|
2023-05-10 22:44:14 +02:00
|
|
|
if len(ret_data) == 0:
|
|
|
|
raise ValueError(f"data has only {data.shape[0]}, but datapoints_per_split is set to {self.split_size}")
|
|
|
|
return ret_data
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
return f"DataSplitter({self.split_size})"
|