# Copyright 3D Control Systems, Inc. All Rights Reserved 2017-2019.
# Built in San Francisco.

# This software is distributed under a commercial license for personal,
# educational, corporate or any other use.
# The software as a whole or any parts of it is prohibited for distribution or
# use without obtaining a license from 3D Control Systems, Inc.

# All software licenses are subject to the 3DPrinterOS terms of use
# (available at https://www.3dprinteros.com/terms-and-conditions/),
# and privacy policy (available at https://www.3dprinteros.com/privacy-policy/)

import time
import simplejson as json
import socket
import logging
import threading
import os.path

import config
import usb_detect
import paths


LISTEN_TIMEOUT = 3


class Listening(threading.Thread):

    def __init__(self, port):
        self.logger = logging.getLogger(self.__class__.__name__)
        self.data = None
        self.addr = None
        self.socket = self.create_socket(port)
        threading.Thread.__init__(self)

    def create_socket(self, port):
        self.logger.debug("Creating listen socket for port %d" % port)
        try:
            sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
            sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
            sock.bind(('', port))
            sock.settimeout(LISTEN_TIMEOUT)
        except socket.error:
            self.logger.debug("Socket init error. Port:" + str(port), exc_info=True)
        except ValueError:
            self.logger.debug("Not valid port number:" + str(port), exc_info=True)
        else:
            self.logger.debug("...done")
            return sock

    def run(self):
        if self.socket:
            try:
                self.data, self.addr = self.socket.recvfrom(1024)
            except socket.error:
                pass


class NetPrinterScan:

    RETRIES = 1
    RETRY_TIMEOUT = 0.1

    def __init__(self, profile):
        self.logger = logging.getLogger('app.' + __name__)
        self.profile = profile
        network_profile = profile['network_detect']
        self.listen_port = network_profile.get('listen_port', None)
        self.broadcast_port = network_profile.get('broadcast_port', None)
        self.target_port = network_profile.get('target_port', None)
        message = network_profile.get('message', None)
        if network_profile.get('json_message', None):
            message = json.dumps(message)
        self.message = message
        self.already_detected = []
        self.discovered_printers = []
        if self.broadcast_port and self.listen_port and self.target_port:
            self.listen = Listening(self.listen_port)
            self.listen.start()
            self.broadcast()
            self.listen.join()
        else:
            error = "Error in config file. Section network_detect of profile: " + str(self.profile)
            self.logger.critical(error)

    def create_broadcasting_socket(self):
        self.logger.debug("Creating broadcast socket on port %d" % self.broadcast_port)
        try:
            sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
            sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
            sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
            sock.settimeout(LISTEN_TIMEOUT-1)
            sock.bind(('', self.broadcast_port))
        except socket.error:
            self.logger.debug("Socket init error. Port: %d" % self.broadcast_port, exc_info=True)
        except ValueError:
            self.logger.debug("Not valid port number: %d" % self.broadcast_port, exc_info=True)
        else:
            return sock

    def broadcast(self):
        self.logger.debug("Sending broadcast to port %d" % self.target_port)
        bc_socket = self.create_broadcasting_socket()
        if bc_socket:
            attempt = 1
            while attempt < self.RETRIES + 1:
                try:
                    bc_socket.sendto(str(self.message), ('255.255.255.255', self.target_port))
                except socket.error:
                    self.logger.debug("Error sending broadcast", exc_info=True)
                    #self.logger.debug("Timeout on port:" + str(self.listen_port))
                attempt += 1
                time.sleep(self.RETRY_TIMEOUT)
                self.listen.join()
                if self.listen.addr:
                    self.process_response(self.listen.data, self.listen.addr)
                    break
            self.logger.debug("Done listening to port " + str(self.listen_port))
            bc_socket.close()

    def process_response(self, response, addr):
        if addr not in self.already_detected:
            self.already_detected.append(addr)
            printer = {'IP': addr[0], 'port': addr[1]}
            if self.profile['network_detect']['json_response']:
                try:
                    response = json.loads(response)
                except:
                    self.logger.debug('Response from printer should be valid json. Its malformed or not json')
                else:
                    ip_field = self.profile['network_detect']['IP_field']
                    iserial_field = self.profile['network_detect']['SNR_field']
                    vid_field = self.profile['network_detect']['VID_field']
                    pid_field = self.profile['network_detect']['PID_field']
                    printer['IP'] = response[ip_field]
                    if addr[0] != printer['IP']:
                        self.logger.warning("Detected printer IP didn't match with IP field in response")
                    printer['SNR'] = response[iserial_field]
                    printer['VID'] = usb_detect.USBDetector.format_vid_or_pid(response[vid_field])
                    printer['PID'] = usb_detect.USBDetector.format_vid_or_pid(response[pid_field])
            self.discovered_printers.append(printer)

    def close(self):
        pass #TODO implement close here

