# Copyright 3D Control Systems, Inc. All Rights Reserved 2017-2019.
# Built in San Francisco.

# This software is distributed under a commercial license for personal,
# educational, corporate or any other use.
# The software as a whole or any parts of it is prohibited for distribution or
# use without obtaining a license from 3D Control Systems, Inc.

# All software licenses are subject to the 3DPrinterOS terms of use
# (available at https://www.3dprinteros.com/terms-and-conditions/),
# and privacy policy (available at https://www.3dprinteros.com/privacy-policy/)

import os
import sys
import logging
import subprocess
import string
import time

import certifi

import client_ssl_context
import paths
import platforms
from awaitable import Awaitable
#NOTE DONT IMPORT config here, a first config import must be done in app.App

private_key_path = "/tmp/3dprinteros_key"
public_key_path = private_key_path + ".pub"
ssh_authorized_keys_path = os.path.expanduser("~/.ssh/authorized_keys")
default_shadow_file_path = "/etc/shadow"


CLISMA_SERVICE_NAME = 'clisma.service'


BROKEN_OS_TRUST = False


sudo_available = None


def is_sudo_available(test_command=None):
    global sudo_available
    if not test_command:
        test_command = 'whoami'
    if sudo_available == None:
        try:
            sudo_available = not bool(subprocess.run(['sudo', '-n', test_command]).returncode)
        except:
            sudo_available = False
    return sudo_available


def is_admin():
    try:
        is_admin = os.getuid() == 0
    except:
        import ctypes
        is_admin = ctypes.windll.shell32.IsUserAnAdmin()
    return is_admin


def check_ssh_dir():
    dirname = os.path.dirname(ssh_authorized_keys_path)
    if not os.path.exists(dirname):
        os.mkdir(dirname)


def revoke_all_ssh_keys():
    logger = logging.getLogger(__name__)
    logger.info('Removing ssh authorized keys')
    try:
        check_ssh_dir()
        with open(ssh_authorized_keys_path, 'w') as f:
            f.write('')
    except Exception as e:
        logger.warning("Error while erasing authorized keys!\n%s" % e)
    else:
        restart_sshd()
        return True


def restart_sshd():
    logger = logging.getLogger(__name__)
    sshd_restart_command = ['systemctl', 'restart', 'sshd.service']
    logger.info('Restarting sshd...')
    if subprocess.run(sshd_restart_command).returncode:
        logger.error('sshd restart failed')
    else:
        logger.info('sshd successfully restarted')


def add_public_key(public_key):
    logger = logging.getLogger(__name__)
    try:
        check_ssh_dir()
        with open(ssh_authorized_keys_path, "r") as f:
            ssh_authorized_keys = f.read()
        if not public_key in ssh_authorized_keys:
            with open(ssh_authorized_keys_path, "a") as f:
                f.write(public_key)
            logger.info('Public key is added to authorized keys')
            restart_sshd()
            return True
    except Exception as e:
        logger.warning('Cant write public key into authorized_keys:\n%s' % e)
    return False


def generate_ssh_keys():
    logger = logging.getLogger(__name__)
    path = ''
    try:
        for path in (public_key_path, private_key_path):
            if os.path.isfile(path):
                os.remove(path)
    except:
        logger.warning(f'Unable to remove a key file {path}')
    try:
        key_generation_line = "ssh-keygen", "-t", "ed25519",  "-N", "" , "-f", private_key_path
        logger.info('Launching key generation subprocess...')
        if subprocess.run(key_generation_line).returncode:
            logger.error("Can't generate ssh keys. Non zero return code.")
            return
        logger.info('Keys are generated')
        with open(public_key_path) as f:
            public_key = f.read()
        os.remove(public_key_path)
    except Exception as e:
        logger.error('Error during key generation subprocess\n%s' % e)
    else:
        add_public_key(public_key)
        return private_key_path


def change_ssh_password(old_password, new_password, allow_on_non_rpi=False):
    error = ""
    logger = logging.getLogger(__name__)
    found_forbidden_chars = []
    inter_command_sleep = 0.1
    for password in (old_password, new_password):
        for char in password:
            if char not in string.printable:
                found_forbidden_chars.append(char)
    if found_forbidden_chars:
        logger.warning('Forbidden characters in password: ' + str(found_forbidden_chars))
        error = "Password contains forbidden characters"
    old_password = old_password.encode() + b"\n"
    new_password = new_password.encode() + b"\n"
    # print(f'Old pass:{old_password}\tNew pass:{new_password}')
    if platforms.get_platform() == 'rpi' or allow_on_non_rpi:
        try:
            process = subprocess.Popen(['/usr/bin/passwd'], stdin=subprocess.PIPE)
            time.sleep(inter_command_sleep)
            process.stdin.write(old_password)
            time.sleep(inter_command_sleep)
            process.stdin.flush()
            if process.poll() is None: #mean process did not exit
                process.stdin.write(new_password)
                time.sleep(inter_command_sleep)
                process.stdin.flush()
                process.stdin.write(new_password)
                time.sleep(inter_command_sleep)
                process.stdin.flush()
                ret_code = process.wait(5)
                if ret_code == 0:
                    logger.info('Password successfully changed')
                    restart_sshd()
                    return 
                elif ret_code > 0:
                    error = 'Password change rejected'
            else:
                error = 'Password change error: wrong old password'
        except:
            error = "Exception occurred in SSH password change process"
            logger.exception(error)
    else:
        error = "Warning! System password change attempt on non RPi platform!"
    logger.warning(error)
    return error


