diff --git a/qubesadmin/tests/__init__.py b/qubesadmin/tests/__init__.py index a0fa1e6..996276f 100644 --- a/qubesadmin/tests/__init__.py +++ b/qubesadmin/tests/__init__.py @@ -52,7 +52,7 @@ class TestVMCollection(dict): class TestProcess(object): - def __init__(self, input_callback=None, stdout=None, stderr=None): + def __init__(self, input_callback=None, stdout=None, stderr=None, stdout_data=None): self.input_callback = input_callback self.got_any_input = False self.stdin = io.BytesIO() @@ -70,6 +70,10 @@ class TestProcess(object): self.stderr = io.BytesIO() else: self.stderr = stderr + if stdout_data: + self.stdout.write(stdout_data) + # Seek to head so that it can be read later + self.stdout.seek(0) self.returncode = 0 def store_input(self): @@ -91,7 +95,7 @@ class TestProcess(object): return 0 def poll(self): - return None + return self.returncode class _AssertNotRaisesContext(object): @@ -167,11 +171,12 @@ class QubesTest(qubesadmin.app.QubesBase): # raise AssertionError('Unexpected service call {!r}'.format(call_key)) if call_key in self.expected_service_calls: kwargs = kwargs.copy() - kwargs['stdout'] = io.BytesIO(self.expected_service_calls[call_key]) + kwargs['stdout_data'] = self.expected_service_calls[call_key] return TestProcess(lambda input: self.service_calls.append((dest, service, input)), stdout=kwargs.get('stdout', None), stderr=kwargs.get('stderr', None), + stdout_data=kwargs.get('stdout_data', None), )