import numpy as np
import pyqtgraph as pg
from PyQt6.QtWidgets import QWidget


import logging
log = logging.getLogger(__name__)

class Plot(pg.GraphicsLayoutWidget):
    """
    pyqtgraph plot widget for showing voltage and LED vs time

    To maintain performance, only a fixed number of data points is stored.
    Whenever this buffer size is reached, half of the data points are removed
    and some new data points are skipped to maintain a constant interval.

    This means that accuracy is lost when zooming in!
    """
    def __init__(self, max_data_points=1000):
        super().__init__()

        self.setWindowTitle("CPD - LED")
        # self.x_axis.setLabel("Time [s]")
        # self.v_axis.setLabel("CPD [V]")
        # self.l_axis.setLabel("LED [%]")
        # self.layout.addItem(self.x_axis, row=1, col=1, rowspan=1, colspan=1)
        # self.layout.addItem(self.v_axis, row=1, col=1, rowspan=1, colspan=1)
        self.v_plot_item = pg.PlotItem()
        self.v_box = self.v_plot_item.vb
        self.addItem(self.v_plot_item, row=1, col=1)

        self.v_plot_item.showAxis('right')
        self.v_plot_item.setLabel('right', "LED [%]")
        self.v_plot_item.setLabel('left', "CPD [V]")
        self.v_plot_item.setLabel('bottom', "time [s]")

        self.l_box = pg.ViewBox()
        self.v_plot_item.scene().addItem(self.l_box)
        self.v_plot_item.getAxis('right').linkToView(self.l_box)
        self.l_box.setXLink(self.v_plot_item)
        self.l_box.setYRange(0, 110)

        self.setBackground("w")
        # self.setTitle("Test")
        # self.setLabel("bottom", "time [s]")
        # self.setLabel("left", "Voltage [V]")
        # self.setLabel("right", "LED [%]")
        # self.getAxis("right").setRange(0, 110)  # Adding some margin
        # self.showGrid(x=True, y=True)
        self.MAX_DATA_POINTS = max_data_points
        self.data_t = None
        self.data_v = None
        self.data_l = None
        self.n_data_array = None
        self.n_data_total = None
        self.n_data_take_every_nth = None

        self.v_line = pg.PlotCurveItem(
            [],
            [],
            pen=pg.mkPen("b", width=2),
            symbol="o",
            symbolSize=5,
            symbolBrush="b",
        )
        self.l_line = pg.PlotCurveItem(
            [],
            [],
            pen=pg.mkPen("r", width=2)
        )
        self.l_box.addItem(self.l_line)
        self.v_box.addItem(self.v_line)

        self.clear_data()

        self.update_views()
        self.v_plot_item.getViewBox().sigResized.connect(self.update_views)

    def update_views(self):
        """
        Make sure the linked view boxes have the correct size
        (called after resizing of the plot)
        """
        self.l_box.setGeometry(self.v_plot_item.getViewBox().sceneBoundingRect())
        self.l_box.linkedViewChanged(self.v_plot_item.getViewBox(), self.l_box.XAxis)


    def update_plot(self, time, voltage, led):
        """
        Add new data points to the plot
        Parameters
        ----------
        time
        voltage
        led
        """
        # if the array is full, keep only every second data point
        # and skip every second data point.
        if self.n_data_array == self.MAX_DATA_POINTS:
            log.debug(f"Reached {self.MAX_DATA_POINTS} data points, taking only 1/{self.n_data_take_every_nth+1} data points from now on")
            for i in range(0, self.MAX_DATA_POINTS//2):
                self.data_t[i] = self.data_t[2 * i]
                self.data_v[i] = self.data_v[2 * i]
                self.data_l[i] = self.data_l[2 * i]
            self.n_data_array = self.MAX_DATA_POINTS//2
            self.n_data_take_every_nth += 1

        # skip data points to keep a constant interval
        if self.n_data_total % self.n_data_take_every_nth == 0:
            self.data_t[self.n_data_array] = time
            self.data_v[self.n_data_array] = voltage
            self.data_l[self.n_data_array] = led
            self.n_data_array += 1
            # update the plots
            self.v_line.setData(self.data_t[:self.n_data_array], self.data_v[:self.n_data_array])
            self.l_line.setData(self.data_t[:self.n_data_array], self.data_l[:self.n_data_array])
        self.n_data_total += 1


    def set_data(self, data_t, data_v, data_l):
        """
        Set the data to be plotted, without updating the internal arrays
        """
        # self.data_t = data_t
        # self.data_v = data_v
        # self.data_l = data_l
        # self.n_data_array = len(data_t)
        # self.n_data_total = len(data_t)
        # self.n_data_take_every_nth = 1
        self.v_line.setData(data_t, data_v)
        self.l_line.setData(data_t, data_l)

    def clear_data(self):
        """
        Clear the lines and data
        """
        self.data_t = np.empty(self.MAX_DATA_POINTS, dtype=float)
        self.data_v = np.empty(self.MAX_DATA_POINTS, dtype=float)
        self.data_l = np.empty(self.MAX_DATA_POINTS, dtype=float)
        self.n_data_array = 0
        self.n_data_total = 0
        self.n_data_take_every_nth = 1
        self.v_line.setData([], [])
        self.l_line.setData([], [])