# Copyright 3D Control Systems, Inc. All Rights Reserved 2017-2019.
# Built in San Francisco.

# This software is distributed under a commercial license for personal,
# educational, corporate or any other use.
# The software as a whole or any parts of it is prohibited for distribution or
# use without obtaining a license from 3D Control Systems, Inc.

# All software licenses are subject to the 3DPrinterOS terms of use
# (available at https://www.3dprinteros.com/terms-and-conditions/),
# and privacy policy (available at https://www.3dprinteros.com/privacy-policy/)

import io
import os
import tempfile
import threading
import time
import zipfile

import requests
import certifi

import config
import log
import paths


class Downloader(threading.Thread):

    CONNECTION_TIMEOUT = 6
    MAX_RETRIES = 5
    DOWNLOAD_CHUNK_SIZE = 128*1024 #128kB
    MAX_JOB_STARTS_CHECK_PERIOD = 6

    def __init__(self, parent, url, callback, is_zip):
        self.logger = parent.logger.getChild(self.__class__.__name__)
        self.parent = parent
        self.url = url
        self.callback = callback
        self.in_memory_gcodes = False
        #TODO investigate the key error bug in beta 7.42.6
        #get is for temporary fix only
        if config.get_settings().get('in_memory_gcodes'):
            try:
                self.in_memory_gcodes = getattr(parent.sender, 'print_bytes')
                if self.in_memory_gcodes:
                    self.logger.info('In memory gcodes mode enabled')
            except:
                pass
        self.is_zip = is_zip
        self.cancel_flag = False
        self.download_size = 0
        self.downloaded_bytes = 0
        self.written_bytes = 0 
        self.percent = 0.0
        threading.Thread.__init__(self, name="Downloader", daemon=True)

    def check_wait_or_record_jobs_start(self, retry=900):
        max_jobs_start_same_time = config.get_settings().get('max_parallel_starts')
        if max_jobs_start_same_time:
            try:
                with open(paths.JOB_STARTS, "r+") as f:
                    jobs_text = f.read()
                    jobs_count = jobs_text.count('\n')
                    if jobs_count < max_jobs_start_same_time:
                        f.write(self.parent.id_string + "\n")
                        return True
            except OSError:
                pass
            sleeps_counter = 20
            while not self.parent.stop_flag and sleeps_counter:
                time.sleep(self.MAX_JOB_STARTS_CHECK_PERIOD/20)
                sleeps_counter -= 1 
            retry -= 1
            if retry:
                self.check_wait_or_record_jobs_start(retry)
            else:
                return True

    def erase_job_start_record(self, retry=900):
        max_jobs_start_same_time = config.get_settings().get('max_parallel_starts')
        if max_jobs_start_same_time:
            try:
                with open(paths.JOB_STARTS, "w+") as f:
                    jobs_text = f.read()
                    f.write(jobs_text.replace(self.parent.id_string + "\n", ''))
            except OSError:
                sleeps_counter = 1000
                while not self.parent.stop_flag and sleeps_counter:
                    time.sleep(0.1)
                    sleeps_counter -= 1 
            retry -= 1
            if retry:
                self.erase_job_start_record(retry)
            else:
                return True

    @log.log_exception
    def run(self):
        self.logger.info('Starting downloading')
        downloaded = self.download()
        if downloaded and self.callback:
            self.check_wait_or_record_jobs_start()
            if self.in_memory_gcodes:
                self.logger.info('In memory gcodes mode enabled. Overriding download complete callback with load_text')
                self.parent.sender.print_bytes(downloaded)
            else:
                sender = self.callback.__self__
                if self.is_zip and not getattr(sender, 'NATIVE_ZIP_HANDLING', False):
                    downloaded = sender.unzip_file(downloaded, not config.get_settings().get('keep_print_files', False))
                self.execule_callback(downloaded)
            self.erase_job_start_record()
        if self.cancel_flag:
            self.logger.info('Cancel command was received after printing start in downloading thread')
            try:
                self.parent.sender.cancel()
            except AttributeError:
                pass
        self.logger.info('Downloading finished')

    def execule_callback(self, f):
        if self.cancel_flag:
            self.logger.info('Cancel command received')
            self.cancel_flag = False
        elif not f:
            self.parent.register_error(67, "Unknown download error", job_fail=True)
        else:
            self.callback(f)

    def download(self):
        if config.get_settings().get('hide_sensitive_log'):
            self.logger.info("Downloading from __hidden__")
        else:
            self.logger.info("Downloading from " + self.url)
        if self.is_zip:
            suffix = ".zip"
        else:
            suffix = ".gcode"
        if self.in_memory_gcodes:
            tmp_file = io.BytesIO()
            filename = None
        else:
            tmp_file = tempfile.NamedTemporaryFile(mode='wb', dir=paths.DOWNLOAD_FOLDER,
                delete=False, prefix='3dprinteros-', suffix=suffix)
            filename = tmp_file.name
        retry = 0
        compression = None
        while retry < self.MAX_RETRIES and not self.parent.stop_flag and not self.cancel_flag:
            if retry:
                self.logger.warning("Download retry/resume N" + str(retry))
            self.logger.info("Connecting to server...")
            headers = { 'Accept-Encoding': 'identity, deflate, compress, gzip',
                     'Accept': '*/*', 'User-Agent': 'python-requests/{requests.__version__}'}
            if self.downloaded_bytes:
                if compression:
                    self.downloaded_bytes = 0
                    self.percent = 0.0
                    self.written_bytes = 0
                    tmp_file.truncate(0)
                    self.logger.info(f'Unable to resume with compression {compression}. Restarting download')
                else:
                    headers['Range'] = 'bytes=%d-' % self.downloaded_bytes
                    self.logger.info(f'Resuming download from {self.downloaded_bytes}')
            try:
                response = requests.get(self.url, headers = headers, stream=True, timeout = self.CONNECTION_TIMEOUT, verify=certifi.where())
            except Exception as e:
                response = None
                self.parent.register_error(65, "Unable to open download link: " + str(e), job_fail=False)
                time.sleep(self.CONNECTION_TIMEOUT*(retry+1))
            else:
                self.logger.info('Response headers:' + str(response.headers))
                if not response.ok:
                    self.parent.register_error(68, f'Download error: HTTP status not OK, but {response.status_code}', job_fail=False)
                else:
                    if not self.download_size:
                        self.download_size = int(response.headers.get('content-length', 0))
                        self.logger.info(f"Starting download of {self.download_size}B")
                        compression = response.headers.get('Content-Encoding')
                        if compression:
                            self.logger.info("Download compression encoding: " + str(compression))
                    self.downloaded_bytes += self.download_chunks(response, tmp_file)
                    self.logger.info(f"Downloaded {self.downloaded_bytes}B")
                    if self.downloaded_bytes == self.download_size:
                        self.logger.info(f'Success. Downloaded: {self.download_size}B. Wrote: {self.written_bytes}B')
                        if self.in_memory_gcodes:
                            tmp_file.seek(0)
                            if self.is_zip:
                                with zipfile.ZipFile(tmp_file, 'r') as zf:
                                    namelist = zf.namelist()
                                    if namelist:
                                        name = namelist[0]
                                        with zf.open(name, 'r') as f:
                                            ret = f.read()
                                    else:
                                        self.parent.register_error(64, "Download error: empty zip file", job_fail=False)
                                        ret = b""
                            else:
                                ret = tmp_file.read()
                        else:
                            ret = filename
                        tmp_file.close()
                        return ret
                    elif self.downloaded_bytes > self.download_size:
                        self.parent.register_error(66, "Download error: data is corrupted. Expected: {self.download_size}B. Downloaded: {self.download_size}B. Wrote: {self.written_bytes}B", job_fail=False)
                        break
                    else: 
                        self.parent.register_error(66, "Download error: connection was lost. Expected: {self.download_size}B. Downloaded: {self.download_size}B. Wrote: {self.written_bytes}B", job_fail=False)
            finally:
                if response:
                    response.close()
                retry += 1
                time.sleep(1)
        self.parent.register_error(66, 'Download error: unable to complete download', job_fail=True)
        try:
            tmp_file.close()
        except:
            pass
        try:
            if filename:
                os.remove(filename)
        except:
            pass

    def download_chunks(self, response, tmp_file):
        downloaded_bytes = 0
        prev_percent = 0
        try:
            for chunk in response.iter_content(self.DOWNLOAD_CHUNK_SIZE):
                if self.cancel_flag or self.parent.stop_flag:
                    self.logger.info('Download canceled')
                    return downloaded_bytes
                downloaded_bytes = response.raw.tell()
                if self.download_size:
                    self.percent = round(min((downloaded_bytes + self.downloaded_bytes) / self.download_size, 1.0) * 100, 2)
                    if self.percent > prev_percent:
                        self.logger.info(f'File downloading: {self.percent}%')
                        prev_percent = self.percent
                else:
                    self.logger.info(f"File downloading: {(downloaded_bytes + self.downloaded_bytes) // 1024}kB")
                tmp_file.write(chunk)
                self.written_bytes += len(chunk)
        except Exception as e:
            self.parent.register_error(69, 'Download error: chunk error: ' + str(e), job_fail=False)
            return downloaded_bytes
        else:
            self.percent = 100
            return downloaded_bytes

    def cancel(self):
        self.cancel_flag = True

    def get_percent(self):
        return self.percent
