tests: extend run_service mockup for pre-recorded output

And also handle input written if just stdin.flush() is called but not
stdin.close().
This commit is contained in:
Marek Marczykowski-Górecki 2019-10-19 04:31:11 +02:00
parent 889e606d7c
commit 7fb90e0233
No known key found for this signature in database
GPG Key ID: 063938BA42CFA724

View File

@ -54,14 +54,12 @@ class TestVMCollection(dict):
class TestProcess(object): class TestProcess(object):
def __init__(self, input_callback=None, stdout=None, stderr=None): def __init__(self, input_callback=None, stdout=None, stderr=None):
self.input_callback = input_callback self.input_callback = input_callback
self.got_any_input = False
self.stdin = io.BytesIO() self.stdin = io.BytesIO()
# don't let anyone close it, before we get the value # don't let anyone close it, before we get the value
self.stdin_close = self.stdin.close self.stdin_close = self.stdin.close
if self.input_callback: self.stdin.close = self.store_input
self.stdin.close = ( self.stdin.flush = self.store_input
lambda: self.input_callback(self.stdin.getvalue()))
else:
self.stdin.close = lambda: None
if stdout == subprocess.PIPE: if stdout == subprocess.PIPE:
self.stdout = io.BytesIO() self.stdout = io.BytesIO()
else: else:
@ -72,6 +70,13 @@ class TestProcess(object):
self.stderr = stderr self.stderr = stderr
self.returncode = 0 self.returncode = 0
def store_input(self):
value = self.stdin.getvalue()
if (not self.got_any_input or value) and self.input_callback:
self.input_callback(self.stdin.getvalue())
self.got_any_input = True
self.stdin.truncate(0)
def communicate(self, input=None): def communicate(self, input=None):
if input is not None: if input is not None:
self.stdin.write(input) self.stdin.write(input)
@ -121,14 +126,17 @@ class _AssertNotRaisesContext(object):
class QubesTest(qubesadmin.app.QubesBase): class QubesTest(qubesadmin.app.QubesBase):
expected_service_calls = None
expected_calls = None expected_calls = None
actual_calls = None actual_calls = None
service_calls = None service_calls = None
def __init__(self): def __init__(self):
super(QubesTest, self).__init__() super(QubesTest, self).__init__()
#: expected calls and saved replies for them #: expected Admin API calls and saved replies for them
self.expected_calls = {} self.expected_calls = {}
#: expected qrexec service calls and saved replies for them
self.expected_service_calls = {}
#: actual calls made #: actual calls made
self.actual_calls = [] self.actual_calls = []
#: rpc service calls #: rpc service calls
@ -152,6 +160,14 @@ class QubesTest(qubesadmin.app.QubesBase):
def run_service(self, dest, service, **kwargs): def run_service(self, dest, service, **kwargs):
self.service_calls.append((dest, service, kwargs)) self.service_calls.append((dest, service, kwargs))
call_key = (dest, service)
# TODO: consider it as a future extension, as a replacement for
# checking app.service_calls later
# if call_key not in self.expected_service_calls:
# raise AssertionError('Unexpected service call {!r}'.format(call_key))
if call_key in self.expected_service_calls:
kwargs = kwargs.copy()
kwargs['stdout'] = io.BytesIO(self.expected_service_calls[call_key])
return TestProcess(lambda input: self.service_calls.append((dest, return TestProcess(lambda input: self.service_calls.append((dest,
service, input)), service, input)),
stdout=kwargs.get('stdout', None), stdout=kwargs.get('stdout', None),