# Copyright 3D Control Systems, Inc. All Rights Reserved 2017-2019.
# Built in San Francisco.

# This software is distributed under a commercial license for personal,
# educational, corporate or any other use.
# The software as a whole or any parts of it is prohibited for distribution or
# use without obtaining a license from 3D Control Systems, Inc.

# All software licenses are subject to the 3DPrinterOS terms of use
# (available at https://www.3dprinteros.com/terms-and-conditions/),
# and privacy policy (available at https://www.3dprinteros.com/privacy-policy/)

import sys
import time
import logging
import threading

import config
import paths
import usb.core
import usb.util
import usb.backend.libusb1


class PyUSBConnection:

    #TODO rename to DEFUALT____ after checking all modules that could use them
    READ_TIMEOUT = 500  # in ms
    WRITE_TIMEOUT = 1000

    def __init__(self, usb_info, profile, verbose=False):
        self.logger = logging.getLogger('app.' + __name__)
        self.usb_info = usb_info
        self.int_vid = int(usb_info['VID'], 16)
        self.int_pid = int(usb_info['PID'], 16)
        self.interface_indexes = profile.get('interface_indexes', (0,0))
        self.endpoint_out = profile.get('endpoint_out', None)
        self.endpoint_in = profile.get('endpoint_in', None)
        self.file_transfer_endpoint_out = profile.get("file_endpoint_out", None)
        if self.endpoint_out:
            self.endpoint_out = int(self.endpoint_out, 16)
        if self.endpoint_in:
            self.endpoint_in = int(self.endpoint_in, 16)
        if self.file_transfer_endpoint_out:
            self.file_transfer_endpoint_out = int(self.file_transfer_endpoint_out, 16)
        self.write_packet_length = profile.get('write_packet_length', None)
        self.read_packet_length = profile.get('read_packet_length', None)
        self.driver_detach_needed = profile.get('driver_detach_needed', False)
        self.read_timeout = profile.get('read_timeout', self.READ_TIMEOUT)
        self.write_timeout = profile.get('write_timeout', self.WRITE_TIMEOUT)
        self.send_lock = threading.Lock()
        self.recv_lock = threading.Lock()
        self.verbose = config.get_settings()['verbose']
        self.dev = None
        counter = 5
        while not self.dev:
            if not counter:
                raise RuntimeError("Can't connect to printer")
            self.dev = self.connect()
            counter -= 1
            time.sleep(1)

    # TODO check this logic if it even working. interface indexes could be implemented wrong
    def detach_driver_and_set_cfg(self, dev):
        if sys.platform.startswith('linux'):
            if dev.is_kernel_driver_active(self.interface_indexes[0]) is True:
                self.logger.info('Interface is kernel active. Detaching...')
                claim_attempts = 5
                for _ in range(claim_attempts):
                    try:
                        dev.detach_kernel_driver(self.interface_indexes[0])
                        dev.set_configuration()
                        usb.util.claim_interface(dev, self.interface_indexes[0])
                    except Exception as e:
                        logging.warning('Exception while detaching : %s' % str(e))
                    else:
                        if dev.is_kernel_driver_active(self.interface_indexes[0]) is True:
                            self.logger.info('Can\'t detach USB device. Attempting once more...')
                        else:
                            self.logger.info('Detached and claimed!')
                            break
            else:
                self.logger.info('Interface is free. Connecting...')
            if dev.is_kernel_driver_active(self.interface_indexes[0]) is True:
                self.logger.warning('Cannot claim USB device. Aborting.')
                return False

    def connect(self):
        backend_from_our_directory = usb.backend.libusb1.get_backend(find_library=paths.get_libusb_path)
        dev = usb.core.find(idVendor=self.int_vid, idProduct=self.int_pid, backend=backend_from_our_directory)
        if dev:
            if self.driver_detach_needed:
                self.detach_driver_and_set_cfg(dev)
            try:
                dev.set_configuration()
            except usb.USBError:
                if self.verbose:
                    self.logger.exception("Error connecting to pyusb device")
                dev.reset()
            else:
                configuration = dev.get_active_configuration()
                interface = configuration[self.interface_indexes]
                if self.endpoint_in is None or self.endpoint_out is None or self.file_transfer_endpoint_out is None:
                    for endpoint in interface:
                        if usb.util.endpoint_direction(endpoint.bEndpointAddress) == usb.util.ENDPOINT_OUT:
                            if not self.write_packet_length:
                                self.write_packet_length = endpoint.wMaxPacketSize
                            if self.endpoint_out is None:
                                self.endpoint_out = endpoint.bEndpointAddress
                            elif self.file_transfer_endpoint_out is None and\
                                            endpoint.bEndpointAddress != self.endpoint_out:
                                self.file_transfer_endpoint_out = endpoint.bEndpointAddress
                        elif usb.util.endpoint_direction(endpoint.bEndpointAddress) == usb.util.ENDPOINT_IN:
                            if not self.read_packet_length:
                                self.read_packet_length = endpoint.wMaxPacketSize
                            if self.endpoint_in is None:
                                self.endpoint_in = endpoint.bEndpointAddress
                        if self.endpoint_out is not None and self.endpoint_in is not None and\
                                        self.file_transfer_endpoint_out is not None:
                            break
                    else:
                        message = "Can't autodetect endpoint, you need to set them in printer profile."
                        self.logger.error(message)
                        return
                self.logger.info('USB endpoints set:\nOUT:%s\tpacket length:%s\nIN:%s\tpacket_length:%s\nfileOUT:%s' %
                                 (hex(self.endpoint_out), self.write_packet_length, hex(self.endpoint_in),
                                  self.read_packet_length, hex(self.file_transfer_endpoint_out)))
                return dev

    def format_message(self, raw_message):
        return str(raw_message) + "\r\n"

    def send(self, message, endpoint=None, raw=False, timeout=None):
        with self.send_lock:
            if self.dev:
                if not endpoint:
                    endpoint = self.endpoint_out
                if not raw:
                    message = self.format_message(message)
                if not timeout:
                    timeout = self.write_timeout
                if self.verbose:
                    self.logger.debug("SEND to endpoint:%s timeout:%d\n%s" % (endpoint, timeout, message.strip()))
                if type(message) == str:
                    message = message.encode("utf-8")
                try:
                    self.dev.write(endpoint, message, timeout)
                except usb.core.USBError as e:
                    # self.logger.info(format(e))
                    if "110" in str(e) or "Operation timed out" in str(e):
                        self.logger.info("Write timeout")
                    else:
                        self.logger.info('USBError on write: %s' % str(e))
                except Exception as e:
                    self.logger.warning('Error while writing data "%s"\nError: %s' % (message, str(e)), exc_info=True)
                else:
                    return True

    def recv(self, size = None, endpoint = None):
        with self.recv_lock:
            if self.dev:
                if not endpoint:
                    endpoint = self.endpoint_in
                try:
                    if not size:
                        size = self.read_packet_length
                    answer_array = self.dev.read(endpoint, size, self.read_timeout)
                    answer_str = ''.join([chr(x) for x in answer_array])
                    if self.verbose:
                        self.logger.debug("RECV:" + str(answer_str))
                except usb.core.USBError as e:
                    # self.logger.info(format(e))
                    if "110" in str(e) or "Operation timed out" in str(e):
                        if self.verbose:
                            self.logger.info("Read timeout")
                    else:
                        self.logger.info('USBError on read: %s' % str(e))
                except Exception as e:
                    self.logger.warning('Error while reading gcode: %s' % str(e), exc_info=True)
                else:
                    return answer_str

    def reset(self):
        self.dev.reset()

    def dispose(self):
        usb.util.dispose_resources(self.dev)

    def close(self):
        self.logger.info("Closing USB device connection...")
        with self.send_lock:
            with self.recv_lock:
                if self.dev:
                    try:
                        self.dev.reset()
                    except:
                        pass
                    self.dev = None
                    self.logger.info("... device closed")