class NetworkDetector:
    STORAGE_FILE_PATH = os.path.join(paths.CURRENT_SETTINGS_FOLDER, 'detected_printers.json')

    def __init__(self, app):
        self.app = app
        self.logger = logging.getLogger(self.__class__.__name__)
        self.detected_printers = []
        self.printers_add_by_ip = []
        self.load_printers_list()
        self.network_detector_run_flag = False

    def get_all_network_profiles(self):
        return [x for x in config.get_profiles() if x.get('network_detect')]

    def get_unique_network_profiles(self):
        profiles = []
        network_detects = []
        for profile in self.get_all_network_profiles():
            if profile['network_detect'] not in network_detects:
                profiles.append(profile)
                network_detects.append(profile['network_detect'])
        return profiles

    def get_printers_list(self):
        if self.network_detector_run_flag:
            self.logger.info('Scanning for network printers...')
            self.detected_printers = []
            for profile in self.get_unique_network_profiles():
                if self.app.stop_flag:
                    return []
                scanner = NetPrinterScan(profile)
                self.detected_printers.extend(scanner.discovered_printers)
            self.logger.info('Discovered printers:\n%s' % self.detected_printers)
            self.save_printers_list()
            self.network_detector_run_flag = False
        return self.detected_printers + self.printers_add_by_ip

    def printers_list_with_profiles(self):
        printers_list = []
        for info in self.get_printers_list():
            for pi in self.app.printer_interfaces:
                if pi.usb_info == info:
                    printers_list.append({"usb_info": info, "profile": pi.printer_profile})
                    break
            else:
                printers_list.append({"usb_info": info, "profile": {}})
        return printers_list

    def connect_to_network_printer(self, printer_ip, printer_type):
        #TODO find a way to pass seletected printer type to it's printer interface
        for profile in self.get_all_network_profiles():
            if printer_type == profile.get("name"):
                printer = {
                    'IP': printer_ip,
                    'SNR': printer_ip,
                    'VID': profile["vids_pids"][0][0],
                    'PID': profile["vids_pids"][0][1]
                }
                if printer not in self.printers_add_by_ip:
                    self.printers_add_by_ip.append(printer)
                    self.save_printers_list()
                break

    def forget(self, usb_info):
        for group in (self.detected_printers, self.printers_add_by_ip):
            for printer in group:
                if printer == usb_info:
                    group.remove(printer)
        self.save_printers_list()

    def save_printers_list(self):
        file_path = self.STORAGE_FILE_PATH
        try:
            with open(file_path, 'w') as f:
                data = {'detected': self.detected_printers, 'added_by_ip': self.printers_add_by_ip}
                json_config = json.dumps(data, indent = 4, separators = (',', ': '))
                f.write(json_config)
        except Exception as e:
            self.logger.error("Error writing data to %s: %s" % (file_path, str(e)))

    def load_printers_list(self):
        file_path = self.STORAGE_FILE_PATH
        try:
            with open(file_path, 'r') as f:
                data = f.read()
                self.logger.info("Stored printers list: %s" % data)
                printers_list = json.loads(data)
                self.detected_printers = printers_list.get('detected', [])
                self.printers_add_by_ip = printers_list.get('added_by_ip', [])
        except FileNotFoundError:
            pass
        except Exception as e:
            self.logger.error("Error reading data from %s: %s" % (file_path, str(e)))


if __name__ == '__main__':
    class FakeApp:
        pass
    fa = FakeApp()
    fa.stop_flag = False
    import user_login
    import log
    log.create_logger("", None)
    user_login = user_login.UserLogin(fa)
    user_login.wait()
    if user_login.user_token and hasattr(user_login, "profiles"):
        config.Config.instance().set_profiles(user_login.profiles)
    detector = NetworkDetector(fa)
    detector.network_detector_run_flag = True
    printers = detector.get_printers_list()
    print("Detected network printers: ")
    print(json.dumps(printers))
