m-teng/m_teng/update_funcs.py
2023-06-18 17:38:10 +02:00

124 lines
4.9 KiB
Python

import matplotlib.pyplot as plt
import numpy as np
import torch
from teng_ml.util import model_io as mio
from teng_ml.util.settings import MLSettings
from teng_ml.util.split import DataSplitter
from m_teng.backends.keithley import keithley
def _update_print(i, ival, vval):
print(f"n = {i:5d}, I = {ival: .12f} A, U = {vval: .5f} V" + " "*10, end='\r')
class _Monitor:
"""
Monitor v and i data
"""
def __init__(self, max_points_shown=None, use_print=False):
self.max_points_shown = max_points_shown
self.use_print = use_print
self.index = []
self.vdata = []
self.idata = []
plt.ion()
self.fig1, (self.vax, self.iax) = plt.subplots(2, 1, figsize=(8, 5))
self.vline, = self.vax.plot(self.index, self.vdata, color="g")
self.vax.set_ylabel("Voltage [V]")
self.vax.grid(True)
self.iline, = self.iax.plot(self.index, self.idata, color="m")
self.iax.set_ylabel("Current [A]")
self.iax.grid(True)
def update(self, i, ival, vval):
if self.use_print:
_update_print(i, ival, vval)
self.index.append(i)
self.idata.append(ival)
self.vdata.append(vval)
# update data
self.iline.set_xdata(self.index)
self.iline.set_ydata(self.idata)
self.vline.set_xdata(self.index)
self.vline.set_ydata(self.vdata)
# recalculate limits and set them for the view
self.iax.relim()
self.vax.relim()
if self.max_points_shown and i > self.max_points_shown:
self.iax.set_xlim(i - self.max_points_shown, i)
self.vax.set_xlim(i - self.max_points_shown, i)
self.iax.autoscale_view()
self.vax.autoscale_view()
# update plot
self.fig1.canvas.draw()
self.fig1.canvas.flush_events()
def __del__(self):
plt.close(self.fig1)
class _ModelPredict:
colors = ["red", "green", "purple", "blue", "orange", "grey", "cyan"]
def __init__(self, instr, model_dir):
"""
@param model_dir: directory where model.plk and settings.pkl are stored
Predict the values that are currently being recorded
@details:
Load the model and model settings from model dir
Wait until the number of recoreded points is >= the size of the models DataSplitter
Collect the data from the keithley, apply the transforms and predict the label with the model
Shows the prediction with a bar plot
"""
self.instr = instr
self.model = mio.load_model(model_dir)
self.model_settings: MLSettings = mio.load_settings(model_dir)
if type(self.model_settings.splitter) == DataSplitter:
self.data_length = self.model_settings.splitter.split_size
else:
self.data_length = 200
plt.ion()
self.fig1, (self.ax) = plt.subplots(1, 1, figsize=(8, 5))
self.bar_cont = self.ax.bar(self.model_settings.labels.get_labels(), [ 1 for _ in range(len(self.model_settings.labels))])
self.ax.set_ylabel("Prediction")
self.ax.grid(True)
def update(self, i, ival, vval):
buffer_size = keithley.get_buffer_size(self.instr, buffer_nr=1)
if buffer_size <= self.data_length:
print(f"ModelPredict.update: buffer_size={buffer_size} < {self.data_length}")
return
else:
ibuffer = keithley.collect_buffer_range(self.instr, (buffer_size-self.data_length, buffer_size), buffer_nr=1)
vbuffer = keithley.collect_buffer_range(self.instr, (buffer_size-self.data_length, buffer_size), buffer_nr=2)
if self.model_settings.num_features == 1: # model uses only voltage
data = np.vstack((ibuffer[:,0], ibuffer[:,1], vbuffer[:,1])).T
# print(f"data.shape:", data.shape)
else:
raise NotImplementedError(f"Cant handle models with num_features != 1 yet")
for t in self.model_settings.transforms:
data = t(data)
data = np.reshape(data[:,2], (1, -1, 1)) # batch_size, seq, features
with torch.no_grad():
x = torch.FloatTensor(data) # select voltage data, without timestamps
# print(x.shape)
prediction = self.model(x) # (batch_size, label-predictions)
prediction = torch.nn.functional.softmax(prediction) # TODO remove when softmax is already applied by model
predicted = torch.argmax(prediction, dim=1, keepdim=False) # -> [ label_indices ]
# print(f"raw={prediction[0]}, predicted_index={predicted[0}")
label = self.model_settings.labels[predicted[0]]
# print(f"-> label={label}")
self.bar_cont.remove()
self.bar_cont = self.ax.bar(self.model_settings.labels.get_labels(), prediction[0], color=_ModelPredict.colors[:len(self.model_settings.labels)])
# update plot
self.fig1.canvas.draw()
self.fig1.canvas.flush_events()