diff --git a/qubes/tests/__init__.py b/qubes/tests/__init__.py index b8ac3b43..8c1deb7b 100644 --- a/qubes/tests/__init__.py +++ b/qubes/tests/__init__.py @@ -1204,7 +1204,7 @@ def load_tests(loader, tests, pattern): # pylint: disable=unused-argument # 'qubes.tests.regressions', # external modules -# 'qubes.tests.extra', + 'qubes.tests.extra', ): tests.addTests(loader.loadTestsFromName(modname)) diff --git a/qubes/tests/extra.py b/qubes/tests/extra.py index d8985b0b..a60cdfab 100644 --- a/qubes/tests/extra.py +++ b/qubes/tests/extra.py @@ -19,10 +19,71 @@ # import sys + +import asyncio +import subprocess import pkg_resources import qubes.tests import qubes.vm.appvm +class ProcessWrapper(object): + def __init__(self, proc, loop=None): + self._proc = proc + self._loop = loop or asyncio.get_event_loop() + + def __getattr__(self, item): + return getattr(self._proc, item) + + def __setattr__(self, key, value): + if key.startswith('_'): + return super(ProcessWrapper, self).__setattr__(key, value) + return setattr(self._proc, key, value) + + def communicate(self, input=None): + return self._loop.run_until_complete(self._proc.communicate(input)) + +class VMWrapper(object): + '''Wrap VM object to provide stable API for basic operations''' + def __init__(self, vm, loop=None): + self._vm = vm + self._loop = loop or asyncio.get_event_loop() + + def __getattr__(self, item): + return getattr(self._vm, item) + + def __setattr__(self, key, value): + if key.startswith('_'): + return super(VMWrapper, self).__setattr__(key, value) + return setattr(self._vm, key, value) + + def __str__(self): + return str(self._vm) + + def __eq__(self, other): + return self._vm == other + + def start(self): + return self._loop.run_until_complete(self._vm.start()) + + def shutdown(self): + return self._loop.run_until_complete(self._vm.shutdown()) + + def run(self, command, wait=False, user=None, passio_popen=False, + passio_stderr=False, **kwargs): + if wait: + try: + self._loop.run_until_complete( + self._vm.run_for_stdio(command, user=user)) + except subprocess.CalledProcessError as err: + return err.returncode + return 0 + elif passio_popen: + p = self._loop.run_until_complete(self._vm.run(command, user=user, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE if passio_stderr else None)) + return ProcessWrapper(p, self._loop) + class ExtraTestCase(qubes.tests.SystemTestCase): @@ -31,6 +92,17 @@ class ExtraTestCase(qubes.tests.SystemTestCase): def setUp(self): super(ExtraTestCase, self).setUp() self.init_default_template(self.template) + if self.template is not None: + # also use this template for DispVMs + dispvm_base = self.app.add_new_vm('AppVM', + name=self.make_vm_name('dvm'), + template=self.template, label='red', template_for_dispvms=True) + self.loop.run_until_complete(dispvm_base.create_on_disk()) + self.app.default_dispvm = dispvm_base + + def tearDown(self): + self.app.default_dispvm = None + super(ExtraTestCase, self).tearDown() def create_vms(self, names): """ @@ -49,13 +121,14 @@ class ExtraTestCase(qubes.tests.SystemTestCase): name=self.make_vm_name(vmname), template=template, label='red') - vm.create_on_disk() + self.loop.run_until_complete(vm.create_on_disk()) self.app.save() # get objects after reload vms = [] for vmname in names: - vms.append(self.app.domains[self.make_vm_name(vmname)]) + vms.append(VMWrapper(self.app.domains[self.make_vm_name(vmname)], + loop=self.loop)) return vms def enable_network(self): @@ -97,6 +170,6 @@ def load_tests(loader, tests, pattern): ExtraForTemplateLoadFailure = type('ExtraForTemplateLoadFailure', (qubes.tests.QubesTestCase,), {entry.name: runTest}) - tests.addTest(ExtraLoadFailure(entry.name)) + tests.addTest(ExtraForTemplateLoadFailure(entry.name)) return tests