60 lines
2.0 KiB
Python
60 lines
2.0 KiB
Python
from ..tracker.epoch_tracker import EpochTracker
|
|
from ..util.settings import MLSettings
|
|
|
|
import io
|
|
import pickle
|
|
|
|
# from https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory
|
|
class RenameUnpickler(pickle.Unpickler):
|
|
def find_class(self, module, name):
|
|
if "teng-ml" in module:
|
|
module = module.replace("teng-ml", "teng_ml")
|
|
return super(RenameUnpickler, self).find_class(module, name)
|
|
def renamed_load(file_obj):
|
|
return RenameUnpickler(file_obj).load()
|
|
def renamed_loads(pickled_bytes):
|
|
file_obj = io.BytesIO(pickled_bytes)
|
|
return renamed_load(file_obj)
|
|
|
|
|
|
"""
|
|
Load and save model, settings and EpochTrackers from/on disk
|
|
"""
|
|
|
|
def load_tracker_validation(model_dir):
|
|
with open(f"{model_dir}/tracker_validation.pkl", "rb") as file:
|
|
validation_tracker: EpochTracker = renamed_load(file)
|
|
return validation_tracker
|
|
|
|
def load_tracker_training(model_dir):
|
|
with open(f"{model_dir}/tracker_training.pkl", "rb") as file:
|
|
training_tracker: EpochTracker = renamed_load(file)
|
|
return training_tracker
|
|
|
|
def load_settings(model_dir):
|
|
with open(f"{model_dir}/settings.pkl", "rb") as file:
|
|
st: MLSettings = renamed_load(file)
|
|
return st
|
|
|
|
def load_model(model_dir):
|
|
with open(f"{model_dir}/model.pkl", "rb") as file:
|
|
model = renamed_load(file)
|
|
return model
|
|
|
|
|
|
def save_tracker_validation(model_dir, validation_tracker: EpochTracker):
|
|
with open(f"{model_dir}/tracker_validation.pkl", "wb") as file:
|
|
pickle.dump(validation_tracker, file)
|
|
|
|
def save_tracker_training(model_dir, training_tracker: EpochTracker):
|
|
with open(f"{model_dir}/tracker_training.pkl", "wb") as file:
|
|
pickle.dump(training_tracker, file)
|
|
|
|
def save_settings(model_dir, st):
|
|
with open(f"{model_dir}/settings.pkl", "wb") as file:
|
|
pickle.dump(st, file)
|
|
|
|
def save_model(model_dir, model):
|
|
with open(f"{model_dir}/model.pkl", "wb") as file:
|
|
pickle.dump(model, file)
|