# -*- encoding: utf8 -*-
#
# The Qubes OS Project, http://www.qubes-os.org
#
# Copyright (C) 2017 Marek Marczykowski-Górecki
#                               <marmarek@invisiblethingslab.com>
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation; either version 2.1 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License along
# with this program; if not, see <http://www.gnu.org/licenses/>.
import subprocess
import traceback
import unittest

import io

import qubesadmin
import qubesadmin.app


class TestVM(object):
    def __init__(self, name, **kwargs):
        self.name = name
        self.klass = 'TestVM'
        for key, value in kwargs.items():
            setattr(self, key, value)

    def get_power_state(self):
        return getattr(self, 'power_state', 'Running')

    def __str__(self):
        return self.name

    def __lt__(self, other):
        if isinstance(other, TestVM):
            return self.name < other.name
        return NotImplemented


class TestVMCollection(dict):
    def __iter__(self):
        return iter(self.values())


class TestProcess(object):
    def __init__(self, input_callback=None, stdout=None, stderr=None):
        self.input_callback = input_callback
        self.got_any_input = False
        self.stdin = io.BytesIO()
        # don't let anyone close it, before we get the value
        self.stdin_close = self.stdin.close
        self.stdin.close = self.store_input
        self.stdin.flush = self.store_input
        if stdout == subprocess.PIPE or stdout == subprocess.DEVNULL \
                or stdout is None:
            self.stdout = io.BytesIO()
        else:
            self.stdout = stdout
        if stderr == subprocess.PIPE or stderr == subprocess.DEVNULL \
                or stderr is None:
            self.stderr = io.BytesIO()
        else:
            self.stderr = stderr
        self.returncode = 0

    def store_input(self):
        value = self.stdin.getvalue()
        if (not self.got_any_input or value) and self.input_callback:
            self.input_callback(self.stdin.getvalue())
        self.got_any_input = True
        self.stdin.truncate(0)

    def communicate(self, input=None):
        if input is not None:
            self.stdin.write(input)
        self.stdin.close()
        self.stdin_close()
        return self.stdout.read(), self.stderr.read()

    def wait(self):
        self.stdin_close()
        return 0

    def poll(self):
        return None


class _AssertNotRaisesContext(object):
    """A context manager used to implement TestCase.assertNotRaises methods.

    Stolen from unittest and hacked. Regexp support stripped.
    """ # pylint: disable=too-few-public-methods

    def __init__(self, expected, test_case, expected_regexp=None):
        if expected_regexp is not None:
            raise NotImplementedError('expected_regexp is unsupported')

        self.expected = expected
        self.exception = None

        self.failureException = test_case.failureException

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, tb):
        if exc_type is None:
            return True

        if issubclass(exc_type, self.expected):
            raise self.failureException(
                "{!r} raised, traceback:\n{!s}".format(
                    exc_value, ''.join(traceback.format_tb(tb))))
        else:
            # pass through
            return False


class QubesTest(qubesadmin.app.QubesBase):
    expected_service_calls = None
    expected_calls = None
    actual_calls = None
    service_calls = None

    def __init__(self):
        super(QubesTest, self).__init__()
        #: expected Admin API calls and saved replies for them
        self.expected_calls = {}
        #: expected qrexec service calls and saved replies for them
        self.expected_service_calls = {}
        #: actual calls made
        self.actual_calls = []
        #: rpc service calls
        self.service_calls = []

    def qubesd_call(self, dest, method, arg=None, payload=None,
            payload_stream=None):
        if payload_stream:
            payload = (payload or b'') + payload_stream.read()
        call_key = (dest, method, arg, payload)
        self.actual_calls.append(call_key)
        if call_key not in self.expected_calls:
            raise AssertionError('Unexpected call {!r}'.format(call_key))
        return_data = self.expected_calls[call_key]
        if isinstance(return_data, list):
            try:
                return_data = return_data.pop(0)
            except IndexError:
                raise AssertionError('Extra call {!r}'.format(call_key))
        return self._parse_qubesd_response(return_data)

    def run_service(self, dest, service, **kwargs):
        self.service_calls.append((dest, service, kwargs))
        call_key = (dest, service)
        # TODO: consider it as a future extension, as a replacement for
        #  checking app.service_calls later
        # if call_key not in self.expected_service_calls:
        #     raise AssertionError('Unexpected service call {!r}'.format(call_key))
        if call_key in self.expected_service_calls:
            kwargs = kwargs.copy()
            kwargs['stdout'] = io.BytesIO(self.expected_service_calls[call_key])
        return TestProcess(lambda input: self.service_calls.append((dest,
            service, input)),
            stdout=kwargs.get('stdout', None),
            stderr=kwargs.get('stderr', None),
        )


class QubesTestCase(unittest.TestCase):
    def setUp(self):
        super(QubesTestCase, self).setUp()
        self.app = QubesTest()

    def assertAllCalled(self):
        self.assertEqual(
            set(self.app.expected_calls.keys()),
            set(self.app.actual_calls))
        # and also check if calls expected multiple times were called
        self.assertFalse([(call, ret)
            for call, ret in self.app.expected_calls.items() if
            isinstance(ret, list) and ret],
            'Some calls not called expected number of times')

    def assertNotRaises(self, excClass, callableObj=None, *args, **kwargs):
        """Fail if an exception of class excClass is raised
           by callableObj when invoked with arguments args and keyword
           arguments kwargs. If a different type of exception is
           raised, it will not be caught, and the test case will be
           deemed to have suffered an error, exactly as for an
           unexpected exception.

           If called with callableObj omitted or None, will return a
           context object used like this::

                with self.assertRaises(SomeException):
                    do_something()

           The context manager keeps a reference to the exception as
           the 'exception' attribute. This allows you to inspect the
           exception after the assertion::

               with self.assertRaises(SomeException) as cm:
                   do_something()
               the_exception = cm.exception
               self.assertEqual(the_exception.error_code, 3)
        """
        context = _AssertNotRaisesContext(excClass, self)
        if callableObj is None:
            return context
        with context:
            callableObj(*args, **kwargs)