diff --git a/qubesmgmt/tests/__init__.py b/qubesmgmt/tests/__init__.py index 26b2e06..dfbd836 100644 --- a/qubesmgmt/tests/__init__.py +++ b/qubesmgmt/tests/__init__.py @@ -19,6 +19,8 @@ # with this program; if not, see . import unittest +import io + import qubesmgmt import qubesmgmt.app @@ -40,14 +42,41 @@ class TestVM(object): return self.name < other.name return NotImplemented + class TestVMCollection(dict): def __iter__(self): return iter(self.values()) +class TestProcess(object): + def __init__(self, input_callback=None, stdout=None, stderr=None): + self.input_callback = input_callback + self.stdin = io.BytesIO() + # don't let anyone close it, before we get the value + self.stdin_close = self.stdin.close + if self.input_callback: + self.stdin.close = ( + lambda: self.input_callback(self.stdin.getvalue())) + else: + self.stdin.close = lambda: None + self.stdout = stdout + self.stderr = stderr + self.returncode = 0 + + def communicate(self, input=None): + self.stdin.write(input) + self.stdin.close() + self.stdin_close() + return self.stdout, self.stderr + + def wait(self): + self.stdin_close() + return 0 + class QubesTest(qubesmgmt.app.QubesBase): expected_calls = None actual_calls = None + service_calls = None def __init__(self): super(QubesTest, self).__init__() @@ -55,6 +84,8 @@ class QubesTest(qubesmgmt.app.QubesBase): self.expected_calls = {} #: actual calls made self.actual_calls = [] + #: rpc service calls + self.service_calls = [] def qubesd_call(self, dest, method, arg=None, payload=None): call_key = (dest, method, arg, payload) @@ -64,6 +95,11 @@ class QubesTest(qubesmgmt.app.QubesBase): return_data = self.expected_calls[call_key] return self._parse_qubesd_response(return_data) + def run_service(self, dest, service, **kwargs): + self.service_calls.append((dest, service, kwargs)) + return TestProcess(lambda input: self.service_calls.append((dest, + service, input))) + class QubesTestCase(unittest.TestCase): def setUp(self):