import asyncio
import aiohttp
import aiohttp.client
import collections
import io
import json
import logging
import threading
import time
from pprint import pformat

import base_sender
import log


#Connection and Sender classed are not property separated, but actually the connection code appeared
#to be too Repetier Server specific. So much that there is not profit from the separation.
class RepetierServerConnection:

    PORT = 3344
    HOST_MASK = "http://%s:%d"
    WS_HOST_MASK = "ws://%s:%d/socket"
    APIKEY_HEADER_NAME = "x-api-key"
    INFO_PATH = "/printer/info"
    UPLOAD_PATH = "/printer/model/" 
    UPLOAD_FILENAME = "3dprinteros" # don't use dot in name and be very cautious with length and special characters 
    UPLOAD_TIMEOUT = 600 # 10 minutes to upload a file to a repetier server
    MAX_MESSAGE_SIZE = 0x2000 # 8kB(repetier server's max WS message)
    MAX_MESSAGE_INDEX = 0xffff # 65535(16 bit int max)
    TIMEOUT = 2
    LOOP_TIME = 0.1     # please tune this with extreme caution as printer can disconnect or delay messages for minutes 
    READS_PER_WRITE = 3 # if this value will be messed up. After change test disconnection and message delay!

    E1_TEMP_ID = 0
    E2_TEMP_ID = 1
    BED_TEMP_ID = 1000
    
    def __init__(self, host, port, timeout=0, cancel_gcodes=[], printer_interface=None, logger=None):
        self.stop_flag = False
        self.host = self.HOST_MASK % (host, port)
        self.ws_host_plus_path = self.WS_HOST_MASK % (host, port)
        if not timeout:
            timeout = self.TIMEOUT
        self.timeout = timeout
        self.cancel_gcodes = cancel_gcodes
        self.printer_interface = printer_interface
        self.loop = None
        self.send_now_buffer = collections.deque()
        self.ws_send_deque = collections.deque()
        self.print_future = None
        self.apikey = None
        self.slug = None
        self.operational_flag = False
        self.printing_flag = False
        self.paused_flag = False
        self.session = None
        self.message_index = 0
        self.e1_temp = 0.0
        self.e1_ttemp = 0.0
        self.e2_temp = 0.0
        self.e2_ttemp = 0.0
        self.bed_temp = 0.0
        self.bed_ttemp = 0.0
        self.rs_model_id = None
        self.rs_job_id = None
        self.total_lines = 0
        self.current_line = 0
        self.percent = 0
        #self.last_listPrinter_command_id = None
        if logger:
            self.logger = logger.getChild('connection')
        else:
            self.logger = logging.getLogger(self.__class__.__name__)
        self.connect_thread = threading.Thread(target=self.create_connect_thread)
        self.connect_thread.start()

    def create_connect_thread(self):
        try:
            self.loop = asyncio.get_running_loop()
        except RuntimeError:
            self.logger.info("Staring event loop...")
            self.loop = asyncio.new_event_loop()
            asyncio.set_event_loop(self.loop)
        while not self.stop_flag:
            self.loop.run_until_complete(self.connect())

    def upload(self, f):
        self.logger.info("Creating upload gcodes task...")
        self.print_future = asyncio.run_coroutine_threadsafe(self.send_file_to_print(f), self.loop)
        self.logger.info("Waiting for upload gcodes to finish...")
        try:
            upload_error = self.print_future.result(self.UPLOAD_TIMEOUT)
        except asyncio.TimeoutError:
            self.print_future.cancel()
            upload_error = None
        except asyncio.CancelledError:
            upload_error = 'Cancelled'
        if upload_error:
            self.logger.info("Upload error:" + str(upload_error))
        self.print_future = None
        return not bool(upload_error)

    def get_apikey_headers(self):
        return {self.APIKEY_HEADER_NAME: self.apikey}

    def parse_upload_respose(self, response):
        data = response.get('data')
        if isinstance(data, list):
            for gcode_file in data:
                if gcode_file['name'] == self.UPLOAD_FILENAME:
                    self.rs_model_id = gcode_file['id']
        
    async def send_file_to_print(self, f):
        self.percent = 0 # to prevent 100 from older job to be threared as job end
        if not self.apikey:
            return 'No API key'
        form = aiohttp.FormData()
        form.add_field('a', 'upload')
        form.add_field('name', self.UPLOAD_FILENAME)
        form.add_field('filename', f, filename=self.UPLOAD_FILENAME, content_type='text/x-gcode')
        form.add_field('overwrite', 'true')
        #form.add_field('group', '*') # useless in complication in our case  
        self.logger.info('Uploading file to RepetierServer')
        timeout = aiohttp.ClientTimeout(total=self.UPLOAD_TIMEOUT)
        async with self.session.post(self.host + self.UPLOAD_PATH + self.slug,\
                                     headers=self.get_apikey_headers(),\
                                     data=form,\
                                     timeout=timeout) as resp:
            if resp.status == 200:
                body = ""
                try:
                    body = await resp.text()
                    json_resp = json.loads(body)
                    self.parse_upload_respose(json_resp)
                    self.logger.debug('Upload response:' + pformat(json_resp))
                except:
                    self.logger.exception("Error parsing upload response:" + str(body))
            else:
                return "Server responded with error: %s %s" % (str(resp.status), str(getattr(resp, "body", "")))

    async def connect(self):
        self.logger.info("Connecting to repetier server at %s" % self.host)
        self.session = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=self.timeout), loop=self.loop)
        try:
            async with self.session.get(self.host + self.INFO_PATH) as resp:
                if resp.status == 200:
                    resp_dict = await resp.json()
                    self.apikey = resp_dict.get('apikey')
                    printers = resp_dict.get('printers', [])
                    if printers:
                        self.slug = printers[0].get('slug')
            if self.apikey and self.slug:
                self.logger.debug('Got apikey and slug:%s %s' % (self.apikey, self.slug))
                self.operational_flag = True
                await self.ws_communication_loop()
        except:
            self.logger.error("Repetier server connection to %s failed" % str(self.host))
        finally:
            self.logger.info("Closing connection to RepetierServerConnection at " + str(self.host))
            await self.session.close()
            self.operational_flag = False
            if not self.stop_flag:
                self.register_error(2000, "Connection problems with Repetier Server at %s" % self.host,
                                           is_blocking=False)

    async def ws_communication_loop(self):
        self.logger.debug('Entering WS loop...')
        self.logger.info("Connected to RepetierServer at %s" % self.host)
        listPrinterRequest = False
        counter = self.READS_PER_WRITE
        async with self.session.ws_connect(self.ws_host_plus_path,\
                                           headers=self.get_apikey_headers()) as ws:
            while not self.stop_flag:
                msg = await ws.receive()
                #self.logger.debug('Received on WS:' + str(msg.data))
                if self.stop_flag:
                    await ws.close()
                    self.logger.debug('Exiting WS loop due to stop flag...')
                    break
                elif msg.type == aiohttp.WSMsgType.TEXT:
                    #self.logger.debug('Received on WS:' + pformat(msg.json()))
                    self.parse_message(msg.data)
                elif msg.type == aiohttp.WSMsgType.ERROR:
                    self.logger.info('Exiting WS loop due to error message...')
                    break
                elif msg.type == aiohttp.WSMsgType.CLOSED:
                    self.logger.info('Exiting WS loop due to server socket close...')
                    break
                if counter == self.READS_PER_WRITE:
                    counter = 0
                    message = {"action": "ping",\
                               "data": {},\
                               "printer": self.slug,\
                               "callback_id": self.get_message_index()}
                    listPrinterRequest = not listPrinterRequest
                    if listPrinterRequest:
                        message["action"] = "listPrinter"
                        #self.last_listPrinter_command_id = message['callback_id']
                    elif self.ws_send_deque:
                        action, data = self.ws_send_deque.popleft()
                        message['action'] = action
                        if data:
                            message['data'] = data
                    elif self.send_now_buffer:
                        data = self.send_now_buffer.popleft()
                        if data:
                            message['action'] = 'send'
                            message['data'] = {"cmd" : data}
                    elif self.rs_model_id != None:
                        message['action'] = 'copyModel'
                        message['data'] = {"id": self.rs_model_id, "autostart": True, "overwrite": True}
                        self.rs_model_id = None
                    if message['action'] != "ping":
                        self.logger.debug("Sending to WS:" + str(message))
                    await ws.send_json(message)
                counter += 1
                await asyncio.sleep(self.LOOP_TIME)
        self.logger.debug('WS loop finish')

    def parse_message(self, message):
        try:
            message = json.loads(message)
        except:
            self.logger.warning("Received message is not valid JSON:" + str(message))
        else:
            self.logger.debug('Parsing WS message:' + pformat(message))
            try:
                while 'data' in message:
                    if message.get('error'):
                        self.register_error(2002, "Unknown Repetier Server error: %s %s" %\
                                                (str(message), str(message.get['error'], "")))
                    message = message.get('data')
                if 'list' in message:
                    message = message['list']
                if type(message) == list:
                    for data in message:
                        self.parse_data_field(data)
                elif type(message) == dict:
                    self.parse_data_field(message)
                else:
                    self.logger.debug("Strange message:\n" % pformat(message))
            except Exception:
                self.logger.exception('Exception while parsing WS message:')

    def parse_data_field(self, data):
        printer_name = data.get('printer') or data.get('slug')
        if not printer_name or printer_name == self.slug:
            if data.get('error'):
                self.register_error(2002, "Unknown Repetier Server error: %s %s" % (str(data.get['data'], ""),\
                                                                                    str(data.get['error'], "")))
            self.parse_list_printer(data)
            event = data.get('event')
            if event:
                self.parse_event(data.get('data', {}), event)
        else:
            self.logger.warning("Suspicious printer name received from WS: " + str(data))

    def parse_event(self, data, event):
        if event == 'temp': 
            temp = data.get('T', 0.0)
            ttemp = data.get('S', 0.0)
            temp_id = data.get('id')
            if temp_id == self.BED_TEMP_ID:
                self.bed_temp = temp
                self.bed_ttemp = ttemp
            elif temp_id == self.E1_TEMP_ID:
                self.e1_temp = temp
                self.e1_ttemp = ttemp
            elif temp_id == self.E2_TEMP_ID:
                self.e2_temp = temp
                self.e2_ttemp = ttemp
        elif event == 'jobStarted':
            self.printing_flag = True
        elif event == 'jobFinished':
            self.printing_flag = False
            self.paused_flag = False
            self.percent = 100
            self.remove_job()
        elif event == 'jobKilled':
            self.printing_flag = False
            self.paused_flag = False
            self.remove_job()
        elif event == 'changeFilamentRequested':
            self.register_error(2000, "Filament change requested", is_blocking=False, is_info=True)
            self.paused_flag = True

    def parse_list_printer(self, data):
        job = data.get("job")
        if job:
            if job == 'none':
                self.printing_flag = False
                self.paused_flag = False
                self.total_lines = 0
                self.current_line = 0
            else:
                self.printing_flag = True
                self.paused_flag = bool(data.get("paused"))
                try:
                    if job == self.UPLOAD_FILENAME:
                        self.rs_job_id = data.get('jobid', None)
                    self.total_lines = int(data.get('totalLines', 0))
                    self.current_line = int(data.get('linesSend', 0))
                    self.percent = int(data.get('done', 0))
                except TypeError:
                    self.logger.warning("Can't parse message:" + str(data))

    def get_message_index(self):
        if self.message_index > self.MAX_MESSAGE_INDEX:
            self.message_index = 0
        self.message_index += 1
        return self.message_index

    def remove_job(self, job_id=None):
        if not job_id:
            job_id = self.rs_job_id
        if job_id is not None:
            self.send_to_ws("removeJob", {"id": job_id})
        else:
            self.logger.warning("Unable to remove job without id")

    def send_to_ws(self, command, data={}):
        self.ws_send_deque.append((command, data))

    def pause(self):
        return self.send_now_buffer.append("@pause")

    def unpause(self):
        return self.send_to_ws("continueJob")

    def cancel(self):
        self.send_to_ws("stopJob") 
        self.remove_job()
        self.send_now_buffer.extend(self.cancel_gcodes)

    def register_error(self, code, message, is_blocking=False, is_info=False):
        if self.printer_interface:
            self.printer_interface.register_error(code, message, is_blocking, is_info=is_info)

    def close(self):
        self.logger.info("Closing " + __class__.__name__)
        self.stop_flag = True
        if self.print_future:
            self.print_future.cancel()
        self.connect_thread.join(self.TIMEOUT + 0.1)
        # Zero-sleep to allow underlying connections to close
        if self.loop:
            try:
                self.loop.close()
            except RuntimeError:
                pass


