 # Copyright 3D Control Systems, Inc. All Rights Reserved 2017-2019.

# Built in San Francisco.

# This software is distributed under a commercial license for personal,
# educational, corporate or any other use.
# The software as a whole or any parts of it is prohibited for distribution or
# use without obtaining a license from 3D Control Systems, Inc.

# All software licenses are subject to the 3DPrinterOS terms of use
# (available at https://www.3dprinteros.com/terms-and-conditions/),
# and privacy policy (available at https://www.3dprinteros.com/privacy-policy/)

import time
import typing
import base_sender
import aiohttp
import aiohttp.client
import paramiko
import asyncio
import hashlib
from urllib.parse import urlparse

import http_connection


class Raise3DConnection(http_connection.HTTPConnection):

    DEFAULT_PORT = 10800
    TOKEN_ERROR = 401
    HOST_MASK = "http://%s:%d"
    TIMEOUT = 2
    JOB_START_WAIT_TIMEOUT = 60

    def __init__(self, host=None, port=None, timeout=TIMEOUT, parent=None, logger=None):
        super().__init__(host, port, timeout, parent, logger)
        self.temps = [0.0, 0.0, 0.0]
        self.ttemps = [0.0, 0.0, 0.0]
        self.api_token = ""
        self.cam_url = ""
        self.current_job_id = None
        self.last_start_print_time = None
        if parent:
            self.api_passwd = parent.api_password
        else:
            self.api_passwd = None

    async def connect(self):
        self.logger.info(f"Connecting to Raise3D printer at {self.host}")
        try:
            self.session = aiohttp.ClientSession(loop=self.loop, timeout=aiohttp.ClientTimeout(total=self.timeout))
            info = await self.get_printer_system_info()
            # self.logger.info(info)
            if info.get("status", 0) == 1:
                await self.status_requesting()
            else:
                self.logger.error(info)
        except Exception as e:
            self.logger.warning(f"Unable to connect to {str(self.host)}. Desc: {str(e)}")
        finally:
            self.logger.info(f"Closing connection " + str(self.host))
            if self.session:
                await self.session.close()
            self.session = None
            self.operational_flag = False

    def get_ip_or_hostname(self, url):
        parsed_url = urlparse(url)
        return parsed_url.hostname or parsed_url.netloc.split(':')[0]

    def generate_sign(self, password, timestamp):
        sha1_str = f"password={password}&timestamp={timestamp}"
        sha1_result = hashlib.sha1(sha1_str.encode()).hexdigest()
        sign = hashlib.md5(sha1_result.encode()).hexdigest()
        return sign

    async def login(self):
        timestamp = str(int(time.time() * 1000))
        sign = self.generate_sign(self.api_passwd, timestamp)
        url = f"{self.host}/v1/login?sign={sign}&timestamp={timestamp}"
        async with self.session.get(url) as response:
            if response.status == 200:
                data = await response.json()
                self.api_token = data.get('data', {}).get('token')
                if "error" in data:
                    self.parent.register_error(data['error'].get("code", 9999), data['error'].get("msg"))
                return {"token": self.api_token}
            else:
                error = await response.json()
                return {"error": error.get('error', "Unknown error")}

    async def _make_request(self, method, endpoint, retry=False, **kwargs):
        if not self.api_token:
            await self.login()
        if not self.api_token:
            return {"error": "Unable to re-authenticate"}
        url = f"{self.host}{endpoint}?token={self.api_token}"
        async with self.session.request(method, url, **kwargs) as response:
            # self.logger.info(response)
            if response.status == 200:
                resp = await response.json()
                #self.logger.info(resp)
                return resp
            elif response.status == 401 and not retry:
                # If token error, login and retry once
                login_result = await self.login()
                if "token" in login_result:
                    return await self._make_request(method, endpoint, retry=True, **kwargs)
                else:
                    return {"error": "Unable to re-authenticate"}
            else:
                error = await response.json()
                self.logger.info(error)
                return {"error": error.get('error', "Unknown error")}

    async def get_printer_system_info(self):
        return await self._make_request("GET", "/v1/printer/system")

    async def get_running_status(self):
        return await self._make_request("GET", "/v1/printer/runningstatus")

    async def get_basic_info(self):
        return await self._make_request("GET", "/v1/printer/basic")

    async def get_camera_info(self):
        return await self._make_request("GET", "/v1/printer/camera")

    async def get_left_nozzle_info(self):
        return await self._make_request("GET", "/v1/printer/nozzle1")

    async def get_right_nozzle_info(self):
        return await self._make_request("GET", "/v1/printer/nozzle2")

    async def get_current_job_info(self):
        return await self._make_request("GET", "/v1/job/currentjob")

    async def create_new_job(self, file_path):
        payload = {"file_path": file_path}
        self.logger.info("Creating new job")
        resp = await self._make_request("POST", "/v1/job/newjob/set", json=payload)
      #  self.logger.info(resp)
        if resp.get("status", 0):
            self.last_start_print_time = time.monotonic()
            self.parent.save_current_printer_job_id(self.current_job_id)
        return resp

    async def set_job_status(self, operate): # pause, resume, stop
        payload = {"operate": operate}
        return await self._make_request("POST", "/v1/job/currentjob/set", json=payload)

    def init_camera_url(self, camera_data):
        data = camera_data.get("data", {})
        if data.get("is_camera_connected", False):
            username = data.get("user_name", "")
            pwd = data.get("password", "")
            self.cam_url = f"http://{username}:{pwd}@{self.get_ip_or_hostname(self.host)}:30216/api/v1/camera/stream"
            if self.parent:
                self.parent.add_camera_url(self.cam_url)

    async def status_requesting(self):
        self.logger.info('Entering status loop...')
        retries_left = self.REQ_RETRY
        camera_resp = await self.get_camera_info()
        self.init_camera_url(camera_resp)
        while not self.stop_flag:
            loop_start_time = time.monotonic()
            if not self.session:
                break
            if not self.command_future and not self.print_future:
                try:
                    info_resp = await self.get_basic_info() # bed temp
                    left_nozzle_resp = await self.get_left_nozzle_info()
                    right_nozzle_resp = await self.get_right_nozzle_info()
                    status_resp = await self.get_running_status()
                    if self.last_start_print_time and time.monotonic() - self.last_start_print_time > self.JOB_START_WAIT_TIMEOUT:
                        self.logger.warning("Wait time for job start elapsed. Clean flags ...")
                        self.last_start_print_time = None
                    if info_resp.get("status", 0) == 1:
                        self.operational_flag = True
                        self.temps[0] = info_resp.get("data", {}).get("heatbed_cur_temp", self.temps[0])
                        self.ttemps[0] = info_resp.get("data", {}).get("heatbed_tar_temp", self.ttemps[0])
                    if left_nozzle_resp.get("status", 0) == 1:
                        self.temps[1] = left_nozzle_resp.get("data", {}).get("nozzle_cur_temp", self.temps[1])
                        self.ttemps[1] = left_nozzle_resp.get("data", {}).get("nozzle_tar_temp", self.ttemps[1])
                    if right_nozzle_resp.get("status", 0) == 1:
                        self.temps[2] = right_nozzle_resp.get("data", {}).get("nozzle_cur_temp", self.temps[2])
                        self.ttemps[2] = right_nozzle_resp.get("data", {}).get("nozzle_tar_temp", self.ttemps[2])
                    if status_resp.get("status", 0) == 1:
                        status = status_resp.get("data", {}).get("running_status")
                        self.printing_flag = status == "running"
                        self.paused_flag = status == "paused"
                        if self.printing_flag or self.paused_flag:
                            job_info_resp = await self.get_current_job_info()
                            if job_info_resp.get("status", 0) == 1:
                                self.current_job_id = job_info_resp.get("data", {}).get("job_id")
                                if self.parent:
                                    self.parent.printers_job_id = self.current_job_id
                                self.percent = round(job_info_resp.get("data", {}).get("print_progress", 0), 2)
                            if (self.last_start_print_time and self.current_job_id
                                    and time.monotonic() - self.last_start_print_time < self.JOB_START_WAIT_TIMEOUT):
                                self.last_start_print_time = None
                                self.logger.info("Save current_job_id " + self.current_job_id)
                                if self.parent:
                                    self.parent.printers_job_id = self.current_job_id
                                    self.parent.save_current_printer_job_id(self.current_job_id)
                            if self.parent:
                                self.parent.check_preconnect_printer_job()
                except:
                    retries_left -= 1
                    if not retries_left:
                        self.operational_flag = False
                        self.register_error(2000, self.SERVICE_NAME + ' connection lost', is_blocking=True)
            delta = time.monotonic() + self.LOOP_TIME - loop_start_time
            if delta > 0:
                await asyncio.sleep(delta)
        self.logger.info('Status loop exit')

    def parse_command_response(self, data):
        success = False
        if isinstance(data, dict):
            success = data.get('status', 0) == 1
            if not success:
                self.logger.error(
                    f"Command failed, desc: {data.get('error', {}).get('code')},  {data.get('error', {}).get('msg')}")
        return success

    def pause(self):
        return self.run_async_method(self.set_job_status, "pause")

    def resume(self):
        return self.run_async_method(self.set_job_status, "resume")

    def cancel(self):
        return self.run_async_method(self.set_job_status, "stop")

    def start_job(self, path):
        return self.run_async_method(self.create_new_job, path)

    def register_error(self, code, message, is_blocking=False, is_info=False):
        if self.parent:
            self.parent.register_error(code, message, is_blocking, is_info=is_info)
        else:
            self.logger.error(f"Register error: {code} {message} blocking={is_blocking} is_info={is_info}")

    def close(self):
        self.logger.info("Closing " + __class__.__name__)
        self.stop_flag = True
        if self.print_future:
            self.print_future.cancel()
        if self.command_future:
            self.command_future.cancel()
        self.status_thread.join(self.TIMEOUT + 0.1)
        if self.loop:
            try:
                self.loop.close()
            except RuntimeError:
                pass
        time.sleep(0.1)


