diff --git a/qubesadmin/tests/__init__.py b/qubesadmin/tests/__init__.py index a116039..8299c0c 100644 --- a/qubesadmin/tests/__init__.py +++ b/qubesadmin/tests/__init__.py @@ -54,14 +54,12 @@ class TestVMCollection(dict): class TestProcess(object): def __init__(self, input_callback=None, stdout=None, stderr=None): self.input_callback = input_callback + self.got_any_input = False self.stdin = io.BytesIO() # don't let anyone close it, before we get the value self.stdin_close = self.stdin.close - if self.input_callback: - self.stdin.close = ( - lambda: self.input_callback(self.stdin.getvalue())) - else: - self.stdin.close = lambda: None + self.stdin.close = self.store_input + self.stdin.flush = self.store_input if stdout == subprocess.PIPE: self.stdout = io.BytesIO() else: @@ -72,6 +70,13 @@ class TestProcess(object): self.stderr = stderr self.returncode = 0 + def store_input(self): + value = self.stdin.getvalue() + if (not self.got_any_input or value) and self.input_callback: + self.input_callback(self.stdin.getvalue()) + self.got_any_input = True + self.stdin.truncate(0) + def communicate(self, input=None): if input is not None: self.stdin.write(input) @@ -121,14 +126,17 @@ class _AssertNotRaisesContext(object): class QubesTest(qubesadmin.app.QubesBase): + expected_service_calls = None expected_calls = None actual_calls = None service_calls = None def __init__(self): super(QubesTest, self).__init__() - #: expected calls and saved replies for them + #: expected Admin API calls and saved replies for them self.expected_calls = {} + #: expected qrexec service calls and saved replies for them + self.expected_service_calls = {} #: actual calls made self.actual_calls = [] #: rpc service calls @@ -152,6 +160,14 @@ class QubesTest(qubesadmin.app.QubesBase): def run_service(self, dest, service, **kwargs): self.service_calls.append((dest, service, kwargs)) + call_key = (dest, service) + # TODO: consider it as a future extension, as a replacement for + # checking app.service_calls later + # if call_key not in self.expected_service_calls: + # raise AssertionError('Unexpected service call {!r}'.format(call_key)) + if call_key in self.expected_service_calls: + kwargs = kwargs.copy() + kwargs['stdout'] = io.BytesIO(self.expected_service_calls[call_key]) return TestProcess(lambda input: self.service_calls.append((dest, service, input)), stdout=kwargs.get('stdout', None),