class Sender(base_sender.BaseSender):

    CONNECTION_CLASS = RepetierServerConnection
    OK_TIMEOUT = 2

    def __init__(self, parent, usb_info, profile):
        super().__init__(parent, usb_info, profile)
        self.connect()

    def connect(self):
        self.connection = self.CONNECTION_CLASS(self.usb_info['IP'],\
                                                self.profile.get('port', 3344),\
                                                self.profile.get('timeout'),\
                                                self.preprocess_gcodes(self.profile.get('end_gcodes', "")),
                                                self.parent)
        time.sleep(self.OK_TIMEOUT)
        if not self.connection.operational_flag:
            self.connection.close()
            self.connection = None
            raise RuntimeError("No connection to host:%s" % self.usb_info['IP'])

    def gcodes(self, filename, keep_file=False):
        if self.connection:
            with open(filename, 'wb') as f:
                self.logger.info("Uploading gcodes on RepetierServer")
                self.connection.upload(f)
                self.logger.info("Print job started")
                return True
        else:
            self.logger.info('Error: no connection')
        return False

    def unbuffered_gcodes(self, gcodes):
        self.logger.info("Gcodes to send now: " + str(gcodes))
        if self.connection:
            if isinstance(gcodes, bytes):
                gcodes = gcodes.decode('utf-8')
            gcodes = gcodes.split("\n")
            for gcode in gcodes:
                if not gcode.endswith("\n"):
                    gcode += "\n"
                self.connection.send_now_buffer.append(gcode)
        else:
            return False

    def is_printing(self):
        return bool(self.connection and self.connection.printing_flag)

    def is_paused(self):
        return bool(self.connection and self.connection.paused_flag)

    def is_operational(self):
        return bool(self.connection and self.connection.operational_flag)

    def pause(self):
        if self.is_printing():
            #if not self.in_relative_pos_mode:
            #    self.send_now(b"G91")
            #self.send_now(b"G1 Z+%d E-%d" % (self.PAUSE_LIFT_HEIGHT, self.PAUSE_EXTRUDE_LENGTH))
            self.connection.pause()
            return True
        return False

    def unpause(self):
        if self.is_paused():
            #self.send_now(b"G1 Z-%d E+%d" % (self.PAUSE_LIFT_HEIGHT, self.PAUSE_EXTRUDE_LENGTH))
            #if not self.was_in_relative_before_pause:
            #    self.send_now(b"G90")
            self.connection.unpause()
            return True
        return False

    def cancel(self):
        if self.is_printing():
            self.connection.cancel()
            return True
        return False

    def get_percent(self):
        try:
            return self.connection.percent
        except AttributeError:
            return 0

    def get_current_line_number(self):
        try:
            return self.connection.current_line
        except AttributeError:
            return 0

    def get_total_gcodes(self):
        try:
            total_gcodes = self.connection.total_lines
        except AttributeError:
            total_gcodes = 0
        return total_gcodes

    def get_temps(self):
        if self.connection:
            temps = self.round_temps_list([self.connection.bed_temp,\
                     self.connection.e1_temp,\
                     self.connection.e2_temp])
        else:
            temps = [0.0, 0.0, 0.0]
        return temps

    def get_ttemps(self):
        if self.connection:
            temps = self.round_temps_list([self.connection.bed_ttemp,\
                     self.connection.e1_ttemp,\
                     self.connection.e2_ttemp])
        else:
            temps = [0.0, 0.0, 0.0]
        return temps

    def close(self):
        if self.connection:
            self.connection.close()


