# Copyright 3D Control Systems, Inc. All Rights Reserved 2017-2019.
# Built in San Francisco.

# This software is distributed under a commercial license for personal,
# educational, corporate or any other use.
# The software as a whole or any parts of it is prohibited for distribution or
# use without obtaining a license from 3D Control Systems, Inc.

# All software licenses are subject to the 3DPrinterOS terms of use
# (available at https://www.3dprinteros.com/terms-and-conditions/),
# and privacy policy (available at https://www.3dprinteros.com/privacy-policy/)

import time
import json
import select
import socket
import logging
import threading
import os.path
import pprint

import config
import usb_detect
import paths
import http_rrfw_detect
import http_reperier_detect
import zeroconf_detect
import zeroconf_formlabs_detect
import upnp_detect
import broadcast_printer_detector
import base_detector
import printer_settings_and_id


# TODO refactor
# 1) use multithreading or asyncio to scan simultaneously
# 2) decouple starting, running and processing from network detector and other detectors
# 3) refactor to simplify network detector logic and minimize this class

SCANNERS = {'broadcast': broadcast_printer_detector.BroadcastPrinterDetector,\
            'birdwing': broadcast_printer_detector.BroadcastPrinterDetector,\
            'rrfw' : http_rrfw_detect.HTTPRRFWDetector,\
            'repetier' : http_reperier_detect.HTTPRepetierDetector,\
            'zeroconf' : zeroconf_detect.ZeroconfDetector,\
            'zeroconf_formlabs' : zeroconf_formlabs_detect.ZeroconfDetector, \
            'upnp' : upnp_detect.UPNPDetector,
            'noscanner': None,
            None: None}

