# Copyright 3D Control Systems, Inc. All Rights Reserved 2017-2019.
# Built in San Francisco.

# This software is distributed under a commercial license for personal,
# educational, corporate or any other use.
# The software as a whole or any parts of it is prohibited for distribution or
# use without obtaining a license from 3D Control Systems, Inc.

# All software licenses are subject to the 3DPrinterOS terms of use
# (available at https://www.3dprinteros.com/terms-and-conditions/),
# and privacy policy (available at https://www.3dprinteros.com/privacy-policy/)

import re
import time
import base64
import logging
import threading
import collections


import config


class BaseSender:

    def __init__(self, parent, usb_info, profile):
        self.logger = logging.getLogger(__name__)
        self.logger.setLevel(logging.DEBUG)
        self.logger.propagate = False
        self.stop_flag = False
        self.parent = parent
        self.profile = profile
        self.usb_info = usb_info
        if not getattr(self, "temps", False):
            self.temps = [0,0]
            self.target_temps = [0,0]
        self.position = [0, 0, 0, 0]  # X, Y, Z, E
        self.total_gcodes = None
        self.buffer = collections.deque()
        self.current_line_number = 0
        self.percent = 0
        self.operational_flag = False
        self.printing_flag = False
        self.pause_flag = False
        self.heating = False #TODO rename to heating_flag after merge
        self.callbacks_lock = threading.Lock()
        self.response_callbacks = [] # functions list to call on printer_response. if not empty it will suppress tempratures requesting.
        self.responses_planned = 0
        self.responses = []
        self.filename = None
        self.est_print_time = 0 # in seconds
        self.print_time_left = 0 # in seconds
        self.average_printing_speed = 0 # percents per second
        self.override_clouds_estimations = config.get_settings().get('print_estimation', {}).get('override_clouds', False)
        self.allow_increase_of_print_time_left = config.get_settings().get('print_estimation', {}).get('allow_rise_time_left', False)
        #self.heating_start_time = None
        #self.last_pause_time = None
        #self.print_start_time = 0.0
        #self.sum_pause_and_heating_duration = 0 # in seconds
        #self.pause_and_heating_duration_lock = threading.Lock()
        self.is_base64_re = re.compile(b"^([A-Za-z0-9+/]{4})*([A-Za-z0-9+/]{4}|[A-Za-z0-9+/]{3}=|[A-Za-z0-9+/]{2}==)$")

    def set_total_gcodes(self, length):
        raise NotImplementedError

    def load_gcodes(self, gcodes):
        raise NotImplementedError

    def unbuffered_gcodes(self, gcodes):
        raise NotImplementedError

    def cancel(self):
        self.parent.register_error(605, "Cancel is not supported for this printer type", is_blocking=False)
        return False

    def preprocess_gcodes(self, gcodes):
        gcodes = gcodes.replace(b"\r", b"")
        gcodes = gcodes.split(b"\n")
        gcodes = [item.decode("utf-8") for item in gcodes if item]
        if gcodes:
            while gcodes[-1] in ("\n", "\n", "\t", " ", "", None):
                line = gcodes.pop()
                self.logger.info("Removing corrupted line '%s' from gcodes tail" % line)
        self.logger.info('Got %d gcodes to print.' % len(gcodes))
        return gcodes

    def gcodes(self, gcodes):
        start_time = time.monotonic()
        self.logger.debug("Determining gcode format")
        if self.is_base64_re.match(gcodes):
            gcodes = base64.b64decode(gcodes)
            self.unbuffered_gcodes(gcodes)
        else:
            self.logger.debug("Start loading gcodes. Determination time:" + str(time.monotonic() - start_time))
            if self.load_gcodes(gcodes) != False: # None is equal to True here
                self.print_start_time = time.monotonic()
            self.logger.debug("Done loading gcodes. Time:" + str(time.monotonic() - start_time))

    def set_filename(self, filename):
        self.filename = str(filename) if filename else None

    def set_estimated_print_time(self, est_print_time):
        estimation = 0
        try:
            estimation = int(est_print_time)
        except (TypeError, ValueError):
            self.logger.warning(f"Given estimated time can not be converted to integer: {est_print_time}")
        self.est_print_time = estimation
        self.print_time_left = estimation
        self.logger.info(f"Setting estimated print duration to {estimation} seconds")
        #self.sum_pause_and_heating_duration = 0
        #self.print_start_time = time.monotonic()
        #self.last_remaining_print_time_get = time.monotonic()

    def set_average_printing_speed(self, speed):
        self.logger.info(f"Average printing speed: {speed}")
        self.average_printing_speed = speed

    def get_remaining_print_time(self, ignore_state=False):
        time_left = 0
        if self.is_printing() or ignore_state:
            if self.est_print_time:
                time_left = int(self.est_print_time - self.est_print_time * self.get_percent() / 100)
            if self.average_printing_speed:
                if not time_left or self.override_clouds_estimations:
                    time_left = int((100 - self.get_percent()) / self.average_printing_speed)
        if self.print_time_left and time_left > self.print_time_left and not self.allow_increase_of_print_time_left:
            time_left = self.print_time_left
        self.logger.info(f"Time left:{time_left}")
        self.print_time_left = time_left
        return time_left

    def get_remaining_print_time_string(self, seconds=None):
        time_string = ""
        if seconds == None:
            seconds = self.get_remaining_print_time()
        if seconds:
            hours, minutes = divmod(seconds // 60, 60)
            if hours:
                time_string += f"{hours} hour"
                if hours != 1:
                    time_string += "s"
            if minutes:
                if time_string:
                    time_string += " "
                time_string += f"{minutes} minute"
                if minutes != 1:
                    time_string += "s"
            if seconds and not hours and not minutes:
                time_string = "less than a minute"
            self.logger.debug("Remaining print time " + time_string)
        return time_string

    #  def get_remaining_print_time(self, ignore_state=False):
    #      if self.is_printing() or ignore_state:
    #          if self.print_time_left and self.est_print_time:
    #              now = time.monotonic()
    #              with self.pause_and_heating_duration_lock:
    #                  if self.heating_start_time:
    #                      self.sum_pause_and_heating_duration += now - self.heating_start_time
    #                      self.heating_start_time = None
    #                  if self.pause_flag:
    #                      if self.last_pause_time:
    #                          self.update_pause_time_and_duration()
    #                  else:
    #                      if self.last_pause_time:
    #                          self.update_unpause_time_and_duration()
    #              #time_progress_relation = int((self.print_start_time + self.sum_pause_and_heating_duration - ) /  print_time_left)
    #              #time_coef = (((self.estimated_print_time - elapsed_time)/self.estimated_print_time) * (1 - self.get_percent() / 100) ** 0.5
    #              time_relation = (self.est_print_time + self.sum_pause_and_heating_duration - self.print_time_left) / self.est_print_time
    #              progress_relation = (100 - self.get_percent()) / 100
    #              if progress_relation:
    #                  speed_coefficient = time_relation / progress_relation
    #              else:
    #                  speed_coefficient = 1
    #              delta_time = now - self.last_remaining_print_time_get
    #              self.last_remaining_print_time_get = now
    #              self.print_time_left -= int(delta_time * speed_coefficient)
    #              if self.print_time_left < 0:
    #                  self.print_time_left = 0
    #              self.logger.info("Remaining print time %s second" % self.print_time_left)
    #              return self.print_time_left 
    #      return 0

    def get_position(self):
        return self.position

    def get_temps(self):
        return self.temps

    def get_target_temps(self):
        return self.target_temps

    def get_percent(self):
        return self.percent

    #  def update_pause_time_and_duration(self):
    #      with self.pause_and_heating_duration_lock:
    #          now = time.monotonic()
    #          if self.last_pause_time is not None:
    #              self.sum_pause_and_heating_duration += now - self.last_pause_time
    #          self.last_pause_time = now

    #  def update_unpause_time_and_duration(self):
    #      with self.pause_and_heating_duration_lock:
    #          now = time.monotonic()
    #          self.sum_pause_duration = now - self.last_pause_time 
    #          self.last_pause_time = None

    def pause(self):
        self.pause_flag = True
        #self.update_pause_time_and_duration()

    def unpause(self):
        self.pause_flag = False
        #self.update_unpause_time_and_duration()

    def is_printing(self):
        return self.printing_flag

    def is_paused(self):
        return self.pause_flag

    def is_operational(self):
        return False

    def is_heating(self):
        return self.heating

    def get_downloading_percent(self):
        return self.parent.downloader.get_percent()

    def get_nonstandart_data(self):
        return {}

    def execute_callback(self, line, success):
        for callback in self.response_callbacks:
            try:
                callback(line, success)
            except:
                self.logger.exception("Exception in callback(%s):" % str(callback))

    def add_response_callback(self, callback_function):
        self.logger.info("Adding callback: %s" % callback_function)
        with self.callbacks_lock:
            #if not callback_function in self.response_callbacks:
                self.response_callbacks.append(callback_function)
                self.logger.info("Callback added: %s" % callback_function)

    def del_response_callback(self, callback_function):
        self.logger.info("Removing callback: %s" % callback_function)
        with self.callbacks_lock:
            self.response_callbacks.remove(callback_function)
            self.logger.info("Callback removed: %s" % callback_function)

    def flush_response_callbacks(self):
        with self.callbacks_lock:
            for callback in self.response_callbacks:
                try:
                   self.response_callbacks.remove(callback)
                   self.logger.info("Callback removed: %s" % callback)
                except ValueError:
                    pass

    def init_speed_calculation_thread(self):
        if config.get_settings().get('print_estimation', {}).get('by_print_speed'):
            self.logger.info("Starting print speed calculation thread")
            self.speed_calculation_thread = SpeedCalculationThread(self)
            self.speed_calculation_thread.start()
        else:
            self.logger.info("Print speed calculation is disabled. No thread start")

    def close(self):
        self.stop_flag = True
        if hasattr(self, 'speed_calculation_thread'):
            self.logger.info("Joining estimation thread...")
            self.speed_calculation_thread.join(self.speed_calculation_thread.LOOP_TIME)
            self.logger.info("...estimation thread joined")


class SpeedCalculationThread(threading.Thread):

    LOOP_STEPS = 100
    LOOP_TIME = 6 # seconds
    SPEEDS_QUEUE_LEN = 24

    def __init__(self, base_sender):
        self.base_sender = base_sender
        self.stop_flag = False
        self.speeds_log = collections.deque(maxlen=self.SPEEDS_QUEUE_LEN)
        self.logger = logging.getLogger(__class__.__name__)
        super().__init__()

    def get_average_speed(self):
        if len(self.speeds_log) == self.SPEEDS_QUEUE_LEN:
            try: #NOTE could use normalize or other formulas instead of average, too increase accuracy
                return sum(self.speeds_log) / self.SPEEDS_QUEUE_LEN
            except IndexError:
                self.logger.exception("Exception while getting average print speed:")

    def run(self):
        printing_counter = 0
        nonprinting_counter = 0
        sleep = self.LOOP_TIME / self.LOOP_STEPS
        last_time = time.monotonic()
        last_percent = 0.0 
        self.logger.info('Entering speed calculation loop')
        while not self.stop_flag and not self.base_sender.stop_flag:
            if self.base_sender.is_operational() and self.base_sender.is_printing():
                printing_counter += 1
            else:
                nonprinting_counter += 1
            if nonprinting_counter >= self.LOOP_STEPS:
                nonprinting_counter = 0
                if self.speeds_log:
                    self.speeds_log.clear()
                    self.base_sender.set_average_printing_speed(0)
            if printing_counter < self.LOOP_STEPS:
                time.sleep(sleep)
            else:
                printing_counter = 0
                if self.base_sender.is_printing() and not self.base_sender.is_paused() and \
                        not self.base_sender.is_heating():
                    percent = self.base_sender.get_percent()
                    delta_time = time.monotonic() - last_time
                    if percent and delta_time:
                        speed = (percent - last_percent) / delta_time
                        self.logger.info(f'Print speed: {speed} %/s')
                        self.speeds_log.append(speed)
                        avg_speed = self.get_average_speed()
                        if avg_speed:
                            self.base_sender.set_average_printing_speed(avg_speed)
                        self.logger.info(f"Delta:{delta_time} Speed:{speed} Avg:{avg_speed}")
                last_percent = self.base_sender.get_percent()
                last_time = time.monotonic()

#  def test_get_remaining_print_time():
#      import unittest.mock as mock

#      EST_PRINT_TIME = 15310

#      logging.basicConfig()
#      sender = BaseSender(mock.Mock() , {}, {})
#      sender.set_estimated_print_time(EST_PRINT_TIME)
#      sender.printing_flag = True
#      seconds = sender.get_remaining_print_time()
#      #print(sender.get_remaining_print_time_string())
#      assert seconds == EST_PRINT_TIME
#      sender.percent = 50
#      #print(sender.get_remaining_print_time_string())
#      assert seconds == EST_PRINT_TIME / 50


#  if __name__ == "__main__":
#      test_get_remaining_print_time()