if __name__ == "__main__":

    def state(self):
        if self.stop_flag:
            state = "connecting"
        elif self.is_paused():
            state = "paused"
        elif self.is_printing():
            state = "printing"
        elif self.is_operational():
            state = "ready"
        else:
            state = "error"
        return state

    logging.basicConfig(format='', level=logging.DEBUG)
    sender = None
    homing = True # home at start
    printing = True # send gcodes to print
    pause = True # pause in ~30 seconds and continue in ~60
    cancel = False # cancel in ~85 seconds
    exit = False # exit in ~125 seconds
    pause_on_counter = 15
    continue_on_counter = 30
    cancel_on_counter = 40
    exit_on_counter = 60
    counter = 0
    try:
        sender = Sender(None, {"IP": "192.168.1.218"}, {"port": 3344})
        while True:
            if homing:
                sender.unbuffered_gcodes("G28")
                homing = False
            elif printing:
                with open("/tmp/1.gcode", "rb") as f:
                    sender.gcodes(f)
                printing = False
            print("State:", state(sender), sender.get_percent())
            time.sleep(2)
            counter += 1
            if pause and counter == pause_on_counter:
                sender.pause()
            elif pause and counter == continue_on_counter:
                sender.unpause()
            elif cancel and counter == cancel_on_counter:
                sender.cancel()
            elif exit and counter == exit_on_counter:
                sender.close()
                time.sleep(5)
                sender = None
                break
    except KeyboardInterrupt:
        if sender:
            sender.close()