class NetworkDetector(base_detector.BaseDetector):

    STORAGE_FILE_PATH = os.path.join(paths.CURRENT_SETTINGS_FOLDER, 'detected_printers.json')
    DEFAULT_SCANNER_CLASS = broadcast_printer_detector.BroadcastPrinterDetector

    PRINTER_ID_DICT_KEYS = ('VID', 'PID', 'SNR', 'IP')
    UNKNOWN_PROFILE_VID_PID = 'XXXX', 'XXXX'

    @staticmethod
    def is_same_printer(checking_dict, checked_dict): #TODO rename args in all other classes
        for key in NetworkDetector.PRINTER_ID_DICT_KEYS:
            checking_value = checking_dict.get(key)
            checked_value = checked_dict.get(key)
            if checking_value and checking_value != checked_value:
                return False
        return True

    @staticmethod
    def get_all_network_conns(profile):
        return filter(lambda conn: conn.get('type') == 'LAN', profile.get('v2', {}).get('connections', []))

    def __init__(self, parent=None):
        super().__init__(self)
        self.parent = parent
        self.lists_lock = threading.RLock()
        self.run_lock = threading.Lock()
        self.detected_printers = []
        self.printers_add_by_ip = []
        self.load_printers_list()
        self.network_detector_run_flag = False

    def refresh_on_next_loop(self):
        self.network_detector_run_flag = True

    def parent_stop(self):
        return self.parent and self.parent.stop_flag

    def get_all_network_profiles(self):
        return [p for p in config.get_profiles() if p.get('network_detect') or any(self.get_all_network_conns(p))]

    def get_unique_network_profiles(self):
        profiles = []
        network_detects = []
        for profile in self.get_all_network_profiles():
            detector = profile.get('network_detect')
            if detector and detector not in network_detects:
                profiles.append(profile)
                network_detects.append(detector)
        return profiles

    def detect(self):
        self.logger.info('Scanning for network printers...')
        detected_printers = []
        for profile in self.get_unique_network_profiles():
            if self.parent_stop():
                return []
            scanner_name = profile.get('network_detect', {}).get('scanner')
            if scanner_name and scanner_name != 'noscanner':
                scanner_class = SCANNERS.get(scanner_name)
                if not scanner_class:
                    self.logger.error(f'Unknown scanner name: {scanner_name}')
                    continue
                scanner = scanner_class(self.parent, profile)
                scanner.detect()
                self.logger.debug('Scanner:\t%s\t%s' % (scanner, scanner.discovered_printers))
                detected_printers = list(scanner.discovered_printers)
        self.logger.info('Discovered printers:\n%s' % detected_printers)
        return detected_printers

    def get_printers_list(self):
        with self.run_lock:
            if self.network_detector_run_flag:
                detected = self.detect()
                with self.lists_lock:
                    for printer in detected:
                        if printer not in self.detected_printers and printer not in self.printers_add_by_ip:
                            self.detected_printers.append(printer)
                    self.save_printers_list()
                self.network_detector_run_flag = False
        return self.detected_printers + self.printers_add_by_ip

    def printers_list_with_profiles(self):
        printers_list = []
        for info in self.get_printers_list():
            for pi in getattr(self.parent, "printer_interfaces", []):
                if pi.usb_info == info:
                    printers_list.append({"usb_info": info, "profile": pi.printer_profile})
                    break
            else:
                printers_list.append({"usb_info": info, "profile": {}})
        return printers_list

    def remember_printer(self, printer_type_name_or_alias, ip, port=None, vid=None, pid=None, serial_number=None, password=None, run_detector=False, conn_id=None):
        for profile in self.get_all_network_profiles():
            if printer_type_name_or_alias == profile.get("name") or printer_type_name_or_alias == profile.get("alias"):
                printer_type_alias = profile.get("alias")
                scanner_name = profile.get('network_detect', {}).get('scanner')
                if run_detector and scanner_name: #disable scan for birdwing printers
                    if scanner_name in SCANNERS:
                        scanner_class = SCANNERS[scanner_name]
                        if scanner_class:
                            scanner = scanner_class(self.parent, profile)
                            scanner.detect(already_know_ip=ip, non_default_port=port)
                            for printer in scanner.discovered_printers:
                                if printer not in self.printers_add_by_ip and not printer in self.detected_printers:
                                    self.printers_add_by_ip.append(printer)
                            self.save_printers_list()
                            return True
                    else:
                        self.logger.warning(f'Unknown scanner: {scanner_name}')
                else:
                    if not serial_number:
                        serial_number = ip
                    if not vid or not pid:
                        if "v2" in profile and not profile.get('network_detect'):
                            conn_dicts = list(self.get_all_network_conns(profile))
                            if conn_dicts:
                                if conn_id:
                                    for conn in conn_dicts:
                                        if conn.get('id') == conn_id:
                                            vid = conn['ids'][0]['VID']
                                            pid = conn['ids'][0]['PID']
                                            break
                                else:
                                    try:
                                        vid = conn_dicts[0]['ids'][0]['VID']
                                        pid = conn_dicts[0]['ids'][0]['PID']
                                        conn_id = conn_dicts[0]['id']
                                    except (IndexError, AttributeError, KeyError):
                                        pass
                                if not vid or not pid:
                                    self.logger.warning(f'Unable to get VID and PID for {printer_type_name_or_alias}. Using v1 profile ids')
                    if not vid or not pid:
                        vid, pid = profile["vids_pids"][0]
                    printer = {
                        'IP': ip,
                        'SNR': serial_number,
                        'VID': vid,
                        'PID': pid
                    }
                    printer_settings = {'type_alias': printer_type_alias}
                    if conn_id:
                        printer_settings['connection_id'] = conn_id
                    printer_settings_and_id.save_settings(printer_settings_and_id.create_id_string(printer), printer_settings)
                    if port is not None:
                        printer['PORT'] = port
                    if password is not None:
                        printer['PASS'] = password
                    with self.lists_lock:
                        if printer not in self.printers_add_by_ip:
                            self.printers_add_by_ip.append(printer)
                            self.save_printers_list()
                    self.logger.info('Network printers list updated to: ' + pprint.pformat(self.printers_add_by_ip))
                    return True
        return False

    def forget_printer(self, usb_info):
        with self.lists_lock:
            for group in (self.detected_printers, self.printers_add_by_ip):
                for printer in group:
                    # vid = usb_info.get('VID')
                    # pid = usb_info.get('PID')
                    # ip = usb_info.get('IP')
                    # snr = usb_info.get('SNR')
                    if self.is_same_printer(usb_info, printer):
                        group.remove(printer)
                        self.logger.info(f'Forgetting a printer with id: {usb_info}')
                        self.save_printers_list()
                        return True
            self.logger.warning(f'Unable to forget a printer with id: {usb_info}')
            return False

    def save_printers_list(self):
        file_path = self.STORAGE_FILE_PATH
        try:
            with open(file_path, 'w') as f:
                data = {'detected': self.detected_printers, 'added_by_ip': self.printers_add_by_ip}
                json_config = json.dumps(data, indent = 4, separators = (',', ': '))
                f.write(json_config)
        except Exception as e:
            self.logger.error("Error writing data to %s: %s" % (file_path, str(e)))

    def load_printers_list(self):
        with self.lists_lock:
            file_path = self.STORAGE_FILE_PATH
            try:
                with open(file_path, 'r') as f:
                    data = f.read()
                    self.logger.info("Stored printers list: %s" % data)
                    printers_list = json.loads(data)
                    self.detected_printers = printers_list.get('detected', [])
                    self.printers_add_by_ip = printers_list.get('added_by_ip', [])
            except FileNotFoundError:
                pass
            except Exception as e:
                self.logger.error("Error reading data from %s: %s" % (file_path, str(e)))


if __name__ == '__main__':
    from pprint import pprint

    logging.basicConfig(level=logging.DEBUG)
    nd = NetworkDetector()
    nd.refresh_on_next_loop()
    print("Detected network printers:")
    pprint(nd.get_printers_list())