class Sender(base_sender.BaseSender):
    API_PORT = 10800
    SSH_PORT = 22
    SSH_USERNAME = "root"
    # SSH_USERNAME = "user"
    UPLOAD_PATH = "/home/root"
    # UPLOAD_PATH = "/tmp/"
    LOOP_SLEEP_SEC = 5
    CONNECT_TIMEOUT = 6
    CHUNK_SIZE = 1024 * 1024
    UPLOAD_FILENAME = "3dprinteros.gcode"
    DEFAULT_TEMPS_LIST = [0.0, 0.0, 0.0]
    SUPPORT_JOBS = True

    def __init__(self, parent: typing.Any, usb_info: dict, profile: dict = {}):
        super().__init__(parent, usb_info, profile)
        self.serial_number = usb_info.get("SNR", "").upper()
        self.host = usb_info.get("IP", "")
        self.run_once = usb_info.get('RUNONCE')
        self.api_password = usb_info.get('PASS')
        self.ssh_password = usb_info.get('SSHP')
        if not self.run_once:
            self.connection = Raise3DConnection(self.usb_info['IP'],
                                                self.API_PORT,
                                                self.profile.get('timeout'),
                                                self, self.logger)

    def _upload(self, local_file_path, remote_file_path):
        ssh_client = paramiko.SSHClient()
        ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
        result = False
        try:
            ssh_client.connect(hostname=self.host, port=self.SSH_PORT, username=self.SSH_USERNAME, password=self.ssh_password)
            sftp = ssh_client.open_sftp()
            sftp.put(local_file_path, remote_file_path)
            self.logger.info(f"Successfully uploaded {local_file_path} to {remote_file_path}")
            sftp.close()
            result = True
        except Exception as e:
            self.logger.error(f"Failed to upload the file: {str(e)}")
        finally:
            ssh_client.close()
        return result

    def is_printing(self):
        try:
            return self.connection.printing_flag
        except AttributeError:
            return False

    def is_paused(self):
        try:
            return self.connection.paused_flag
        except AttributeError:
            return False

    def is_operational(self):
        try:
            return self.connection.operational_flag
        except AttributeError:
            return False

    def pause(self):
        if self.is_printing():
            try:
                success = self.connection.pause()
                if not success:
                    self.register_error(999, "Error pausing job")
                return success
            except AttributeError:
                pass
        return False

    def resume(self):
        try:
            success = self.connection.resume()
            if not success:
                self.register_error(998, "Error resuming job")
            return success
        except AttributeError:
            pass
        return False

    def unpause(self):
        return self.resume()

    def cancel(self):
        try:
            success = self.connection.cancel()
            if not success:
                self.register_error(997, "Error canceling job")
            return True
        except AttributeError:
            pass

    def get_percent(self):
        try:
            return self.connection.percent
        except AttributeError:
            return 0

    def unbuffered_gcodes(self, gcodes: str) -> bool:
        self.parent.register_error(505, "Sending single line gcodes is not supported", is_blocking=False)
        return False

    def gcodes(self, filepath: str, keep_file: bool = False) -> bool:
        errcode = 0
        errdesc = ""
        # we upload to linux so don't use path.join because on Windows 3DPrinterOS Client it will use '\'
        job_path = self.UPLOAD_PATH + "/" + self.UPLOAD_FILENAME
        if self._upload(filepath, job_path):
            try:
                if not self.connection.start_job("Local/" + self.UPLOAD_FILENAME):
                    errcode = 1000
                    errdesc = "Start Job Request failed"
            except Exception as e:
                errcode = 1002
                errdesc = str(e)
        else:
            errcode = 1001
            errdesc = "Upload Failed"
        if errcode:
            self.register_error(errcode, errdesc, is_blocking=True)
        return errcode == 0

    def camera_enable_hook(self, token=None):
        pass

    def camera_disable_hook(self):
        pass

    def set_next_print_options(self, options: dict):
        pass

    def get_temps(self):
        if self.connection:
            try:
                return self.round_temps_list(self.connection.temps)
            except:
                return self.DEFAULT_TEMPS_LIST
        else:
            return self.DEFAULT_TEMPS_LIST

    def get_target_temps(self):
        if self.connection:
            try:
                return self.round_temps_list(self.connection.ttemps)
            except:
                return self.DEFAULT_TEMPS_LIST
        else:
            return self.DEFAULT_TEMPS_LIST

    def close(self):
        if self.connection:
            self.connection.close()
            if self.connection.cam_url:
                self.remove_camera_url(self.connection.cam_url)

    def check_preconnect_printer_job(self):
        if self.preconnect_printer_job_id:
            if self.printers_job_id != self.preconnect_printer_job_id:
                self.save_current_printer_job_id(None)
                self.register_error(999,
                                    'After reconnection to the printer, it is not running cloud print job. Assuming that job had failed, when print was offline.',
                                    is_blocking=True)
            self.preconnect_printer_job_id = None

