2023-05-10 22:44:14 +02:00
|
|
|
from os import makedirs, path
|
|
|
|
import torch
|
|
|
|
import pickle
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
|
|
|
|
from ..util.settings import MLSettings
|
|
|
|
from ..tracker.epoch_tracker import EpochTracker
|
|
|
|
from ..util.file_io import get_next_digits
|
2023-08-10 17:29:09 +02:00
|
|
|
from ..util.string import class_str, optimizer_str
|
2023-05-10 22:44:14 +02:00
|
|
|
|
|
|
|
from ..util import model_io as mio
|
|
|
|
|
|
|
|
|
|
|
|
def select_device(force_device=None):
|
|
|
|
"""
|
|
|
|
Select best device and move model
|
|
|
|
"""
|
|
|
|
if force_device is not None:
|
|
|
|
device = force_device
|
|
|
|
else:
|
|
|
|
device = torch.device(
|
|
|
|
"cuda"
|
|
|
|
if torch.cuda.is_available()
|
|
|
|
# else "mps"
|
|
|
|
# if torch.backends.mps.is_available()
|
|
|
|
else "cpu"
|
|
|
|
)
|
|
|
|
# print(device, torch.cuda.get_device_name(device), torch.cuda.get_device_properties(device))
|
|
|
|
return device
|
|
|
|
|
|
|
|
|
2023-08-10 17:29:09 +02:00
|
|
|
def train(model, optimizer, scheduler, loss_func, train_loader: DataLoader, st: MLSettings, print_interval=1, print_continuous=False, training_cancel_points=[]) -> EpochTracker:
|
2023-05-10 22:44:14 +02:00
|
|
|
epoch_tracker = EpochTracker(st.labels)
|
|
|
|
epoch_tracker.begin()
|
|
|
|
for ep in range(st.num_epochs):
|
|
|
|
loss = -1
|
2023-08-10 17:29:09 +02:00
|
|
|
for i, (data, lengths, y) in enumerate(train_loader):
|
2023-05-10 22:44:14 +02:00
|
|
|
# data = batch, seq, features
|
|
|
|
x = data[:,:,[2]].float() # select voltage data
|
|
|
|
# print(f"x({x.shape}, {x.dtype})=...")
|
|
|
|
# print(f"y({y.shape}, {y.dtype})=...")
|
2023-08-10 17:29:09 +02:00
|
|
|
# pack = torch.nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True)
|
|
|
|
# out = model(pack) # really slow
|
|
|
|
out = model(x, lengths)
|
2023-05-10 22:44:14 +02:00
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
predicted = torch.argmax(out, dim=1, keepdim=False) # -> [ label_indices ]
|
|
|
|
correct = torch.argmax(y, dim=1, keepdim=False) # -> [ label_indices ]
|
|
|
|
# print(f"predicted={predicted}, correct={correct}")
|
|
|
|
# train_total += y.size(0)
|
|
|
|
# train_correct += (predicted == correct).sum().item()
|
|
|
|
epoch_tracker.add_prediction(correct, predicted)
|
|
|
|
# predicted2 = torch.argmax(out, dim=1, keepdim=True) # -> [ label_indices ]
|
|
|
|
# print(f"correct={correct}, y={y}")
|
|
|
|
loss = loss_func(out, correct)
|
|
|
|
# loss = loss_func(out, y)
|
|
|
|
|
|
|
|
|
|
|
|
optimizer.zero_grad() # clear gradients for next train
|
|
|
|
loss.backward() # backpropagation, compute gradients
|
|
|
|
optimizer.step() # apply gradients
|
|
|
|
|
|
|
|
# predicted = torch.max(torch.nn.functional.softmax(out), 1)[1]
|
|
|
|
epoch_tracker.end_epoch(loss, optimizer.param_groups[0]["lr"])
|
2023-08-10 17:29:09 +02:00
|
|
|
if (ep+1) % print_interval == 0:
|
|
|
|
if print_continuous: end='\r'
|
|
|
|
else: end='\n'
|
|
|
|
print(f"Training:", epoch_tracker.get_epoch_summary_str(), end=end)
|
|
|
|
# cancel training if model is not good enough
|
2023-08-14 18:43:07 +02:00
|
|
|
if len(training_cancel_points) > 0 and ep+1 == training_cancel_points[0][0]:
|
2023-08-10 17:29:09 +02:00
|
|
|
if epoch_tracker.accuracies[-1] < training_cancel_points[0][1]:
|
2023-08-14 18:43:07 +02:00
|
|
|
print(f"Training cancelled because the models accuracy={epoch_tracker.accuracies[-1]:.2f} < {training_cancel_points[0][1]} after {ep+1} epochs.")
|
2023-08-10 17:29:09 +02:00
|
|
|
break;
|
|
|
|
training_cancel_points.pop(0)
|
|
|
|
|
|
|
|
if scheduler is not None:
|
|
|
|
scheduler.step()
|
2023-05-10 22:44:14 +02:00
|
|
|
print("Training:", epoch_tracker.end())
|
|
|
|
return epoch_tracker
|
|
|
|
|
|
|
|
|
|
|
|
def validate(model, test_loader: DataLoader, st: MLSettings) -> EpochTracker:
|
|
|
|
epoch_tracker = EpochTracker(st.labels)
|
|
|
|
epoch_tracker.begin()
|
|
|
|
with torch.no_grad():
|
|
|
|
for i, (data, y) in enumerate(test_loader):
|
|
|
|
# print(ep, "Test")
|
2023-08-10 17:29:09 +02:00
|
|
|
x = data[:,[2]].float()
|
2023-05-10 22:44:14 +02:00
|
|
|
out = model(x)
|
|
|
|
|
|
|
|
predicted = torch.argmax(out, dim=1, keepdim=False) # -> [ label_indices ]
|
2023-08-10 17:29:09 +02:00
|
|
|
if y.shape[0] == 2: # batched
|
|
|
|
correct = torch.argmax(y, dim=1, keepdim=False) # -> [ label_indices ]
|
|
|
|
else: # unbatched
|
|
|
|
correct = torch.argmax(y, dim=0, keepdim=True) # -> [ label_indices ]
|
2023-05-10 22:44:14 +02:00
|
|
|
|
|
|
|
epoch_tracker.add_prediction(correct, predicted)
|
|
|
|
print("Validation:", epoch_tracker.end())
|
|
|
|
return epoch_tracker
|
|
|
|
|
|
|
|
|
2023-08-10 17:29:09 +02:00
|
|
|
def train_validate_save(model, optimizer, scheduler, loss_func, train_loader: DataLoader, test_loader: DataLoader, st: MLSettings, models_dir, print_interval=1, print_continuous=False, show_plots=False, training_cancel_points=[]):
|
2023-05-10 22:44:14 +02:00
|
|
|
# assumes model and data is already on correct device
|
|
|
|
# train_loader.to(device)
|
|
|
|
# test_loader.to(device)
|
|
|
|
|
|
|
|
# store optimizer, scheduler and loss_func in settings
|
2023-08-10 17:29:09 +02:00
|
|
|
st.optimizer = optimizer_str(optimizer)
|
2023-05-10 22:44:14 +02:00
|
|
|
st.scheduler = class_str(scheduler)
|
|
|
|
st.loss_func = class_str(loss_func)
|
|
|
|
|
|
|
|
model_name = st.get_name()
|
|
|
|
|
|
|
|
def add_tab(s):
|
|
|
|
return "\t" + str(s).replace("\n", "\n\t")
|
|
|
|
print(100 * '=')
|
2023-08-10 17:29:09 +02:00
|
|
|
print("model name:", model_name)
|
2023-05-10 22:44:14 +02:00
|
|
|
print(f"model:\n", add_tab(model))
|
2023-08-10 17:29:09 +02:00
|
|
|
print(f"loss_func: {st.loss_func}")
|
|
|
|
print(f"optimizer: {st.optimizer}")
|
|
|
|
print(f"scheduler: {st.scheduler}")
|
2023-05-10 22:44:14 +02:00
|
|
|
|
|
|
|
|
|
|
|
print(100 * '-')
|
2023-08-10 17:29:09 +02:00
|
|
|
training_tracker = train(model, optimizer, scheduler, loss_func, train_loader, st, print_interval=print_interval, print_continuous=print_continuous, training_cancel_points=training_cancel_points)
|
2023-05-10 22:44:14 +02:00
|
|
|
# print("Training: Count per label:", training_tracker.get_count_per_label())
|
|
|
|
# print("Training: Predictions per label:", training_tracker.get_predictions_per_label())
|
|
|
|
|
|
|
|
print(100 * '-')
|
|
|
|
validation_tracker = validate(model, test_loader, st)
|
|
|
|
# print("Validation: Count per label:", validation_tracker.get_count_per_label())
|
|
|
|
# print("Validation: Predictions per label:", validation_tracker.get_predictions_per_label())
|
|
|
|
|
|
|
|
|
|
|
|
digits = get_next_digits(f"{model_name}_", models_dir)
|
|
|
|
model_dir = f"{models_dir}/{model_name}_{digits}"
|
|
|
|
# do not put earlier, since the dir should not be created if training is interrupted
|
|
|
|
if not path.isdir(model_dir): # should always run, if not the digits function did not work
|
|
|
|
makedirs(model_dir)
|
|
|
|
|
|
|
|
fig, _ = validation_tracker.plot_predictions("Validation: Predictions", model_dir=model_dir, name="img_validation_predictions")
|
|
|
|
fig, _ = training_tracker.plot_predictions("Training: Predictions", model_dir=model_dir, name="img_training_predictions")
|
|
|
|
fig, _ = training_tracker.plot_training(model_dir=model_dir)
|
|
|
|
|
|
|
|
if show_plots:
|
|
|
|
plt.show()
|
|
|
|
plt.close('all')
|
|
|
|
|
|
|
|
# save the settings, results and model
|
|
|
|
mio.save_settings(model_dir, st)
|
|
|
|
mio.save_tracker_validation(model_dir, validation_tracker)
|
|
|
|
mio.save_tracker_training(model_dir, training_tracker)
|
|
|
|
mio.save_model(model_dir, model)
|