# Copyright 3D Control Systems, Inc. All Rights Reserved 2017-2019.
# Built in San Francisco.

# This software is distributed under a commercial license for personal,
# educational, corporate or any other use.
# The software as a whole or any parts of it is prohibited for distribution or
# use without obtaining a license from 3D Control Systems, Inc.

# All software licenses are subject to the 3DPrinterOS terms of use
# (available at https://www.3dprinteros.com/terms-and-conditions/),
# and privacy policy (available at https://www.3dprinteros.com/privacy-policy/)

import os
import logging
import time
import json
import zipfile
import http.client
import requests
import threading

import version
import config
import log
import platforms
import paths
import client_ssl_context


class Updater:

    DEFAULT_UPDATE_CHECK_PAUSE = 300 # five minutes
    IDLE_TIMEOUT_FOR_AUTOUPDATE = 15
    LOOP_SLEEP_STEPS = 20
    UPDATE_URL_REQUEST_TIMEOUT = 6
    LOOP_TIME = 2
    OPENING_TAG = "<"
    DOWNLOAD_RETRIES = 5
    STATES_THAT_DELAY_AUTOUPDATE = ('printing', 'paused', 'local_mode')
    DOWNLOAD_CHUNK_SIZE = 64 * 1024

    def __init__(self, parent):
        self.parent = parent
        self.logger = logging.getLogger(__name__)
        self.update_download_url = None
        self.update_available = False
        self.update_downloaded = False
        self.check_time = 0
        self.auto_update_start_thread = None
        self.loop_thread = None
        if config.get_settings()['protocol']['encryption']:
            self.port = 443
            self.connection_class = http.client.HTTPSConnection
            self.connection_kwargs = {'context': client_ssl_context.SSL_CONTEXT}
        else:
            self.connection_class = http.client.HTTPConnection
            self.connection_kwargs = {}
            self.port = 80
        self.forced_update_request_args = {}
        custom_port = config.get_settings()['update'].get('custom_port')
        if custom_port:
            self.port = int(custom_port)
        if config.get_settings()['update'].get('no_url_prefix', False):
            self.url = config.get_settings()['URL']
        else:
            try:
                self.url = "%s-update.%s.%s" % tuple(config.get_settings()['URL'].split('.'))
            except:
                self.logger.info("Strange update URL detected. Falling back to base URL")
                self.url = config.get_settings()['URL']

    def check_update_timer(self):
        current_time = time.monotonic()
        if current_time - self.check_time > config.get_settings().get('update', {}).get('check_pause', self.DEFAULT_UPDATE_CHECK_PAUSE):
            if self.check_for_updates():
                self.check_time = current_time

    def checks_loop(self):
        while not self.parent.stop_flag:
            self.check_update_timer()
            for _ in range(self.LOOP_SLEEP_STEPS):
                time.sleep(self.LOOP_TIME/self.LOOP_SLEEP_STEPS)
                if self.parent.stop_flag:
                    break

    def start_checks_loop(self):
        if config.get_settings()['update']['enabled']:
            self.loop_thread = threading.Thread(target=self.checks_loop)
            self.loop_thread.start()

    def check_for_updates(self, force_update_and_restart = False):
        # return value is not availability of update, but absences of errors when requesting update
        if config.get_settings()['update']['enabled']:
            self.update_download_url = self.request_update_download_url()
            if self.update_download_url:
                self.logger.info('Update available!')
                self.update_available = True
                if config.get_settings()['update']['auto_update_enabled'] or force_update_and_restart:
                    if not (self.auto_update_start_thread and self.auto_update_start_thread.is_alive()) and not self.parent.stop_flag:
                        self.auto_update_start_thread = threading.Thread(target = self.auto_update, args=(not force_update_and_restart,))
                        self.auto_update_start_thread.start()
                return True
            elif self.update_download_url == "":
                return True
            return False
        else:
            time.sleep(6)

    def request_update_download_url(self):
        current_version = version.version
        payload_dict = { 'platform': platforms.get_platform(),\
                         'client_version': current_version,\
                         'branch': version.branch,\
                         'build': version.build }
        payload_dict.update(self.forced_update_request_args)
        payload = json.dumps(payload_dict)
        headers = {"Content-Type": "application/json", "Content-Length": str(len(payload.encode('utf-8')))}
        try:
            #self.logger.debug(f'Requesting update from {self.url}. With {payload}')
            connection = self.connection_class(self.url, port=self.port,
                         timeout=self.UPDATE_URL_REQUEST_TIMEOUT, **self.connection_kwargs)
            connection.connect()
            connection.request('POST', '/noauth/get_client_update_url', payload, headers)
            resp = connection.getresponse()
            if resp.status == 200:
                update_download_url = resp.read().decode("utf-8")
            else:
                update_download_url = None
            self.logger.debug(f'Requested update from {self.url}. With {payload}. Resp: {update_download_url}')
            connection.close()
        except Exception as e:
            self.logger.warning('Unable to connect to updater server. ' + str(e))
        else:
            if update_download_url:
                if not self.OPENING_TAG in update_download_url:
                    self.logger.info('Update download URL received: ' + str(update_download_url))
                    return update_download_url
                else:
                    self.logger.warning('Received invalid update URL: ' + str(update_download_url))
            elif update_download_url == "":
                return update_download_url

    @log.log_exception
    def auto_update(self, wait_for_prints_end=True):
        self.logger.info('Starting automatic update')
        error = self.download_update()
        if not error:
            if wait_for_prints_end:
                self.wait_for_prints_end()
            else:
                self.logger.info('Restarting 3DPrinterOS to install the update')
                self.parent.stop_flag = True

    def wait_for_prints_end(self):
        self.logger.info('Waiting for prints to end to install update...')
        last_non_idle_time = time.monotonic()
        while not self.parent.stop_flag:
            for pi in getattr(self.parent, "printer_interfaces", []):
                if pi.printer_profile and pi.printer_profile.get('self_printing', False):
                    if pi.get_printer_state() in self.STATES_THAT_DELAY_AUTOUPDATE:
                        last_non_idle_time = time.monotonic()
                        break
            else:
                if time.monotonic() - last_non_idle_time > self.IDLE_TIMEOUT_FOR_AUTOUPDATE:
                    self.logger.info('All prints are finished. Restarting 3DPrinterOS to install the update')
                    self.parent.stop_flag = True
            time.sleep(1)

    def download_update(self):
        if not self.update_download_url:
            self.update_download_url = self.request_update_download_url()
        if self.update_download_url:
            if os.path.exists(paths.UPDATE_FILE_PATH) and zipfile.is_zipfile(paths.UPDATE_FILE_PATH):
                self.logger.info('Update package already downloaded')
                return "Update will be applied after restart"
            else:
                self.logger.info('Downloading update package...')
                error_message = self.chunked_download(self.update_download_url)
                if error_message:
                    return error_message
                else:
                    return self.check_downloaded_update_zip()
        return "No update available"

    def chunked_download(self, url):
        self.logger.info(f"Downloading from {url}")
        try:
            output_file = open(paths.UPDATE_FILE_PATH, "wb")
        except OSError:
            message = 'Error opening file to save an update package'
            self.logger.info(message)
            return message
        self.download_percent = 0
        resume_byte_pos = 0
        retry = 0
        while retry < self.DOWNLOAD_RETRIES and not self.parent.stop_flag:
            if retry:
                self.logger.warning("Download retry/resume N%d" % retry)
            self.logger.info("Connecting to " + self.url)
            resume_header = {'Range': 'bytes=%d-' % resume_byte_pos}
            try:
                request = requests.get(url,\
                                       headers = resume_header,\
                                       stream=True,\
                                       timeout = self.UPDATE_URL_REQUEST_TIMEOUT)
            except Exception as e:
                request = None
                self.parent.register_error(65, "Unable to open download link: " + str(e), is_blocking=False)
            else:
                self.logger.info("Successful connection to " + self.url)
                download_length = int(request.headers.get('content-length', 0))
                if download_length:
                    downloaded_size = self.download_chunks(request, download_length, output_file)
                    resume_byte_pos += downloaded_size
                    self.logger.info("Downloaded %d bytes" % resume_byte_pos)
                    if downloaded_size == download_length:
                        output_file.close()
                        return
            finally:
                if request:
                    request.close()
                retry += 1
                time.sleep(1)
        #remove partially downloaded file
        try:
            output_file.close()
            os.remove(paths.UPDATE_FILE_PATH)
        except OSError:
            pass
        return 'Update package download canceled'

    def download_chunks(self, response, download_length, tmp_file):
        downloaded_bytes = 0
        try:
            for chunk in response.iter_content(self.DOWNLOAD_CHUNK_SIZE):
                if self.parent.stop_flag:
                    self.logger.info('Download canceled')
                    return downloaded_bytes
                downloaded_bytes = response.raw.tell()
                percent = round(min(downloaded_bytes / download_length, 1.0) * 100, 2)
                self.logger.info(f'File downloading: {downloaded_bytes // 1024}kB {percent}%')
                tmp_file.write(chunk)
        except Exception as e:
            self.parent.register_error(69, 'Download error: chunk error: ' + str(e), is_blocking=False)
        return downloaded_bytes

    def check_downloaded_update_zip(self):
        if zipfile.is_zipfile(paths.UPDATE_FILE_PATH):
            self.logger.info('...update successfully downloaded!')
            self.update_available = False
            self.update_downloaded = True
        else:
            self.update_download_url = None
            error_message = 'Error: corrupted update package.'
            self.logger.warning(error_message)
            try:
                os.remove(paths.UPDATE_FILE_PATH)
            except:
                self.logger.warning("Unable to remove corrupted update package.")
            else:
                self.logger.warning("Corrupted update package removed.")
            return error_message

    def update(self):
        self.check_for_updates(force_update_and_restart=True)


