# Copyright 3D Control Systems, Inc. All Rights Reserved 2017-2019.
# Built in San Francisco.

# This software is distributed under a commercial license for personal,
# educational, corporate or any other use.
# The software as a whole or any parts of it is prohibited for distribution or
# use without obtaining a license from 3D Control Systems, Inc.

# All software licenses are subject to the 3DPrinterOS terms of use
# (available at https://www.3dprinteros.com/terms-and-conditions/),
# and privacy policy (available at https://www.3dprinteros.com/privacy-policy/)

import json
import time
import copy
import socket
import logging
import threading

import log
import config
import platforms
from functools import reduce


BROADCAST_RECV_PORT = 5168
BROADCAST_MESSAGE = b"3DPrinterOS_Detect"


class BroadcastReceiver(threading.Thread):

    RECEIVE_TIMEOUT = 1
    MAX_MESSAGE_LENGTH = 4096

    def __init__(self, parent):
        self.logger = logging.getLogger(self.__class__.__name__)
        self.parent = parent
        self.stop_flag = False
        self.clients_lock = threading.Lock()
        self.clients = {}
        threading.Thread.__init__(self, daemon=True)

    def create_socket(self):
        self.logger.debug("Creating listen socket for port %d" % BROADCAST_RECV_PORT)
        try:
            sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
            sock.bind(('', BROADCAST_RECV_PORT))
            sock.settimeout(self.RECEIVE_TIMEOUT)
        except socket.error:
            self.logger.warning("Can't start clients detection receiver:", exc_info=True)
        except OSError:
            self.logger.warning("Clients detection receiver port is already in use - disabling clients detection")
            self.stop_flag = True
        else:
            return sock

    def validate_message(self, message_dict): #check for thread-safety
        return "login" in message_dict and "printers" in message_dict and \
               "port" in message_dict and "remote_control" in message_dict and "platform" in message_dict and\
               type(message_dict['printers']) == list and\
               reduce(lambda cumulative, printer: cumulative * (type(printer.get('profile')) == dict),
                      message_dict.get('printers', []), True)

    def is_loopback(self, IP):
        return False #TODO find more reliable way to check for loopback
        # try:
        #     connection_names = socket.gethostbyaddr(IP)
        # except socket.herror:
        #     pass #name not know to us in definitely not loopback
        # else:
        #     host_name = socket.gethostname()
        #     return host_name in connection_names or host_name in [conn.split(".")[0] for conn in connection_names]

    @log.log_exception
    def run(self):
        sock = None
        while not self.stop_flag:
            if not sock:
                sock = self.create_socket()
            if not sock:
                time.sleep(6)
                continue
            try:
                message, address = sock.recvfrom(self.MAX_MESSAGE_LENGTH)
            except socket.timeout:
                pass
            except (socket.error, AttributeError) as e:
                if sock:
                    sock.close()
                    sock = None
                self.logger.warning("Client detector error while receiving: " + str(e))
            else:
                if message == BROADCAST_MESSAGE:
                    if config.get_settings()["remote_control"]["detectable"]:
                        self.parent.send_id_dict(address[0])
                else:

                    try:
                        id_dict = json.loads(message)
                    except (ValueError, TypeError):
                        self.logger.exception("Client detector received a message(%s) that is not a valid json from %s.\nException: " % (str(message), str(address)))
                    else:
                        with self.clients_lock:
                            if self.validate_message(id_dict):
                                if not self.is_loopback(address[0]):
                                    self.clients[address[0]] = id_dict
                            else:
                                self.logger.warning("Client detector received invalid id_dict from %s\t%s: " % (str(message), str(address)))
        if sock:
            sock.close()

    def flush_clients_list(self):
        with self.clients_lock:
            self.clients = {}


class ClientScanner:

    DETECTION_TIMEOUT = 5

    def __init__(self, parent):
        self.logger = logging.getLogger(__name__)
        self.parent = parent
        if config.get_settings()["remote_control"]["broadcasting"]:
            self.receiver = BroadcastReceiver(self)
            self.receiver.start()
        else:
            self.receiver = None

    def create_id_dict_json(self):
        printers = [ {'name': getattr(printer, 'printer_name', ''),
                     'id_dict': getattr(printer, 'id_dict', getattr(printer, 'usb_info', {})),
                     'state': printer.get_printer_state(),
                     'profile': getattr(printer, 'printer_profile', {})}
                     for printer in self.parent.printer_interfaces]
        id_dict = {"login": getattr(self.parent.user_login, "login", "No login"),
                   "platform": platforms.get_platform(),
                   "printers": printers,
                   "port": config.get_settings()['web_interface']['port'],
                   "remote_control": config.get_settings()['remote_control']['web_server']}
        id_dict_json = json.dumps(id_dict)
        if len(id_dict_json) > BroadcastReceiver.MAX_MESSAGE_LENGTH:
            self.logger.warning("Client detection id dictionary message is too long. It can't be received properly!")
        return id_dict_json.encode("utf-8")

    def send_id_dict(self, IP):
        try:
            sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
            sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
            sock.settimeout(1)
            sock.bind(('', 0))
            sock.sendto(self.create_id_dict_json(), (IP, BROADCAST_RECV_PORT))
        except socket.error as e:
            self.logger.warning("Client detector error on send: " + str(e))
        else:
            try:
                sock.shutdown(socket.SHUT_RDWR)
                sock.close()
            except:
                pass

    def detect(self):
        if self.receiver:
            self.logger.info("Detecting other clients in LAN...")
            self.receiver.flush_clients_list()
            try:
                sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
                sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
                sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
                sock.settimeout(self.receiver.RECEIVE_TIMEOUT)
                sock.bind(('', 0))
                sock.sendto(BROADCAST_MESSAGE, ('255.255.255.255', BROADCAST_RECV_PORT))
            except socket.error as e:
                self.logger.warning("Client detector broadcasting error: " + str(e))
            else:
                try:
                    sock.shutdown(socket.SHUT_RDWR)
                    sock.close()
                except:
                    pass
                time.sleep(self.DETECTION_TIMEOUT)

                with self.receiver.clients_lock:
                    clients_copy = copy.deepcopy(self.receiver.clients)  # it should be this complex for thread-safety
                self.logger.info("Detected clients: " + str(clients_copy))
                return clients_copy

    def forget_client(self, IP):
        try:
            with self.receiver.clients_lock:
                del self.receiver.clients[IP]
        except KeyError:
            pass

    def close(self):
        if self.receiver:
            self.receiver.stop_flag = True
            self.receiver.join()

if __name__ == "__main__":
    import pprint
    class FakeApp:
        pass
    fa = FakeApp()
    fa.stop_flag = False
    fa.printer_interfaces = []
    fa.user_login = None
    log.create_logger("", None)
    cs = ClientScanner(fa)
    pprint.pprint(cs.detect())
