2025-05-09 15:08:19 +02:00

440 lines
16 KiB
Python

import pandas as pd
import numpy as np
import os
import matplotlib.pyplot as plt
import pickle
import datetime
import logging
log = logging.getLogger(__name__)
from ..utility.file_io import get_next_filename, sanitize_filename
FLUSH_TYPE = "pickle-ndarray"
PARTIAL_PREFIX = "PART_"
METADATA_FILENAME = PARTIAL_PREFIX + "measurement_metadata.pkl"
class PrsData:
"""
Class managing data and metadata.
Can be initialized from data directly, or from a file or directory path.
Keys:
- dR: delta R, the change in reflectivity. (This is the signal amplitude "R" from the lock-in)
- R: the baseline reflectivity. (DC signal measured using Aux In of the lock-in)
- theta: phase
- <qty>_raw: The raw measurement data (all individual samples)
data is a dictionary with:
- key: wavelength as int
- value: dictionary with:
- key: quantity as string: ["dR_raw", "R_raw", "theta_raw", "dR", ...]
- value: quantity value, or array if "_raw"
"""
def __init__(self, data: dict | None=None, metadata: dict | None=None,
load_data_path: str|None=None,
write_data_path: str|None=None,
write_data_name: str = "PRS",
write_dirname: str | None = None,
add_number_if_dir_exists=True,
):
if type(metadata) == dict:
self.metadata = metadata
else:
self.metadata = {}
self.data = data
if data is None and load_data_path is None:
raise ValueError("Either path or data must be defined.")
if data is not None and load_data_path is not None:
raise ValueError("Either path or data must be defined, but not both.")
if load_data_path is not None: # load from file
if os.path.isdir(load_data_path):
self.data, md = PrsData.load_data_from_dir(load_data_path)
self.metadata |= md
elif os.path.isfile(load_data_path):
if load_data_path.endswith(".csv"):
self.data, md = PrsData.load_data_from_csv(load_data_path)
self.metadata |= md
elif load_data_path.endswith(".pkl"):
self.data, md = PrsData.load_data_from_pkl(load_data_path)
self.metadata |= md
else:
raise NotImplementedError(f"Only .csv and .pkl files are supported")
else:
raise FileNotFoundError(f"Path '{load_data_path}' is neither a file nor a directory.")
self.wavelengths = []
keys = list(self.data.keys())
for key in keys:
# for some reason, the wavelengths can end up as string
try:
wl = int(key)
self.wavelengths.append(wl)
self.data[wl] = self.data[key]
del self.data[key]
except ValueError:
pass
self.wavelengths.sort()
# INIT WRITE MODE
self.dirname = None
self.path = None
self.name = write_data_name
if write_data_path:
self.path = os.path.abspath(os.path.expanduser(write_data_path))
if write_dirname is None:
self.dirname = sanitize_filename(datetime.datetime.now().strftime("%Y-%m-%d_%H-%M") + "_" + self.name)
else:
self.dirname = sanitize_filename(write_dirname)
self.dirpath = os.path.join(self.path, self.dirname)
if os.path.exists(self.dirpath):
if not add_number_if_dir_exists:
raise Exception(f"Directory '{self.dirname}' already exists. Provide a different directory or pass `add_number_if_dir_exists=True` to ignore this")
else:
i = 1
dirpath = f"{self.dirpath}-{i}"
while os.path.exists(dirpath):
i += 1
dirpath = f"{self.dirpath}-{i}"
print(f"Directory '{self.dirname}' already exists. Trying '{dirpath}' instead")
self.dirpath = dirpath
self._assert_directory_exists()
def __setitem__(self, key, value):
self.data[key] = value
try:
wl = int(key)
self.wavelengths.append(wl)
self.wavelengths.sort()
except ValueError:
pass
def __getitem__(self, key):
return self.data[key]
def _check_has_wavelength(self, wl):
if wl not in self.data: raise KeyError(f"No data for wavelength '{wl}'")
def get_for_wl(self, wl, key) -> float:
self._check_has_wavelength(wl)
if key in self.data[wl]:
return self.data[wl][key]
elif key == "dR_R":
return self.calculate_dR_R_for_wl(wl)[0]
elif key == "sdR_R":
return self.calculate_dR_R_for_wl(wl)[1]
elif f"{key}_raw" in self.data[wl]:
vals = self.data[wl][f"{key}_raw"]
mean = np.mean(vals)
self.data[wl][key] = mean
return mean
elif key.startswith("s") and f"{key[1:]}_raw" in self.data[wl]:
vals = self.data[wl][f"{key[1:]}_raw"]
err = np.std(vals)
self.data[wl][key] = err
return err
raise KeyError(f"No '{key}' data for wavelength '{wl}'")
def calculate_dR_R_for_wl(self, wl) -> tuple[float, float]:
dR, sdR = self.get_for_wl(wl, "dR"), self.get_for_wl(wl, "sdR")
R, sR = self.get_for_wl(wl, "R"), self.get_for_wl(wl, "sR")
dR_R = dR / R
sdR_R = np.sqrt((sdR / R)**2 + (dR * sR/R**2)**2)
self.data[wl]["dR_R"] = dR_R
self.data[wl]["sdR_R"] = sdR_R
return dR_R, sdR_R
key_names = {
"wl": "Wavelength [nm]",
"dR": "dR [V]",
"R": "R [V]",
"sR": "sigma(R) [V]",
"sdR": "sigma(dR) [V]",
"dR_R": "dR/R",
"sdR_R": "sigma(dR/R)",
"theta": "theta [°]",
"stheta": "sigma(theta) [°]"
}
default_spectrum_columns=["wl", "dR", "sdR", "R", "sR", "dR_R", "sdR_R", "theta", "stheta"]
labels = {
"dR_R": r"$\Delta R/R$",
"dR": r"$\Delta R$ [V]",
"R": r"$R$ [V]",
"theta": r"$\theta$ [°]",
"sdR_R": r"$\sigma_{\Delta R/R}$",
"sdR": r"$\sigma_{\Delta R}$ [V]",
"sR": r"$\sigma_R$ [V]",
"stheta": r"$\sigma_\theta$ [°]",
}
def get_spectrum_data(self, wavelengths=None, keys=None) -> np.ndarray:
"""
Return the spectral data for the specified keys and wavelengths.
:param wavelengths: List of wavelengths, or None to use all wavelengths.
:param keys: List of keys, or None to use dR, R, dR_R, Theta
:return: numpy.ndarray where the first index is (wavelength=0, <keys>...=1...) and the second is the wavelengths.
"""
if keys is None: keys = self.default_spectrum_columns
if wavelengths is None:
wavelengths = self.wavelengths
data = np.empty((len(wavelengths), len(keys)), dtype=float)
# this might be slow but it shouldnt be called often
for j, wl in enumerate(wavelengths):
for i, key in enumerate(keys):
if key == "wl":
data[j][i] = wl
else:
val = self.get_for_wl(wl, key)
data[j][i] = val
return data
def to_csv(self, sep=","):
# self.to_dataframe().to_csv(os.path.join(self.path, self.name + ".csv"), index=False, metadata=True)
return PrsData.get_csv(self.get_spectrum_data(), self.metadata, sep=sep)
def save_csv_at(self, filepath, sep=",", verbose=False):
if verbose: print(f"Writing csv to {filepath}")
log.info(f"Writing csv to {filepath}")
with open(filepath, "w") as file:
file.write(self.to_csv(sep=sep))
def save_csv(self, sep=",", verbose=False):
"""Save the csv inside the data directory"""
filepath = os.path.join(self.path, self.dirname + ".csv")
self.save_csv_at(filepath, sep, verbose)
# FILE IO
def _check_write_mode(self):
if self.dirpath is None:
raise RuntimeError(f"Can not write data because {__class__.__name__} is not in write mode.")
def _assert_directory_exists(self):
if not os.path.isdir(self.dirpath):
os.makedirs(self.dirpath)
def write_partial_file(self, key):
self._check_write_mode()
if key not in self.data:
raise KeyError(f"Invalid key '{key}'")
filename = sanitize_filename(PARTIAL_PREFIX + str(key)) + ".pkl"
self._assert_directory_exists()
filepath = os.path.join(self.dirpath, filename)
log.info(f"Writing data '{key}' to {filepath}")
with open(filepath, "wb") as file:
pickle.dump(self.data[key], file)
def write_full_file(self):
filename = sanitize_filename("full-data") + ".pkl"
self._assert_directory_exists()
filepath = os.path.join(self.dirpath, filename)
log.info(f"Writing data to {filepath}")
with open(filepath, "wb") as file:
pickle.dump((self.data, self.metadata), file)
def write_metadata(self):
f"""
Write the metadata to the disk as '{METADATA_FILENAME}'
"""
filepath = os.path.join(self.dirpath, METADATA_FILENAME)
log.debug(f"Writing metadata to {filepath}")
with open(filepath, "wb") as file:
pickle.dump(self.metadata, file)
# STATIC CONVERTER
@staticmethod
def get_csv(data: np.ndarray, metadata: dict, columns: list, sep=","):
csv = ""
# metadata
md_keys = list(metadata.keys())
md_keys.sort()
for k in md_keys:
v = metadata[k]
if type(v) == dict:
csv += f"# {k}:\n"
for kk, vv in v.items():
csv += f"# {kk}:{vv}\n"
if type(v) == list:
csv += f"# {k}:\n"
for vv in v:
csv += f"# - {vv}\n"
else:
csv += f"# {k}: {v}\n"
# data header
csv += "".join(f"{colname}{sep}" for colname in columns).strip(sep) + "\n"
# data
for i in range(data.shape[0]):
csv += f"{data[i, 0]}"
for j in range(1, data.shape[1]):
csv += f"{sep}{data[i,j]}"
csv += "\n"
return csv.strip("\n")
# STATIC LOADERS
# TODO
@staticmethod
def load_data_from_csv(filepath:str, sep: str=",") -> tuple[dict, dict]:
"""
Loads data from a single csv file.
Lines with this format are interpreted as metadata:
# key: value
Lines with this format are interpreted as data:
index, timestamp [s], CPD [V], LED [%]
Parameters
----------
filepath
Path to the csv file.
sep
csv separator
Returns
-------
data
2D numpy array with shape (n, 4) where n is the number of data points.
metadata
Dictionary with metadata.
"""
metadata = {}
with open(filepath, "r") as f:
# this loop will read the metadata at the beginning and skip also the header row
for line in f:
if line.startswith("#"):
colon = line.find(":")
if colon == -1: # normal comment
continue
key = line[1:colon].strip()
value = line[colon+1:].strip()
metadata[key] = value
else:
break
# here, the generator has only data lines
data = np.loadtxt(f, delimiter=sep)
return data, metadata
@classmethod
def load_data_from_pkl(cls, filepath:str) -> tuple[dict, dict]:
"""
Loads data from a single pkl file.
Parameters
----------
:param filepath Path to the file.
:return
data
2D numpy array with shape (n, 4) where n is the number of data points.
metadata
Dictionary with metadata.
"""
metadata = {}
with open(filepath, "rb") as f:
obj = pickle.load(f)
if isinstance(obj, tuple):
if not len(obj) == 2:
raise ValueError(f"Pickle file is a tuple with length {len(obj)}, however it must be 2: (data, metadata)")
data = obj[0]
metadata = obj[1]
if not isinstance(data, dict):
raise ValueError(f"First object in tuple is not a dictionary but {type(data)}")
elif isinstance(obj, dict):
data = obj
else:
raise ValueError(f"Pickled object must be either dict=data or (dict=data, dict=metadata), but is of type {type(obj)}")
# must be loaded by now
if not isinstance(metadata, dict):
raise ValueError(f"Metadata is not a of type dict")
return data, metadata
@staticmethod
def load_data_from_dir(dirpath:str) -> tuple[dict, dict]:
"""
Combines all data files with the PARTIAL_PREFIX from a directory into a dictionary
:param dirpath Path to the data directory
:return data, metadata
"""
files = os.listdir(dirpath)
files.sort()
data = {}
metadata = {}
for filename in files:
filepath = os.path.join(dirpath, filename)
if filename.startswith(PARTIAL_PREFIX):
log.debug(f"Loading {filename}")
# must be first
if filename == METADATA_FILENAME: # Metadata filename must also start with FLUSH_PREFIX
with open(filepath, "rb") as file:
metadata = pickle.load(file)
elif filename.endswith(".csv"):
raise NotADirectoryError(f"Partial .csv files are not supported: '{filename}'")
elif filename.endswith(".pkl"):
key = filename.strip(PARTIAL_PREFIX).strip(".pkl")
with open(filepath, "rb") as file:
val = pickle.load(file)
data[key] = val
else:
raise NotImplementedError(f"Unknown file extension for file '{filepath}'")
else:
log.info(f"Skipping unknown file: '{filepath}'")
return data, metadata
def plot_raw_for_wl(self, wl, what=["dR", "R", "theta"]):
self._check_has_wavelength(wl)
fig, ax = plt.subplots(len(what)) # no sharex since theta might have less points
ax[-1].set_xlabel("Index")
fig.suptitle(f"Raw data for $\\lambda = {wl}$ nm")
for i, qty in enumerate(what):
ax[i].set_ylabel(PrsData.labels[qty])
ax[i].plot(self.data[wl][f"{qty}_raw"])
fig.tight_layout()
return fig
def plot_spectrum(data: str or pd.DataFrame or np.ndarray, title: str="", errors=False, what=["dR_R"]):
"""
Plot recorded data
Parameters
----------
data : str or np.ndarray
Path to the data directory or
numpy array with columns PrsData.default_spectrum_columns
:param title : str, optional
Title for the plot. The default is "".
:return fig Matplotlib figure object.
"""
if type(data) == str:
prsdata = PrsData(data)
_data = prsdata.get_spectrum_data()
elif isinstance(data, PrsData):
_data = data.get_spectrum_data()
else:
_data = data
n_plots = len(what)
if errors: n_plots *= 2
fig, ax = plt.subplots(n_plots, 1, sharex=True)
ax[-1].set_xlabel("Wavelength [nm]")
i = 0
colors = {
"dR_R": "red",
"dR": "green",
"R": "blue",
"theta": "magenta",
}
wl_idx = PrsData.default_spectrum_columns.index("wl")
for key in what:
key_and_err = [key, f"s{key}"] if errors else [key]
for k in key_and_err:
data_idx = PrsData.default_spectrum_columns.index(k)
ax[i].plot(_data[:,wl_idx], _data[:,data_idx], color=colors[key])
ax[i].set_ylabel(PrsData.labels[k])
i += 1
if title:
fig.suptitle(title)
fig.tight_layout()
return fig