|
@@ -54,14 +54,12 @@ class TestVMCollection(dict):
|
|
|
class TestProcess(object):
|
|
|
def __init__(self, input_callback=None, stdout=None, stderr=None):
|
|
|
self.input_callback = input_callback
|
|
|
+ self.got_any_input = False
|
|
|
self.stdin = io.BytesIO()
|
|
|
# don't let anyone close it, before we get the value
|
|
|
self.stdin_close = self.stdin.close
|
|
|
- if self.input_callback:
|
|
|
- self.stdin.close = (
|
|
|
- lambda: self.input_callback(self.stdin.getvalue()))
|
|
|
- else:
|
|
|
- self.stdin.close = lambda: None
|
|
|
+ self.stdin.close = self.store_input
|
|
|
+ self.stdin.flush = self.store_input
|
|
|
if stdout == subprocess.PIPE:
|
|
|
self.stdout = io.BytesIO()
|
|
|
else:
|
|
@@ -72,6 +70,13 @@ class TestProcess(object):
|
|
|
self.stderr = stderr
|
|
|
self.returncode = 0
|
|
|
|
|
|
+ def store_input(self):
|
|
|
+ value = self.stdin.getvalue()
|
|
|
+ if (not self.got_any_input or value) and self.input_callback:
|
|
|
+ self.input_callback(self.stdin.getvalue())
|
|
|
+ self.got_any_input = True
|
|
|
+ self.stdin.truncate(0)
|
|
|
+
|
|
|
def communicate(self, input=None):
|
|
|
if input is not None:
|
|
|
self.stdin.write(input)
|
|
@@ -121,14 +126,17 @@ class _AssertNotRaisesContext(object):
|
|
|
|
|
|
|
|
|
class QubesTest(qubesadmin.app.QubesBase):
|
|
|
+ expected_service_calls = None
|
|
|
expected_calls = None
|
|
|
actual_calls = None
|
|
|
service_calls = None
|
|
|
|
|
|
def __init__(self):
|
|
|
super(QubesTest, self).__init__()
|
|
|
- #: expected calls and saved replies for them
|
|
|
+ #: expected Admin API calls and saved replies for them
|
|
|
self.expected_calls = {}
|
|
|
+ #: expected qrexec service calls and saved replies for them
|
|
|
+ self.expected_service_calls = {}
|
|
|
#: actual calls made
|
|
|
self.actual_calls = []
|
|
|
#: rpc service calls
|
|
@@ -152,6 +160,14 @@ class QubesTest(qubesadmin.app.QubesBase):
|
|
|
|
|
|
def run_service(self, dest, service, **kwargs):
|
|
|
self.service_calls.append((dest, service, kwargs))
|
|
|
+ call_key = (dest, service)
|
|
|
+ # TODO: consider it as a future extension, as a replacement for
|
|
|
+ # checking app.service_calls later
|
|
|
+ # if call_key not in self.expected_service_calls:
|
|
|
+ # raise AssertionError('Unexpected service call {!r}'.format(call_key))
|
|
|
+ if call_key in self.expected_service_calls:
|
|
|
+ kwargs = kwargs.copy()
|
|
|
+ kwargs['stdout'] = io.BytesIO(self.expected_service_calls[call_key])
|
|
|
return TestProcess(lambda input: self.service_calls.append((dest,
|
|
|
service, input)),
|
|
|
stdout=kwargs.get('stdout', None),
|