if __name__ == "__main__":

    logging.basicConfig()
    logging.getLogger().setLevel(logging.INFO)

    class FakeApp:
        def __init__(self):
            self.stop_flag = False

    import argparse

    parser = argparse.ArgumentParser(description='Check for 3DPrinterOS Client updates, download and apply them available. Can be used to switch software branch.')
    parser.add_argument('--version', '-v', default=None)
    parser.add_argument('--branch', '-b', default=None)
    parser.add_argument('--platform', '-p', default=None)
    parser.add_argument('--force', '-f', action='store_true')
    parser.add_argument('--download', '-d', action='store_true')
    args = parser.parse_args()
    fa = FakeApp()
    u = Updater(fa)
    if args.force:
        u.forced_update_request_args['client_version'] = '0.0.0'
        if os.path.exists(paths.UPDATE_FILE_PATH):
            os.remove(paths.UPDATE_FILE_PATH)
    elif args.version:
        u.forced_update_request_args['client_version'] = args.version
    if args.branch:
        u.forced_update_request_args['branch'] = args.branch
    if args.platform:
        u.forced_update_request_args['platform'] = args.platform
    u.check_for_updates()
    if u.update_download_url:
        if args.download:
            print(f"Downloading update package from {u.update_download_url} to " + paths.UPDATE_FILE_PATH)
            download_error = u.download_update()
            if download_error:
                print("Update download error:", download_error)
            else:
                print("Update download success. Restart 3DPrinterOS Client to update")

        else:
            print('Update package available at:' + u.update_download_url)
    else:
        print('No update available')
