Parcourir la source

tests: extend run_service mockup for pre-recorded output

And also handle input written if just stdin.flush() is called but not
stdin.close().
Marek Marczykowski-Górecki il y a 4 ans
Parent
commit
7fb90e0233
1 fichiers modifiés avec 22 ajouts et 6 suppressions
  1. 22 6
      qubesadmin/tests/__init__.py

+ 22 - 6
qubesadmin/tests/__init__.py

@@ -54,14 +54,12 @@ class TestVMCollection(dict):
 class TestProcess(object):
     def __init__(self, input_callback=None, stdout=None, stderr=None):
         self.input_callback = input_callback
+        self.got_any_input = False
         self.stdin = io.BytesIO()
         # don't let anyone close it, before we get the value
         self.stdin_close = self.stdin.close
-        if self.input_callback:
-            self.stdin.close = (
-                lambda: self.input_callback(self.stdin.getvalue()))
-        else:
-            self.stdin.close = lambda: None
+        self.stdin.close = self.store_input
+        self.stdin.flush = self.store_input
         if stdout == subprocess.PIPE:
             self.stdout = io.BytesIO()
         else:
@@ -72,6 +70,13 @@ class TestProcess(object):
             self.stderr = stderr
         self.returncode = 0
 
+    def store_input(self):
+        value = self.stdin.getvalue()
+        if (not self.got_any_input or value) and self.input_callback:
+            self.input_callback(self.stdin.getvalue())
+        self.got_any_input = True
+        self.stdin.truncate(0)
+
     def communicate(self, input=None):
         if input is not None:
             self.stdin.write(input)
@@ -121,14 +126,17 @@ class _AssertNotRaisesContext(object):
 
 
 class QubesTest(qubesadmin.app.QubesBase):
+    expected_service_calls = None
     expected_calls = None
     actual_calls = None
     service_calls = None
 
     def __init__(self):
         super(QubesTest, self).__init__()
-        #: expected calls and saved replies for them
+        #: expected Admin API calls and saved replies for them
         self.expected_calls = {}
+        #: expected qrexec service calls and saved replies for them
+        self.expected_service_calls = {}
         #: actual calls made
         self.actual_calls = []
         #: rpc service calls
@@ -152,6 +160,14 @@ class QubesTest(qubesadmin.app.QubesBase):
 
     def run_service(self, dest, service, **kwargs):
         self.service_calls.append((dest, service, kwargs))
+        call_key = (dest, service)
+        # TODO: consider it as a future extension, as a replacement for
+        #  checking app.service_calls later
+        # if call_key not in self.expected_service_calls:
+        #     raise AssertionError('Unexpected service call {!r}'.format(call_key))
+        if call_key in self.expected_service_calls:
+            kwargs = kwargs.copy()
+            kwargs['stdout'] = io.BytesIO(self.expected_service_calls[call_key])
         return TestProcess(lambda input: self.service_calls.append((dest,
             service, input)),
             stdout=kwargs.get('stdout', None),