def add_cacert_to_certifi(cert_path=paths.CUSTOM_CACERT_PATH):
    if os.path.exists(cert_path):
        logger = logging.getLogger('rights')
        try:
            #don't use certifi.where here, since we only want to manipulate our certifi folder
            with open(paths.CERTIFI_CACERT_PATH) as f:
                old_cacerts = f.read()
            with open(paths.CERTIFI_CACERT_PATH, 'a') as f_out:
                with open(cert_path) as f_in:
                    new_cacert = f_in.read()
                    if new_cacert not in old_cacerts:
                        f_out.write('\n')
                        f_out.write(new_cacert)
                        logger.info('Succesfully added custom ca certificate to certifi')
                        # hacky and dangerous for ongoing https requests, but still should work to make certifi to reload new file
                        certifi.core._CACERT_PATH = None 
                        certifi.where() 
                        client_ssl_context.load_context()
                        if platforms.get_platform() in ('linux', 'rpi'):
                            os.sync()
        except (OSError, IOError):
            error = "Error on update of certificates in certifi package"
            logging.getLogger('rights').exception(error)
            return error
        except:
            error = 'Crash on adding certificate to certifi'
            logging.getLogger('rights').exception(error)
            return error


def add_cacert_to_trust():
    if not is_sudo_available():
        return "Error. No access to sudo"
    else:
        try:
            if BROKEN_OS_TRUST:
                subprocess.run(['sudo', 'trust', 'anchor', '/etc/ca-certificates/extracted/tls-ca-bundle.pem'])
            subprocess.run(['sudo', 'trust', 'anchor', paths.CUSTOM_CACERT_PATH])
            subprocess.run(['sudo', 'trust', 'extract-compat'])
            #subprocess.run(['sudo', 'update-ca-trust'])
            subprocess.run(['sync'])
        except (subprocess.CalledProcessError, OSError):
            return "Error on update of system certificates"


def is_clisma_running():
    try:
        output = subprocess.check_output(["sudo", "systemctl", "is-active", CLISMA_SERVICE_NAME])
    except subprocess.CalledProcessError:
        return False
    return not b"inactive" in output


def toggle_clisma():
    logger = logging.getLogger('rights')
    try:
        if is_clisma_running():
            logger.info("Stopping and disabling CLISMA...")
            subprocess.call(["sudo", "systemctl", "stop", CLISMA_SERVICE_NAME])
            subprocess.call(["sudo", "systemctl", "disable", CLISMA_SERVICE_NAME])
            logger.info("CLISMA disabled")
        else:
            import config
            if config.get_settings().get('forbid_support_access'):
                logger.info("CLISMA forbidden in settings!")
                return False
            else:
                logger.info("Starting and enabling CLISMA...")
                subprocess.call(["sudo", "systemctl", "start", CLISMA_SERVICE_NAME])
                subprocess.call(["sudo", "systemctl", "enable", CLISMA_SERVICE_NAME])
                logger.info("CLISMA enabled")
        return True
    except (subprocess.CalledProcessError, OSError):
        logger.exception("Exception while toggling clisma")


class RightsChecker(Awaitable):

    NAME = 'Linux groups membership'

    # pylint: disable=method-hidden
    def _check_function(self):
        if sys.platform.startswith('linux'):
            import config
            if config.get_settings()['linux_rights_warning'] and not is_admin():
                self.logger.debug('Checking Linux rights')
                result = str(self.execute_command('groups'))
                if not ('tty' in result and 'dialout' in result and 'usbusers' in result):
                    self.logger.debug('Current Linux user is not in tty and dialout groups')
                    return False
        return True

    def add_user_groups(self):
        if sys.platform.startswith('linux'):
            self.logger.debug('Adding Linux user to necessary groups')
            self.execute_command(['groupadd', 'usbusers'])
            # self.execute_command('pkexec "sudo usermod -a -G dialout,tty,usbusers $USER"', shell=True)
            self.execute_command('pkexec "sudo usermod -a -G dialout,tty,usbusers $USER"', shell=True)

    def execute_command(self, command, shell=False):
        self.logger.debug('Executing command: ' + str(command))
        try:
            stdout = subprocess.run(command, stdout=subprocess.PIPE, universal_newlines=True).stdout
        except Exception as e:
            self.logger.debug('Error while executing command: "' + str(command) + '\n' + str(e))
        else:
            if stdout:
                self.logger.debug('Execution result: ' + str(stdout))
                return stdout


def test_change_ssh_password():
    try:
        old_password, new_password = sys.argv[1:]
        if "-h" in sys.argv:
            raise ValueError
    except ValueError:
        print("Will change current user password.\nUsage: python rights.py current_password new_password")
        sys.exit(1)
    error = change_ssh_password(old_password, new_password, allow_on_non_rpi=True)
    if error:
        print(error)
        sys.exit(1)


if __name__ == "__main__":
    test_change_ssh_password()
