# Copyright 3D Control Systems, Inc. All Rights Reserved 2017-2019.
# Built in San Francisco.

# This software is distributed under a commercial license for personal,
# educational, corporate or any other use.
# The software as a whole or any parts of it is prohibited for distribution or
# use without obtaining a license from 3D Control Systems, Inc.

# All software licenses are subject to the 3DPrinterOS terms of use
# (available at https://www.3dprinteros.com/terms-and-conditions/),
# and privacy policy (available at https://www.3dprinteros.com/privacy-policy/)

import logging
import platform
import subprocess
import time
import threading


import serial

import config


class SerialConnection:

    DEFAULT_READ_TIMEOUT = 0.5
    DEFAULT_WRITE_TIMEOUT = 0.5
    SERIAL_ERROR_WAIT = 0.3
    MAX_LINE_SIZE = 2048 #protection against endless line when baudrate mismatch
    WRITE_RETRIES = 3
    DTR_RESET_PAUSE = 0.2

    def __init__(self, port_name, baudrate, timeout=DEFAULT_READ_TIMEOUT, start_dtr=None, logger=None):
        if logger:
            self.logger = logger.getChild(self.__class__.__name__)
        else:
            self.logger = logging.getLogger(self.__class__.__name__ + "." + str(port_name))
        self.port_name = port_name
        self.baudrate = baudrate
        self.read_timeout = timeout
        self.write_timeout = self.DEFAULT_WRITE_TIMEOUT
        self.start_dtr = start_dtr
        self.port_recv_lock = threading.Lock()
        self.port_send_lock = threading.Lock()
        self.port_creation_wait = float(config.get_settings().get('port_creation_wait', 1.0))
        self.verbose = config.get_settings()['verbose']
        self.port = None
        self.connect()

    def connect(self):
        if platform.system() == "Linux":
            subprocess.run(["stty", "-F", self.port_name, "-hup"])
        try:
            if config.get_settings().get('port_parity_hack'):
                port = serial.Serial(port=self.port_name, baudrate=self.baudrate,
                                     timeout=self.read_timeout, parity=serial.PARITY_ODD)
                port.close()
                port.parity = serial.PARITY_NONE
                port.open()
                self.port = port
            else:
                self.port = serial.Serial(port=self.port_name, baudrate=self.baudrate, timeout=self.read_timeout, write_timeout=self.write_timeout)
        except serial.SerialException as e:
            self.logger.warning("Can't open serial port %s. Error:%s" % (self.port_name, str(e)))
        except Exception as e:
            self.logger.warning("Unexpected error while open serial port %s:%s" % (self.port_name, str(e)))
        else:
            if self.start_dtr is not None:
                try:
                    self.port.setDTR(self.start_dtr)
                except:
                    pass
            self.logger.info("Opened serial port %s at baudrate %d" % (self.port_name, self.baudrate))
            time.sleep(self.port_creation_wait)

    def recv(self, size=None):
        try:
            with self.port_recv_lock:
                if not self.port:
                    self.logger.warning("Can't perform the read - no connection to serial port")
                    return
                if size:
                    data = self.port.read(size)
                else:
                    data = self.port.readline(self.MAX_LINE_SIZE)
            if self.verbose and data is not None:
                self.logger.info("RECV: %s", data.decode(errors='ignore'))
        except serial.SerialException as e:
            self.logger.warning("Can't read serial port %s. Error:%s" % (self.port_name, str(e)))
            time.sleep(self.SERIAL_ERROR_WAIT)
        except Exception as e:
            self.logger.error("Unexpected error while reading serial port %s:%s" % (self.port_name, str(e)))
            time.sleep(self.SERIAL_ERROR_WAIT)
        else:
            return data

    def prepare_data(self, data):
        return data.strip() + b"\n"

    def send(self, data, raw=False):
        if not raw:
            data = self.prepare_data(data)
        bytes_send = 0
        fails = 0
        data_len = len(data)
        with self.port_send_lock:
            while fails <= self.WRITE_RETRIES:
                if not self.port:
                    self.logger.warning("Can't perform the write - no connection to serial port")
                    break
                sent = 0
                try:
                    sent = self.port.write(data[bytes_send:])
                    bytes_send += sent
                    if self.verbose and not raw:
                        self.logger.info("SEND: %s", data[:bytes_send].decode(errors='ignore'))
                except serial.SerialException as e:
                    self.logger.warning("Can't write serial port %s. Error:%s" % (self.port_name, str(e)))
                    time.sleep(self.SERIAL_ERROR_WAIT)
                except Exception as e:
                    self.logger.error("Unexpected error on write to serial port %s. Error: %s" % (self.port_name, str(e)))
                    time.sleep(self.SERIAL_ERROR_WAIT)
                else:
                    #  if bytes_send == data_len:
                    if bytes_send >= data_len: # in insanity we trust 
                        return True
                if sent:
                    fails = 0
                else:
                    fails += 1
            self.logger.warning("Error. Only %s bytes out of %s were send." % (bytes_send, data_len))
            return False

    def flush_recv(self):
        with self.port_recv_lock:
            if self.port:
                self.port.reset_input_buffer()

    def reset(self):
        with self.port_send_lock:
            with self.port_recv_lock:
                if self.port:
                    try:
                        self.port.setDTR(1)
                        time.sleep(self.DTR_RESET_PAUSE)
                        self.port.setDTR(0)
                    except (OSError, IOError):
                        self.logger.error('Error while setting DTR for port ' + str(self.port_name))

    def close(self):
        with self.port_send_lock:
            with self.port_recv_lock:
                if self.port:
                    self.port.close()
                    self.port = None
