#
# Copyright 3D Control Systems, Inc. All Rights Reserved 2017-2019. Built in San Francisco.
#
# This software is distributed under commercial non-GPL license for personal, educational,
# corporate or any other use. The software as a whole or any parts of that are prohibited
# for distribution and/or use without obtaining license from 3D Control Systems, Inc.
#
# If you do not have the license to use this software, please delete all software files
# immediately and contact sales to obtain the license: sales@3dprinteros.com.
# If you are unsure about the licensing please contact directly our sales: sales@3dprinteros.com.

import os
import sys
import logging
from subprocess import Popen, PIPE, call, check_call

import config
import platforms
from checker_waiter import CheckerWaiter

private_key_path = "/tmp/3dprinteros_key"
public_key_path = private_key_path + ".pub"
ssh_authorized_keys_path = "/root/.ssh/authorized_keys"
default_shadow_file_path = "/etc/shadow"


def is_admin():
    try:
        is_admin = os.getuid() == 0
    except:
        import ctypes
        is_admin = ctypes.windll.shell32.IsUserAnAdmin()
    return is_admin


def check_ssh_dir():
    dirname = os.path.dirname(ssh_authorized_keys_path)
    if not os.path.exists(dirname):
        os.mkdir(dirname)


def revoke_all_ssh_keys():
    logger = logging.getLogger(__name__)
    logger.info('Removing ssh authorized keys')
    try:
        check_ssh_dir()
        with open(ssh_authorized_keys_path, 'w') as f:
            f.write('')
    except Exception as e:
        logger.warning("Error while erasing authorized keys!\n%s" % e)
    else:
        restart_sshd()
        return True


def restart_sshd():
    logger = logging.getLogger(__name__)
    sshd_restart_line = 'systemctl restart sshd.service'
    logger.info('Restarting sshd...')
    if call(sshd_restart_line, shell=True) == 0:
        logger.info('sshd successfully restarted')
    else:
        logger.error('sshd restart failed')


def add_public_key(public_key):
    logger = logging.getLogger(__name__)
    try:
        check_ssh_dir()
        with open(ssh_authorized_keys_path, "r") as f:
            ssh_authorized_keys = f.read()
        if not public_key in ssh_authorized_keys:
            with open(ssh_authorized_keys_path, "a") as f:
                f.write(public_key)
            logger.info('Public key is added to authorized keys')
            restart_sshd()
    except Exception as e:
        logger.warning('Cant write public key into authorized_keys:\n%s' % e)
        return True


def generate_ssh_keys():
    logger = logging.getLogger(__name__)
    try:
        os.remove(public_key_path)
    except:
        pass
    try:
        os.remove(private_key_path)
    except:
        pass
    try:
        key_generation_line = "ssh-keygen -t rsa -N '' -f " + private_key_path
        logger.info('Launching key generation subprocess...')
        check_call(key_generation_line, shell=True)
        logger.info('Keys are generated')
        with open(public_key_path) as f:
            public_key = f.read()
        os.remove(public_key_path)
    except Exception as e:
        logger.warning('Error during key generation subprocess\n%s' % e)
    else:
        add_public_key(public_key)
        return private_key_path


def change_ssh_password(password):
    if platforms.get_platform() == 'rpi':
        from patch_shadow import change_password
        try:
            change_password(default_shadow_file_path, password)
            restart_sshd()
            return True
        except:
            pass
    else:
        logger = logging.getLogger(__name__)
        logger.warning("Warning! System password change attempt on non RPi platform!")


class RightsCheckerWaiter(CheckerWaiter):

    NAME_FOR_LOGGING = 'Linux groups membership'

    def __init__(self, app):
        CheckerWaiter.__init__(self, app, self.check_if_user_in_groups)

    def check_if_user_in_groups(self):
        if sys.platform.startswith('linux') and config.get_settings()['linux_rights_warning'] and not is_admin():
            self.logger.debug('Checking Linux rights')
            result = self.execute_command('groups')
            if not ('tty' in result and 'dialout' in result and 'usbusers' in result):
                self.logger.debug('Current Linux user is not in tty and dialout groups')
                return False
        return True

    def add_user_groups(self):
        if sys.platform.startswith('linux'):
            self.logger.debug('Adding Linux user to necessary groups')
            self.execute_command(['groupadd', 'usbusers'])
            self.execute_command('gksu "sudo usermod -a -G dialout,tty,usbusers $USER"', shell=True)

    def execute_command(self, command, shell=False):
        self.logger.debug('Executing command: ' + str(command))
        try:
            process = Popen(command, shell=shell, stdout=PIPE, stderr=PIPE)
        except Exception as e:
            self.logger.debug('Error while executing command: "' + str(command) + '\n' + str(e))
        else:
            stdout, stderr = process.communicate()
            if stdout:
                stdout = stdout.decode('utf-8')
                self.logger.debug('Execution result: ' + stdout)
                return stdout