teng-ml/teng_ml/util/string.py

77 lines
2.2 KiB
Python
Raw Normal View History

2023-05-10 22:44:14 +02:00
import inspect
import torch.optim.lr_scheduler as sd
import re
def fill_and_center(s: str, fill_char="=", length=100):
rs = fill_char * length
margin = (length - len(s)) // 2
if margin > 1:
rs = f"{fill_char*(margin-1)} {s} {fill_char*(margin-1)}"
if len(rs) == 99: rs = rs + "="
assert(len(rs) == 100)
return rs
else:
return s
2023-08-10 17:29:09 +02:00
2023-05-10 22:44:14 +02:00
def class_str(x):
"""
Return the constructor of the class of x with arguemnts
"""
name = type(x).__name__
params = []
2023-08-10 17:29:09 +02:00
try:
signature = inspect.signature(type(x))
for param_name, param_value in x.__dict__.items():
if param_name not in signature.parameters:
continue
default_value = signature.parameters[param_name].default
if param_value != default_value:
params.append(f"{param_name}={param_value!r}")
except ValueError:
pass
if params:
return f"{name}({', '.join(params)})"
else:
return name
def optimizer_str(x):
# optimizer stores everything in 'defaults' dict and is thus not compatible with class_str
name = type(x).__name__
params = []
try:
signature = inspect.signature(type(x))
for param_name, param_value in x.__dict__["defaults"].items():
if param_name not in signature.parameters:
continue
default_value = signature.parameters[param_name].default
if param_value != default_value:
params.append(f"{param_name}={param_value!r}")
except ValueError:
pass
2023-05-10 22:44:14 +02:00
if params:
return f"{name}({', '.join(params)})"
else:
return name
2023-08-10 17:29:09 +02:00
2023-05-10 22:44:14 +02:00
def cleanup_str(s):
"""
convert to string if necessary and
if scheduler string:
remove unnecessary parameters
"""
if not type(s) == str:
s = str(s)
# check if scheduler string
re_scheduler = r"(\w+)\((.*)(optimizer=[A-Za-z]+) \(.*(initial_lr: [\d.]+).*?\)(.*)\)"
# groups: (sched_name, sched_params1, optimizer=Name, initial_lr: digits, sched_params2)
match = re.fullmatch(re_scheduler, s.replace("\n", " "))
if match:
g = match.groups()
s = f"{g[0]}({g[1]}{g[2]}({g[3]}, ...){g[4]})"
return s
return s