# Copyright 3D Control Systems, Inc. All Rights Reserved 2017-2019.

# Built in San Francisco.

# This software is distributed under a commercial license for personal,
# educational, corporate or any other use.
# The software as a whole or any parts of it is prohibited for distribution or
# use without obtaining a license from 3D Control Systems, Inc.

# All software licenses are subject to the 3DPrinterOS terms of use
# (available at https://www.3dprinteros.com/terms-and-conditions/),
# and privacy policy (available at https://www.3dprinteros.com/privacy-policy/)

import logging
import threading
import time
import socket
import struct
import json
import uuid
import base64
import pathlib
import io
import os
import base_sender
import log
import typing
from formlabs_preform_uploader import PreFormUploader
from enum import Enum


class StatusType(Enum):
    UNDEFINED = 0
    SINGLE_CARTRIDGE = 1
    DUAL_CARTRIDGE = 2
    FUSE = 3


class Sender(base_sender.BaseSender):
    LOOP_TIME = 4
    UPLOAD_REPLY_LENGTH = 159
    RAW_DATA_CHUNK_SIZE = 15640
    MAX_ERROR_BEFORE_ASSUMING_DISCONNECT = 6
    READ_UPLOAD_RESPONSE_RETRIES = 5
    TCP_TIMEOUT = 4
    TCP_PACKAGE_END = b'\x00' * 8

    def __init__(self, parent: typing.Any, usb_info: dict, profile: dict = {}):
        super().__init__(parent, usb_info, profile)
        self.stop_flag = False
        self.profile = profile
        self.serial_number = usb_info.get("SNR", "")
        self.printer_ip = usb_info.get("IP", "")
        self.run_once = usb_info.get('RUNONCE')
        self.printer_port = 35  # default FormLabs port
        if not self.printer_ip:
            raise RuntimeError("Can't connect by IP to FormLabs Printer with " + str(self.usb_info))
        self.last_upload_result = None
        self.is_uploading = False
        self.operational_flag = False
        self.printing_flag = False
        self.paused_flag = False
        self.primed = False
        self.percent = 0
        self.total_lines = 0
        self.lines_sent = 0
        self.temps = [0, 0, 0]
        self.target_temps = [0, 0, 0]
        self.tcp_client = None
        self.file_path_to_upload = ""
        self.status_type = StatusType.UNDEFINED
        self.cartridge_error = ""
        self.tank_material_code = ""
        self.tank_error = None
        self.machine_type_id = ""  # FormLabs printer type (FORM-2-0, FORM-3-0, etc...)
        self.id_bin = ""
        self.status_bin = ""
        self.last_status_bin = ""
        self.status_thread = None
        self.job_upload_uuid = ""
        self.materials_list = []
        self.materials_volume = []
        if not self.run_once:
            self._connect()

    def _start_status_thread(self) -> None:
        if self.status_thread and self.status_thread.is_alive():  # TODO remote this?
            self.logger.warning("FormLabs status loop already running. Skipping...")
            return
        if self.run_once:
            return
        preform_uploader = PreFormUploader.getInstance()
        # preform_uploader.set_parent(self.parent) # TODO: refactor. this will set parent for singleton
        self.status_thread = threading.Thread(target=self._status_loop, daemon=True)
        self.status_thread.start()

    def _write_tcp_package(self, data: typing.AnyStr) -> int:
        size = len(data)
        total_sent = 0
        if self.tcp_client:
            while total_sent < size and not self.stop_flag:
                try:
                    sent = self.tcp_client.send(data[total_sent:])
                    total_sent += sent
                except OSError as e:
                    self.logger.error(f"Socket error({e}) on sending to printer {self.printer_ip}")
                    time.sleep(1)
                    break
                if not sent:
                    # a second chance
                    time.sleep(1)
                    sent = self.tcp_client.send(data[total_sent:])
                    total_sent += sent
                    if not sent:
                        break
        return total_sent

    def _send_recv_tcp(self, data: typing.AnyStr, read_size: int = 1024) -> bytes:
        if self._write_tcp_package(data) == len(data):
            return self._read_tcp_package(read_size)
        return b""

    def _read_tcp_package(self, chunk_size: int = 1024) -> bytes:
        response = bytes()
        wait_response = True
        while wait_response and not self.stop_flag:
            if not self.tcp_client:
                self.logger.error(f"FormLabs: Error reading from {self.printer_ip} - no tcp_client")
                return b""
            else:
                chunk = bytes()
                try:
                    chunk = self.tcp_client.recv(chunk_size)
                except OSError as e:
                    if str(e) == "timed out":
                        self.logger.info(f"FormLabs: Socket timed out for {self.printer_ip}.")
                    else:
                        self.logger.warning(
                            f"FormLabs: Unexpected error in socket read from {self.printer_ip}. Details: {e}")
                    self.register_error(2014, "Error reading data from printer. Description: " + str(e), is_blocking=True)
                    self.operational_flag = False
                finally:
                    if not chunk:
                        self.logger.warning("Formlabs: Unexpected empty socket chunk received " + self.printer_ip)
                response += chunk
                if response == b"":
                    self.logger.error(f"FormLabs: Error in TCP package read from {self.printer_ip}. Empty response")
                    return b""
                wait_response = not self.TCP_PACKAGE_END in response
        return response

    def _find_json(self, raw_data: bytes) -> dict:
        result = {}
        if raw_data:
            try:
                data = raw_data.decode(errors="ignore")
                result = json.loads(data[data.index("{"):data.rindex("}") + 1])
                if not isinstance(result, dict):
                    raise TypeError('Found json in not dict')
            # except (TypeError, ValueError, IndexError):
            except Exception as e:
                self.logger.error(f'Exception in find_json:{e}\n{raw_data}')
        return result

    def _find_valid_json(self, data: bytes) -> dict:
        result = self._find_json(data)
        if result:
            if result.get("Success"):
                return result
            if result.get("Error"):
                self.register_error(2001, f"FormLabs: received response with error from {self.printer_ip}. Details: " \
                                          f'{str(result.get("Error"))}. ReplyToMethod: {str(result.get("ReplyToMethod"))}',
                                    is_info=True)
        return {}

    def _get_uuid_bytes(self) -> bytes:
        # shit code is needed here because uuid.bytes wont properly work with b"%s" % uuid.bytes
        return ("{" + str(uuid.uuid4()) + "}").encode('utf-8')

    def _create_request_json(self, method_name: str, parameters: dict = {}) -> dict:
        if not method_name:
            return
        request_json = {"Id": "{" + str(uuid.uuid4()) + "}", "Method": method_name}
        if parameters:
            request_json["Parameters"] = parameters
        request_json["Version"] = 1
        return request_json

    def _request_from_printer(self, request_json: dict, attr_name_to_save: str = "") -> dict:
        pretty_json_str = json.dumps(request_json, indent=4) + '\n'
        request_data = (len(pretty_json_str)).to_bytes(4, byteorder='little') \
                       + bytes(pretty_json_str, 'utf-8') + self.TCP_PACKAGE_END
        response_data = self._send_recv_tcp(request_data)
        response_json = self._find_valid_json(response_data)
        if attr_name_to_save:
            if response_json:
                setattr(self, attr_name_to_save, base64.b64encode(response_data).decode('utf-8'))
            else:
                self.logger.error(
                    f"Error in validating printer info in FormLabs TCP data package from {self.printer_ip}. request_data: {request_data} response_data: {response_data}")
        return response_json

    def _get_printer_info_tcp(self) -> dict:
        return self._request_from_printer(self._create_request_json("PROTOCOL_METHOD_GET_INFORMATION"), 'id_bin')

    def _get_printer_status_tcp(self) -> dict:
        return self._request_from_printer(self._create_request_json("PROTOCOL_METHOD_GET_STATUS"), 'last_status_bin')

    def send_abort_upload_tcp(self, job_uuid: str) -> dict:
        if not job_uuid:
            return {}
        param = {"Guid": job_uuid}
        return self._request_from_printer(self._create_request_json("PROTOCOL_METHOD_ABORT_JOB", param))

    def _get_printer_calibration_tcp(self, material_code: str, layerThickness_mm: int) -> dict:
        if not material_code or not layerThickness_mm:
            return {}
        param = {
            "layerThickness_mm": layerThickness_mm,
            "material_code": material_code
        }
        return self._request_from_printer(self._create_request_json("PROTOCOL_METHOD_GET_CALIBRATION", param))

    def get_current_line_number(self) -> int:
        return self.current_line_number

    def _close_tcp_client(self) -> None:
        if self.tcp_client:
            self.logger.info("FormLabs: Closing TCP socket connection to %s" % self.printer_ip)
            try:
                self.tcp_client.close()
            except OSError as e:
                self.logger.warning(f'FormLabs: Error closing TCP socket to {self.printer_ip}. Description: {e}')
            self.operational_flag = False
            self.tcp_client = None

    def _init_tcp_client(self) -> None:
        self._close_tcp_client()
        try:
            self.tcp_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            self.tcp_client.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, struct.pack('ii', 1, self.TCP_TIMEOUT))
            self.tcp_client.settimeout(self.TCP_TIMEOUT)
            self.tcp_client.connect((self.printer_ip, self.printer_port))
        except IOError as e:
            self.logger.warning(f'Error connecting to printer: {str(e)}')
            self._close_tcp_client()

    def _request_and_parse_printer_status(self, skip_sending: bool = False) -> str:
        if skip_sending:
            status = self._find_valid_json(self._read_tcp_package())
        else:
            status = self._get_printer_status_tcp()
        if not status or not isinstance(status, dict):
            return f"Printer status from {self.printer_ip} returns non dict response"
        params = status.get("Parameters")
        if not params or not isinstance(params, dict):
            return f"Printer status from {self.printer_ip} got no field Parameters"
        # check status type
        tmp_material_list = []
        update_bin_data = False
        if self.status_type == StatusType.SINGLE_CARTRIDGE:
            tmp_material_list = ["" if params.get("tankMaterialCode") == "NONE" else params.get("tankMaterialCode"),
                                 "" if params.get("cartridgeMaterialCode") == "NONE" else params.get(
                                     "cartridgeMaterialCode")]
            self.materials_volume = [0, max(0, int(params.get("cartridgeOriginalVolume_mL", 0)) - int(
                params.get("cartridgeEstimatedVolumeDispensed_mL", 0)))]
        elif self.status_type == StatusType.DUAL_CARTRIDGE:
            tmp_material_list = ["" if params.get("tankMaterialCode") == "NONE" else params.get("tankMaterialCode"),
                                 "" if params.get("frontCartridgeMaterialCode") == "NONE" else params.get(
                                     "frontCartridgeMaterialCode"),
                                 "" if params.get("backCartridgeMaterialCode") == "NONE" else params.get(
                                     "backCartridgeMaterialCode")]
            self.materials_volume = [0, max(0, int(params.get("frontCartridgeOriginalVolume_mL", 0)) - int(
                params.get("frontCartridgeEstimatedVolumeDispensed_mL", 0))),
                                     max(0, int(params.get("backCartridgeOriginalVolume_mL", 0)) - int(
                                         params.get("backCartridgeEstimatedVolumeDispensed_mL", 0)))]
        elif self.status_type == StatusType.FUSE:
            # material volumes not supported by fuse status data
            self.temps = [int(params.get("bedTemperature_C", 0)), 0, 0]
            tmp_material_list = [
                "" if params.get("cylinderMaterialCode") == "NONE" else params.get("cylinderMaterialCode"),
                "" if params.get("printerMaterial") == "NONE" else params.get("printerMaterial")]
        if tmp_material_list != self.materials_list:
            update_bin_data = True
            self.materials_list = tmp_material_list
        if not self.status_bin \
                or self.printing_flag != params.get("isPrinting") \
                or update_bin_data:
            self.status_bin = self.last_status_bin
        self.primed = params.get("isPrimed", False)
        self.cartridge_error = params.get("cartridgeError", "CARTRIDGE_SUCCESS")
        self.tank_error = params.get("tankError", "TANK_SUCCESS")
        if self.cartridge_error != "CARTRIDGE_SUCCESS" or self.tank_error != "TANK_SUCCESS":
            self.operational_flag = False
        estimated_total = params.get("estimatedTotalPrintTime_ms")
        estimated_remaining = params.get("estimatedPrintTimeRemaining_ms")
        if estimated_total:
            self.estimated_time = estimated_total // 1000
        self.printing_flag = params.get("isPrinting", False)
        if estimated_total:
            self.percent = 100 - int(estimated_remaining / estimated_total * 100)
        else:
            self.percent = 0

    @log.log_exception
    def _status_loop(self) -> None:
        retries_left = self.MAX_ERROR_BEFORE_ASSUMING_DISCONNECT
        while not self.stop_flag:
            loop_begin_time = time.monotonic()
            try:
                error = self._request_and_parse_printer_status()
            except RuntimeError as e:  # FIXME replace with line below in production!
                # except Exception as e:
                error = str(e)
                self.logger.exception(f'Exception in FormLabs sender upload_job: {e}')
            if error:
                retries_left -= 1
            else:
                retries_left = self.MAX_ERROR_BEFORE_ASSUMING_DISCONNECT
            if not retries_left:
                self.operational_flag = False
                self.logger.info(
                    f'Printer disconnected after {self.MAX_ERROR_BEFORE_ASSUMING_DISCONNECT} failed status requests')
            delta = time.monotonic() - loop_begin_time + self.LOOP_TIME
            if delta > 0:
                time.sleep(delta)
            if self.file_path_to_upload:
                time.sleep(1)  # give a printer some time to process requests
                self._upload_job()
                self.file_path_to_upload = ""
        self._close_tcp_client()
        self.logger.info("FormLabs %s exit status loop thread" % self.printer_ip)

    def _connect(self) -> None:
        self.logger.info("Connecting to FormLabs printer via TCP socket")
        errors_counter = 0
        if not self.tcp_client:
            self._init_tcp_client()
        printer_info = self._get_printer_info_tcp()
        if not printer_info or not printer_info.get("Parameters"):
            self.logger.warning("Printer info from %s returns None" % self.printer_ip)
            self._close_tcp_client()
            raise RuntimeError("Unable to connect to FormLabs printer " + str(self.printer_ip))
        self.serial_number = printer_info["Parameters"].get("printer", {}).get("Serial", "")
        self.machine_type_id = printer_info["Parameters"].get("printer", {}).get("MachineTypeId", "")
        if "FORM-2-" in self.machine_type_id or "FORM-3-" in self.machine_type_id or "DGJR-1-" in self.machine_type_id or "FRMB-3-" in self.machine_type_id:
            self.status_type = StatusType.SINGLE_CARTRIDGE
        elif "FRML-3-" in self.machine_type_id or "DGSR-1-" in self.machine_type_id or "FRBL-3-" in self.machine_type_id:
            self.status_type = StatusType.DUAL_CARTRIDGE
        elif "PILK-1-" in self.machine_type_id or "FS30-1-" in self.machine_type_id:
            self.status_type = StatusType.FUSE
        self.operational_flag = True
        self._start_status_thread()

    def set_filename(self, filename: str) -> None:
        self.filename = filename

    def _parse_and_upload_job_file(self, file_path: str) -> str:
        with open(file_path, mode='rb') as file:
            try:
                serial, guid = file.readline().decode(errors="ignore").rstrip('\n').split(':')
            except IOError:
                return f"Aborting FormLabs job upload for {self.printer_ip}. Description: unreadable job file"
            except ValueError:
                return f"Aborting FormLabs job upload for {self.printer_ip}. Description: unreadable job file (Value Error)"
            if serial != self.serial_number:
                return f"Aborting FormLabs job upload for {self.printer_ip}. Description: a job was sliced for a different printer and is incompatible"
            try:
                line = file.readline().decode(errors="ignore")
                if not line:
                    raise IOError(f"Unexpected end of file")
            except IOError as e:
                return f"Aborting FormLabs job upload for {self.printer_ip}. Description: Unable to read package lengths. Error:{e}"
            try:
                packages_lengths_pairs = []
                plain_lengths = line.rstrip('\n').split(",")
                while plain_lengths:
                    packages_lengths_pairs.append((int(plain_lengths.pop(0)), int(plain_lengths.pop(0))))
            except (TypeError, ValueError, IndexError) as e:
                return f"Aborting FormLabs job upload for {self.printer_ip}. Description:{e} on parsing packages lengths. Invalid line:{line}"
            job_start_reply = None
            expected_replies = len(packages_lengths_pairs)
            for head_data_len, raw_data_len in packages_lengths_pairs:
                try:
                    job_package = file.read(head_data_len) + file.read(raw_data_len)
                    if not job_package:
                        raise IOError
                except IOError:
                    return f"Aborting FormLabs job upload for {self.printer_ip}. Description:{e} on reading packages"
                bytes_sent = self._write_tcp_package(job_package)
                if bytes_sent != head_data_len + raw_data_len:
                    # giving a second chance
                    time.sleep(1)
                    bytes_sent2 = self._write_tcp_package(job_package[bytes_sent:])
                    if bytes_sent + bytes_sent2 != head_data_len + raw_data_len:
                        self.send_abort_upload_tcp(guid)
                        return f"Abort printer upload for {self.printer_ip}. Socket write error"
                if not job_start_reply:
                    job_start_reply = self._find_json(self._read_tcp_package())
                    if job_start_reply:
                        if job_start_reply.get("ReplyToMethod") not in (
                        "PROTOCOL_METHOD_START_FORM2_JOB", "PROTOCOL_METHOD_START_JOB"):
                            self.send_abort_upload_tcp(guid)
                            return f"Aborting FormLabs job upload for {self.printer_ip}. Description: wrong job start reply)"
                        if not job_start_reply.get("Success"):
                            return f"Error uploading job for {self.printer_ip}. Details: " + job_start_reply.get(
                                "Error", "")
                        expected_replies -= 1
                else:
                    # wait reply for previous layer upload
                    layer_upload_reply = self._find_json(self._read_tcp_package(self.UPLOAD_REPLY_LENGTH))
                    if layer_upload_reply:
                        if layer_upload_reply.get("ReplyToMethod") == "PROTOCOL_METHOD_UPLOAD_LAYER":
                            expected_replies -= 1
                            if not job_start_reply.get("Success"):
                                self.send_abort_upload_tcp(guid)
                                return f"Error uploading job for {self.printer_ip}. Description: wrong job start reply. Details: " + job_start_reply.get(
                                    "Error", "")
            retries_left = self.READ_UPLOAD_RESPONSE_RETRIES
            while not self.stop_flag and expected_replies:
                reply = self._get_printer_status_tcp()
                if reply:
                    if reply.get("ReplyToMethod") == "PROTOCOL_METHOD_UPLOAD_LAYER":
                        expected_replies -= 1
                    else:
                        retries_left -= 1
                        if not retries_left:
                            self.send_abort_upload_tcp(guid)
                            return f"Error uploading job for {self.printer_ip}. Details: {retries_left} upload replies were lost"
                        time.sleep(self.LOOP_TIME)

    def _upload_job(self) -> bool:  # TODO Finish refactoring this. Remove this method entirely
        self.logger.info(f"Start job upload for {self.printer_ip}")
        try:
            error = self._parse_and_upload_job_file(self.file_path_to_upload)
        except RuntimeError as e:  # FIXME replace with line below in production!
            # except Exception as e:
            self.logger.exception(f'Exception in FormLabs sender upload_job: {e}')
            error = str(e)
        if error:
            self.register_error(6100, error, is_blocking=False)
        return bool(error)

    def _upload_preform(self, file_path: str) -> bool:
        preform_uploader = PreFormUploader.getInstance()
        if not self.operational_flag:
            self.logger.warning("Cancel job upload. No connection to FormLabs printer on %s" % self.printer_ip)
            return False
        if not preform_uploader.is_supported():
            self.register_error(6000, "Upload job via PreForm is not supported (Wrong OS or PreForm app is not found)",
                                is_blocking=False, is_info=True)
            self.logger.warning("Cancel job upload. PreForm app not found.")
            return False
        profile_thickness = self.profile.get("thickness", "")
        profile_material = self.profile.get("material_code", "")
        if file_path:
            job_id = preform_uploader.add_job(self.printer_ip, file_path, profile_thickness, profile_material)
        while not self.stop_flag:
            time.sleep(1)
            with preform_uploader.call_lock:
                result = preform_uploader.get_job_result(job_id)
            if result:
                if result.get("success"):
                    return True
                self.register_error(6001, "PreForm job upload failed. Details: " + result.get("message", ""),
                                    is_blocking=False, is_info=True)
                break
        return False

    def gcodes(self, gcodes_or_file: typing.Union[typing.AnyStr, io.BytesIO], keep_file: bool = False) -> bool:
        return self.upload_to_printer(gcodes_or_file, autostart=True)

    def _start_uploading(self, file_path: str) -> None:  # TODO remove this after remaking upload_to_printer
        self.logger.info('Formlabs: Setting upload file path to: ' + str(file_path))
        self.file_path_to_upload = file_path
        while not self.stop_flag and self.file_path_to_upload:
            time.sleep(1)

    @log.log_exception
    def upload_to_printer(self, gcodes_or_file: typing.Union[str, io.BytesIO], autostart: bool = False):
        self.percent = 0
        self.estimated_time = 0
        self.is_uploading = True
        preform_uploader = PreFormUploader.getInstance()
        file_path = preform_uploader.create_file_copy(gcodes_or_file, self.filename)
        file_extension = pathlib.Path(file_path).suffix
        if self.primed:
            self.register_error(2000, "Prime detected. Printer will start the print after upload", is_blocking=False,
                                is_info=True)
        if file_extension == ".formjob" or file_extension == ".gcode" or file_extension == ".g":
            result = self._start_uploading(file_path)
        else:
            result = self._upload_preform(file_path)
        try:
            os.remove(file_path)
        except OSError:
            self.logger.warning(f'Unable to report file: {file_path}')
        self.is_uploading = False
        return result

    def is_printing(self) -> bool:
        return self.printing_flag

    def is_paused(self) -> bool:
        return self.paused_flag

    def is_operational(self) -> bool:
        return self.operational_flag

    def get_percent(self) -> float:
        if self.is_printing():
            return self.percent
        return 0

    def pause(self) -> bool:
        self.parent.register_error(505, "Pause is not supported", is_blocking=False)
        return False

    def unpause(self) -> bool:
        self.parent.register_error(505, "Resume is not supported", is_blocking=False)
        return False

    def cancel(self) -> bool:
        self.parent.register_error(505, "Cancel is not supported", is_blocking=False)
        return False

    def get_ext(self) -> dict:
        ext_dict = {'status_bin': self.status_bin, 'id_bin': self.id_bin, "primed": self.primed}
        if self.tank_error:
            ext_dict['tank_error'] = self.tank_error
        if self.cartridge_error:
            ext_dict['cartridge_error'] = self.cartridge_error
        return ext_dict

    def get_material_names(self) -> typing.List[str]:
        return self.materials_list

    def get_material_volumes(self) -> typing.List[float]:
        return self.materials_volume

    def close(self) -> None:
        self.logger.info(f'Closing Formlabs sender {self.usb_info}')
        self.stop_flag = True
        if self.status_thread:
            try:
                self.status_thread.join(self.LOOP_TIME * 2)
            except:
                pass

