# Copyright 3D Control Systems, Inc. All Rights Reserved 2017-2019.
# Built in San Francisco.

# This software is distributed under a commercial license for personal,
# educational, corporate or any other use.
# The software as a whole or any parts of it is prohibited for distribution or
# use without obtaining a license from 3D Control Systems, Inc.

# All software licenses are subject to the 3DPrinterOS terms of use
# (available at https://www.3dprinteros.com/terms-and-conditions/),
# and privacy policy (available at https://www.3dprinteros.com/privacy-policy/)

import copy
import json
import logging
import threading
import operator
import os.path
import pprint

import bambulab_scanner
import config
import paths
import base_detector
import printer_settings_and_id
import http_rrfw_scanner
import http_reperier_scanner
import zeroconf_scanner
import upnp_scanner
import broadcast_printer_scanner
import http_prusalink_scanner


class NetworkDetector(base_detector.BaseDetector):

    SCANNERS = {'broadcast': broadcast_printer_scanner.BroadcastPrinterDetector,
                'birdwing': broadcast_printer_scanner.BroadcastPrinterDetector,
                'rrfw' : http_rrfw_scanner.HTTPRRFWDetector,
                'repetier' : http_reperier_scanner.HTTPRepetierDetector,
                'zeroconf' : zeroconf_scanner.ZeroconfDetector,
                'zeroconf_formlabs' : zeroconf_scanner.ZeroconfDetector, 
                'prusalink' : http_prusalink_scanner.HTTPPrusaLinkDetector,
                'bambulab_scanner': bambulab_scanner.BambuLabDetector,
                #'upnp' : upnp_scanner.UPNPDetector,
                'noscanner': None,
                None: None}

    STORAGE_FILE_PATH = os.path.join(paths.CURRENT_SETTINGS_FOLDER, 'detected_printers.json')

    UNKNOWN_PROFILE_VID_PID = 'XXXX', 'XXXX'

    @staticmethod
    def is_same_printer(checking_dict, checked_dict): #TODO rename args in all other classes
        for key in NetworkDetector.PRINTER_ID_DICT_KEYS:
            checking_value = checking_dict.get(key)
            checked_value = checked_dict.get(key)
            if checking_value and checking_value != checked_value:
                return False
        return True

    @staticmethod
    def get_all_network_conns(profile):
        return filter(lambda conn: conn.get('type') == 'LAN', profile.get('v2', {}).get('connections', []))

    @staticmethod
    def get_all_network_profiles():
        profiles = sorted([p for p in config.get_profiles() if p.get('network_detect') or any(NetworkDetector.get_all_network_conns(p))], key=operator.itemgetter('name'))
        return profiles

    def __init__(self, parent=None):
        super().__init__(parent)
        self.parent = parent
        self.lists_lock = threading.RLock()
        self.run_lock = threading.Lock()
        self.detected_printers = []
        self.printers_add_by_ip = []
        self.load_printers_list()
        self.just_detected_printers = []
        self.detection_thread = None
        self.detect_time_left = 0
        self.scanners_table = self.fill_scanners_table()


    def parent_stop(self):
        return self.parent and self.parent.stop_flag

    def fill_scanners_table(self):
        # returns dict of format {scanner_name: ([[vid,pid][vid1,pid1],...], profile_dict, conn_id_str|"")}
        scanners = {}
        for profile in self.get_all_network_profiles():
            scanner_name = profile.get('network_detect', {}).get('scanner')
            if scanner_name:
                if scanner_name not in scanners:
                    scanners[scanner_name] = []
                vids_pids = [{'VID': vp[0], 'PID': vp[1]} for vp in profile['vids_pids']]
                scanners[scanner_name].append((vids_pids, profile, ""))
            if 'v2' in profile:
                for conn in profile.get('v2', {}).get('connections', []):
                    scanner_name = conn.get('scanner')
                    if scanner_name:
                        if scanner_name not in scanners:
                            scanners[scanner_name] = []
                        scanners[scanner_name].append((conn['ids'], profile, conn['id']))
        return scanners

    def get_scanners_table(self):
        return self.scanners_table

    def detect(self, only_scanners=None, only_profile=None, only_conn=None):
        self.logger.info('Scanning for network printers...')
        detected_printers = []
        caches = {}
        for scanner_name in self.get_scanners_table():
            if self.parent_stop():
                break
            if not only_scanners or scanner_name in only_scanners:
                scanner_class = self.SCANNERS.get(scanner_name)
                if not scanner_class:
                    if not scanner_name or scanner_name != 'noscanner':
                        self.logger.error(f'Profile got unknown scanner: {scanner_name}')
                    continue
                scanner = scanner_class(self.parent, only_profile, only_conn)
                cache_attr_name = getattr(scanner_class, 'PASS_CACHE_ATTR_NAME', None)
                cache = caches.get(cache_attr_name)
                if cache_attr_name and cache:
                    setattr(scanner, cache_attr_name, cache)
                scanner.detect()
                for printer in scanner.get_discovered_printers():
                    printer['SCNR'] = scanner_name
                    if not cache and cache_attr_name:
                        cache = getattr(scanner, cache_attr_name, None)
                        if cache:
                            caches[cache_attr_name] = cache
                    # After the printer is detected, check if it has already been added and fill in the known info
                    for known_printer in self.get_printers_list():
                        if all(printer.get(key) == known_printer.get(key) for key in ["VID", "PID", "SNR"]):
                            ip = printer.get('IP')
                            printer.update(known_printer)
                            printer['IP'] = ip
                    self.logger.debug(f'Scanner {scanner_name} found:' + pprint.pformat(printer))
                    detected_printers.append(printer)
        with self.lists_lock:
            self.just_detected_printers = copy.deepcopy(detected_printers)
        self.logger.info('Detected network printers:\n%s' % detected_printers)
        return detected_printers 

    def get_printers_list(self, autodetect=False):
        # with self.run_lock:
        #     if autodetect:
        #         detected = self.detect()
        #         if detected:
        #             with self.lists_lock:
        #                 for printer in detected:
        #                     self.add_printer_to_list(printer, self.detected_printers)
        #             self.save_printers_list()
        return self.detected_printers + self.printers_add_by_ip

    def printers_list_with_profiles(self):
        printers_list = []
        for info in self.get_printers_list():
            for pi in getattr(self.parent, "printer_interfaces", []):
                if self.is_same_printer(info, pi.usb_info):
                    printers_list.append({"usb_info": info, "profile": pi.printer_profile})
                    break
            else:
                printers_list.append({"usb_info": info, "profile": {}})
        return printers_list

    def get_profile_by_type_name_or_alias(self, printer_type_name_or_alias):
        for profile in self.get_all_network_profiles():
            if printer_type_name_or_alias == profile.get("name") or printer_type_name_or_alias == profile.get("alias"):
                return profile

    def run_scanner_for_profile(self, profile, conn_id, vid=None, pid=None, ip=None, port=None):
        detected = []
        scanner_name = None
        if not conn_id:
            possible_scanner_name = profile.get('network_detect', {}).get('scanner')
            if possible_scanner_name and ([vid, pid] not in profile.get('vids_pids') or not vid):
                scanner_name = possible_scanner_name
        if not scanner_name:
            for conn in profile.get('v2').get('connections', []):
                if conn.get('type') == 'LAN':
                    if not conn_id or conn_id == conn.get('id'):
                        possible_scanner_name = conn.get('scanner')
                        if possible_scanner_name:
                            if not vid and not pid:
                                conn_id = conn.get('id')
                                conn_ids = conn.get('ids')[0]
                                vid, pid = conn_ids['VID'], conn_ids['PID']
                                scanner_name = possible_scanner_name
                                break
                            if {'VID': vid, 'PID': pid} in conn['ids']:
                                conn_id = conn.get('id')
                                scanner_name = possible_scanner_name
                                break
        if scanner_name:
            scanner_class = self.SCANNERS[scanner_name]
            if scanner_class:
                scanner = scanner_class(self.parent, profile)
                scanner.detect(ip, port)
                for discovered in scanner.get_discovered_printers():
                    if ip and ip != discovered.get('IP'):
                        continue
                    output = {'IP': discovered.get('IP')}
                    if scanner.CAN_DETECT_VID_PID:
                        if not vid:
                            vid = output['VID'] 
                        elif vid != output['VID']:
                            continue
                        if not pid:
                            pid = output['PID'] 
                        elif pid != output['PID']:
                            continue
                        discovered['VID'] = vid
                        discovered['PID'] = pid 
                    if scanner.CAN_DETECT_SNR and 'SNR' in scanner:
                        discovered['SNR'] = scanner['SNR']
                    detected.append(discovered)
        return detected

    def remember_printer(self, printer_type_name_or_alias, ip=None, port=None, vid=None, pid=None, serial_number=None, password=None, ssh_password=None, run_detector=False, conn_id=None):
        profile = self.get_profile_by_type_name_or_alias(printer_type_name_or_alias)
        if not profile:
            self.logger.error('Unable to found network profile: ' + str(printer_type_name_or_alias))
            return False
        printer_type_alias = profile.get("alias")
        if run_detector or not ip:
            for result in self.run_scanner_for_profile(profile, conn_id, vid, pid, ip, port):
                if not ip:
                    ip = result.get('IP')
                if result.get('IP') == ip:
                    if not serial_number and result.get('SNR'):
                        serial_number = result.get('SNR')
        if not serial_number:
            serial_number = ip
        if not vid or not pid:
            if profile.get('network_detect') and not conn_id:
                vid, pid = profile.get('vids_pids')[0]
            elif "v2" in profile:
                conn_dicts = list(self.get_all_network_conns(profile))
                if conn_dicts:
                    for conn in conn_dicts:
                        if conn.get('type') == 'LAN':
                            if not conn_id or conn.get('id') == conn_id:
                                conn_id = conn.get('id')
                                vid = conn['ids'][0]['VID']
                                pid = conn['ids'][0]['PID']
                                break
        if not vid or not pid:
            self.logger.warning(f'Unable to get VID and PID for {printer_type_name_or_alias}')
        printer = {
            'IP': ip,
            'SNR': serial_number,
            'VID': vid,
            'PID': pid
        }
        printer_settings = {'type_alias': printer_type_alias}
        if conn_id:
            printer_settings['connection_id'] = conn_id
        printer_settings_and_id.save_settings(printer_settings_and_id.create_id_string(printer), printer_settings)
        if port is not None:
            printer['PORT'] = port
        if password is not None:
            printer['PASS'] = password
        if ssh_password is not None:
            printer['SSHP'] = ssh_password
        self.add_printer_to_list(printer, self.printers_add_by_ip)
        self.logger.info('Network printers list updated to: ' + pprint.pformat(self.printers_add_by_ip))
        return True

    # def add_printers(self, scanner):
        #         for printers_list in (self.detected_printers, self.printers_add_by_ip, detected_printers):
        #             for existing_printer in printers_list:
        #                 if self.is_same_printer(printer, existing_printer):
        #                     self.logger.info(f'Removing existing printer {existing_printer} to a printer with same is {printer}')
        #                     if overwrite_existing:
        #                         printers_list[printers_list.index(existing_printer)] = printer
        #                         break
        #             else:
        #                 continue
        #             break
        #         else:
        #             detected_printers.append(printer)

    def add_printer_to_list(self, printer, list_to_add, allow_overwrite=True, extra_excusion_lists=None):
        with self.lists_lock:
            already_added = False
            all_lists = [self.detected_printers, self.printers_add_by_ip]
            if extra_excusion_lists:
                all_lists.extend(extra_excusion_lists)
            for printers_list in all_lists:
                for existing_index, existing_printer in enumerate(printers_list):
                    indexes_to_remove = []
                    if self.is_same_printer(printer, existing_printer):
                        if printer != existing_printer:
                            self.logger.info(f"Removing existing printer {existing_printer} " 
                                             f"to a printer with same is {printer}")
                            if allow_overwrite:
                                if list_to_add == printers_list:
                                    printers_list[existing_index] = printer
                                    already_added = True
                                else:
                                    indexes_to_remove.append(existing_index)
                        else:
                            already_added = True
                    for index in indexes_to_remove:
                        del printers_list[index]
            if not already_added:
                list_to_add.append(printer)
            self.save_printers_list()

    def forget_printer(self, usb_info):
        with self.lists_lock:
            for group in (self.detected_printers, self.printers_add_by_ip):
                for printer in group:
                    # vid = usb_info.get('VID')
                    # pid = usb_info.get('PID')
                    # ip = usb_info.get('IP')
                    # snr = usb_info.get('SNR')
                    if self.is_same_printer(usb_info, printer):
                        group.remove(printer)
                        self.logger.info(f'Forgetting a printer with id: {usb_info}')
                        self.save_printers_list()
                        return True
            self.logger.warning(f'Unable to forget a printer with id: {usb_info}')
            return False

    def save_printers_list(self):
        #TODO implement write retry in case of multiply attempt to edit at the same time
        file_path = self.STORAGE_FILE_PATH
        try:
            with open(file_path, 'w') as f:
                data = {'detected': self.detected_printers, 'added_by_ip': self.printers_add_by_ip}
                json_config = json.dumps(data, indent = 4, separators = (',', ': '))
                f.write(json_config)
        except Exception as e:
            self.logger.error("Error writing data to %s: %s" % (file_path, str(e)))

    def load_printers_list(self):
        with self.lists_lock:
            file_path = self.STORAGE_FILE_PATH
            try:
                with open(file_path, 'r') as f:
                    data = f.read()
                    self.logger.info("Stored printers list: %s" % data)
                    printers_list = json.loads(data)
                    self.detected_printers = printers_list.get('detected', [])
                    self.printers_add_by_ip = printers_list.get('added_by_ip', [])
            except FileNotFoundError:
                pass
            except Exception as e:
                self.logger.error("Error reading data from %s: %s" % (file_path, str(e)))

    def run_detection_thread(self):
        if self.detection_thread and self.detection_thread.is_alive():
            return False
        self.detection_thread = threading.Thread(target=self.detect)
        self.detection_thread.run()
        return True

    def get_just_detected_printers(self, clear=True):
        with self.lists_lock:
            printers_list = self.just_detected_printers
            if clear:
                self.just_detected_printers = []
        return printers_list

    def edit_printer(self, edit_dict):
        with self.lists_lock:
            for existing_printer in self.printers_add_by_ip + self.detected_printers:
                if self.is_same_printer(existing_printer, edit_dict):
                    for key in self.ALL_KEYS:
                        if key in edit_dict:
                            if edit_dict[key] == '' or edit_dict[key] is None:
                                if key in existing_printer:
                                    del existing_printer[key]
                            else:
                                existing_printer[key] = edit_dict[key]
                    self.save_printers_list()
                    return True
            return False


if __name__ == '__main__':
    logging.basicConfig(level=logging.DEBUG)
    nd = NetworkDetector()
    printers = nd.detect()
    print("="*80)
    print("Network printers: ")
    pprint.pprint(printers)
