diff --git a/teng_ml/util/transform.py b/teng_ml/util/transform.py index 96b7f5b..db78cc6 100644 --- a/teng_ml/util/transform.py +++ b/teng_ml/util/transform.py @@ -9,20 +9,37 @@ class Normalize: assert(low < high) self.low = low self.high = high - def __call__(self, a): - min_ = np.min(a) - a = a - min_ - max_ = np.max(a) + def __call__(self, data): + min_ = np.min(data) + data = data - min_ # smallest point is 0 now + max_ = np.max(data) if max_ != 0: - a = (a / max_) + data = (data / max_) # now normalized between 0 and 1 - a *= (self.high - self.low) - a -= self.low - return a + data *= (self.high - self.low) + data += self.low + return data def __repr__(self): return f"Normalize(low={self.low}, high={self.high})" +class NormalizeAmplitude: + """ + scale data so that all values are between -high and high + """ + def __init__(self, high=1): + self.high = high + + def __call__(self, data): + min_ = np.min(data) + max_ = np.max(data) + scale = np.max([np.abs(min_), np.abs(max_)]) + if scale != 0: + data = data / scale * self.high + return data + def __repr__(self): + return f"NormalizeAmplitude(high={self.high})" + class ConstantInterval: """ @@ -32,15 +49,15 @@ class ConstantInterval: def __init__(self, interval): self.interval = interval - def __call__(self, a): + def __call__(self, data): """ array: [timestamps, data1, data2...] """ - timestamps = a[:,0] + timestamps = data[:,0] new_stamps = np.arange(timestamps[0], timestamps[-1], self.interval) ret = new_stamps - for i in range(1, a.shape[1]): # - interp = interp1d(timestamps, a[:,i]) + for i in range(1, data.shape[1]): # + interp = interp1d(timestamps, data[:,i]) new_vals = interp(new_stamps) ret = np.vstack((ret, new_vals)) return ret.T