teng-ml/teng_ml/util/settings.py

46 lines
1.5 KiB
Python
Raw Permalink Normal View History

2023-05-07 21:40:39 +02:00
from ..util.data_loader import LabelConverter
2023-05-10 22:44:14 +02:00
from ..util.split import DataSplitter
2023-05-07 21:40:39 +02:00
class MLSettings:
"""
Manage model and training settings for easy saving and loading
"""
def __init__(self,
num_features=1,
num_layers=1,
hidden_size=1,
bidirectional=True,
2023-05-10 22:44:14 +02:00
optimizer=None,
scheduler=None,
loss_func=None,
2023-05-07 21:40:39 +02:00
transforms=[],
2023-05-10 22:44:14 +02:00
splitter=None,
2023-05-07 21:40:39 +02:00
num_epochs=10,
batch_size=5,
labels=LabelConverter([]),
):
self.num_features = num_features
self.num_layers = num_layers
self.hidden_size = hidden_size
self.num_epochs = num_epochs
self.bidirectional = bidirectional
2023-05-10 22:44:14 +02:00
self.optimizer = optimizer
self.scheduler = scheduler
self.loss_func = loss_func
2023-05-07 21:40:39 +02:00
self.transforms = transforms
2023-05-10 22:44:14 +02:00
self.splitter = splitter
2023-05-07 21:40:39 +02:00
self.batch_size = batch_size
self.labels = labels
def get_name(self):
"""
F = num_features
L = num_layers
H = hidden_size
B = bidirectional
T = #transforms
2023-05-10 22:44:14 +02:00
S = splitter
2023-05-07 21:40:39 +02:00
E = #epochs
"""
2023-05-10 22:44:14 +02:00
return f"F{self.num_features}L{self.num_layers}H{self.hidden_size:02}B{'1' if self.bidirectional else '0'}T{len(self.transforms)}S{self.splitter.split_size if type(self.splitter) == DataSplitter is not None else 0:03}E{self.num_epochs:03}"