# Copyright 3D Control Systems, Inc. All Rights Reserved 2017-2019.
# Built in San Francisco.

# This software is distributed under a commercial license for personal,
# educational, corporate or any other use.
# The software as a whole or any parts of it is prohibited for distribution or
# use without obtaining a license from 3D Control Systems, Inc.

# All software licenses are subject to the 3DPrinterOS terms of use
# (available at https://www.3dprinteros.com/terms-and-conditions/),
# and privacy policy (available at https://www.3dprinteros.com/privacy-policy/)

import logging
import time
import threading
import sys
import socket
from typing import Any

import ffd_sender
import dual_cam
import config


if sys.version_info < (3,10):
    TimeoutExceptionClass = socket.timeout
else:
    TimeoutExceptionClass = TimeoutError


class SocketConnection:

    CONNECTION_TYPE = 'socket'
    BUFFER_SIZE = 1024
    DEFAULT_TIMEOUT = 1
    FAIL_WAIT = 1
    RETRIES = 3
    ENDOFLINE = b'\n'

    def __init__(self, ip, port, timeout=None, logger=None):
        if logger:
            self.logger = logger.getChild(self.__class__.__name__)
        else:
            self.logger = logging.getLogger(self.__class__.__name__)
        self.ip = ip
        self.port = port
        if timeout is None:
            self.timeout = self.DEFAULT_TIMEOUT
        else:
            self.timeout = timeout
        self.recv_lock = threading.RLock()
        self.send_lock = threading.RLock()
        self.verbose = config.get_settings()['verbose']
        self.connection = None
        self.connect()

    def connect(self):
        try:
            self.connection = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            self.connection.settimeout(self.timeout)
            self.connection.connect((self.ip, self.port))
        except IOError as e:
            self.logger.error("Error connecting to %s:%s - %s" % (self.ip, self.port, e))

    def recv(self):
        fails = 0
        while fails < self.RETRIES:
            with self.recv_lock:
                if not self.connection:
                    self.logger.warning("Can't send - no connection to host")
                    self.connect()
                data = None
                try:
                    if self.connection:
                        data = self._read_from_connection()
                    else:
                        raise OSError('No connection')
                    if self.verbose:
                        self.logger.info("RECV: %s" % data)
                except TimeoutExceptionClass:
                    self.logger.warning("Timeout while reading")
                    return data
                except (OSError, EOFError) as e:
                    self.logger.warning("Read error: " + str(e))
                except Exception as e:
                    self.logger.error("Unexpected error while reading: " + str(e))
                else:
                    return data
                fails += 1
                time.sleep(self.FAIL_WAIT)
                self.close()

    def send(self, data, raw=False, timeout=None):
        # restore_timeout = False
        # if timeout is not None and timeout != self.timeout:
        #     restore_timeout = True
        #     try:
        #         self.connection.settimeout(timeout)
        #     except:
        #         pass
        if not raw:
            data = self.prepare_data(data)
        try:
            with self.send_lock:
                if self.connection:
                    self._write_to_connection(data)
                    if self.verbose:
                        self.logger.info("SEND: %s" % data.strip())

        except TimeoutExceptionClass:
            self.logger.warning("Timeout while writing")
        except OSError as e:
            self.logger.warning("Write error: " + str(e))
        except Exception as e:
            self.logger.error("Unexpected error while writing: " + str(e))
        else:
            return True
        # finally:
        #     if restore_timeout:
        #         try:
        #             self.connection.settimeout(self.timeout)
        #         except:
        #             pass

    def reset(self):
        with self.send_lock:
            with self.recv_lock:
                try:
                    self.connection.close()
                except:
                    pass
                self.connect()

    def close(self):
        self.logger.info("Closing connection")
        with self.send_lock:
            with self.recv_lock:
                if self.connection:
                    try:
                        self.connection.close()
                        self.logger.info("Connection closed")
                    except:
                        pass
                    self.connection = None

    def prepare_data(self, data):
        return data

    def _read_from_connection(self):
        data = None
        while True:
            t = time.monotonic()
            chunk = self.connection.recv(self.BUFFER_SIZE)
            if chunk:
                if not data:
                    data = chunk
                else:
                    data += chunk
            elif chunk == b'' and time.monotonic() - t < 0.001:
                time.sleep(self.FAIL_WAIT)
            if len(chunk) < self.BUFFER_SIZE:
                break
        return data

    def _write_to_connection(self, data):
        try:
            self.connection.sendall(data)
        except TimeoutExceptionClass as e:
            raise e
        except OSError as e:
            time.sleep(self.FAIL_WAIT)
            raise e


class FFSocketConnection(SocketConnection):

    DEFAULT_TIMEOUT = 5

    def __init__(self, ip, port, timeout=None, logger=None):
        super().__init__(ip, port, timeout, logger)
        self.endpoint_out = None
        self.endpoint_in = None
        self.file_transfer_endpoint_out = None

    def format_message(self, raw_message):
        return b"~" + raw_message + b"\r\n"

    def send(self, message, endpoint=None, raw=False, timeout=None):
        if not raw:
            message = self.format_message(message)
        if timeout:
            timeout = timeout/1000
        return super().send(message, True, timeout=timeout)

    def recv(self, size=None, endpoint=None):
        return super().recv()


class Camera(dual_cam.Camera):

    def search_cameras(self):
        self.captures = []
        self.init_capture(self.url)

    def set_url(self, ip):
        self.ip = ip
        self.url = 'http://'+ip+':8080/?action=stream'
        self.cloud_camera_number = self.get_number_by_ip()

    def get_number_by_ip(self):
        numbers = [i.zfill(3) for i in self.ip.split(".")]
        try:
            number = int(''.join(numbers))
        except ValueError:
            number = 0
        return number

    def get_camera_number_for_cloud(self):
        return self.cloud_camera_number


class Sender(ffd_sender.Sender):

    RECONNECT_PAUSE = 10

    def __init__(self, parent: Any, usb_info: dict, profile: dict):
        self.ip = usb_info.get("IP", "")
        super().__init__(parent, usb_info, profile)

    def _create_connection(self, usb_info, profile):
        self.connection = FFSocketConnection(self.ip, 8899, None, self.logger)

    def _read_monitoring_data(self):
        result = super()._read_monitoring_data()
        if not result:
            self.operational_flag = False
            self.logger.warning("Error on reading monitoring data. Trying to reconnect to the printer.")
            try:
                self.connection.close()
            except Exception as e:
                self.logger.warning("Error on connection close: %s" % str(e))
            counter = self.RECONNECT_PAUSE
            while counter:
                if self.stop_flag:
                    break
                time.sleep(1)
                counter -= 1
            try:
                self.connection.connect()
                self._handshake()
                self._connect()
                result = super()._read_monitoring_data()
            except Exception as e:
                self.logger.warning("Error on reconnect attempt: %s" % str(e))
        return result

    def _start_camera(self) -> None:
        if not self.camera:
            self.camera = Camera(False)
            self.camera.set_url(self.ip)
            self.camera.search_cameras()
        if self.camera_thread:
            if self.camera_thread.is_alive():
                return
            del self.camera_thread
        self.camera.stop_flag = False
        self.camera_thread = threading.Thread(target=self.camera.main_loop)
        self.camera_thread.start()

    def _stop_camera(self) -> None:
        if self.camera_thread and self.camera_thread.is_alive():
            self.camera.stop_flag = True
