__init__.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. # -*- encoding: utf8 -*-
  2. #
  3. # The Qubes OS Project, http://www.qubes-os.org
  4. #
  5. # Copyright (C) 2017 Marek Marczykowski-Górecki
  6. # <marmarek@invisiblethingslab.com>
  7. #
  8. # This program is free software; you can redistribute it and/or modify
  9. # it under the terms of the GNU Lesser General Public License as published by
  10. # the Free Software Foundation; either version 2.1 of the License, or
  11. # (at your option) any later version.
  12. #
  13. # This program is distributed in the hope that it will be useful,
  14. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  15. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  16. # GNU Lesser General Public License for more details.
  17. #
  18. # You should have received a copy of the GNU Lesser General Public License along
  19. # with this program; if not, see <http://www.gnu.org/licenses/>.
  20. import subprocess
  21. import traceback
  22. import unittest
  23. import io
  24. import qubesadmin
  25. import qubesadmin.app
  26. class TestVM(object):
  27. def __init__(self, name, **kwargs):
  28. self.name = name
  29. self.klass = 'TestVM'
  30. for key, value in kwargs.items():
  31. setattr(self, key, value)
  32. def get_power_state(self):
  33. return getattr(self, 'power_state', 'Running')
  34. def __str__(self):
  35. return self.name
  36. def __lt__(self, other):
  37. if isinstance(other, TestVM):
  38. return self.name < other.name
  39. return NotImplemented
  40. class TestVMCollection(dict):
  41. def __iter__(self):
  42. return iter(self.values())
  43. class TestProcess(object):
  44. def __init__(self, input_callback=None, stdout=None, stderr=None, stdout_data=None):
  45. self.input_callback = input_callback
  46. self.got_any_input = False
  47. self.stdin = io.BytesIO()
  48. # don't let anyone close it, before we get the value
  49. self.stdin_close = self.stdin.close
  50. self.stdin.close = self.store_input
  51. self.stdin.flush = self.store_input
  52. if stdout == subprocess.PIPE or stdout == subprocess.DEVNULL \
  53. or stdout is None:
  54. self.stdout = io.BytesIO()
  55. else:
  56. self.stdout = stdout
  57. if stderr == subprocess.PIPE or stderr == subprocess.DEVNULL \
  58. or stderr is None:
  59. self.stderr = io.BytesIO()
  60. else:
  61. self.stderr = stderr
  62. if stdout_data:
  63. self.stdout.write(stdout_data)
  64. # Seek to head so that it can be read later
  65. self.stdout.seek(0)
  66. self.returncode = 0
  67. def store_input(self):
  68. value = self.stdin.getvalue()
  69. if (not self.got_any_input or value) and self.input_callback:
  70. self.input_callback(self.stdin.getvalue())
  71. self.got_any_input = True
  72. self.stdin.truncate(0)
  73. def communicate(self, input=None):
  74. if input is not None:
  75. self.stdin.write(input)
  76. self.stdin.close()
  77. self.stdin_close()
  78. return self.stdout.read(), self.stderr.read()
  79. def wait(self):
  80. self.stdin_close()
  81. return 0
  82. def poll(self):
  83. return self.returncode
  84. class _AssertNotRaisesContext(object):
  85. """A context manager used to implement TestCase.assertNotRaises methods.
  86. Stolen from unittest and hacked. Regexp support stripped.
  87. """ # pylint: disable=too-few-public-methods
  88. def __init__(self, expected, test_case, expected_regexp=None):
  89. if expected_regexp is not None:
  90. raise NotImplementedError('expected_regexp is unsupported')
  91. self.expected = expected
  92. self.exception = None
  93. self.failureException = test_case.failureException
  94. def __enter__(self):
  95. return self
  96. def __exit__(self, exc_type, exc_value, tb):
  97. if exc_type is None:
  98. return True
  99. if issubclass(exc_type, self.expected):
  100. raise self.failureException(
  101. "{!r} raised, traceback:\n{!s}".format(
  102. exc_value, ''.join(traceback.format_tb(tb))))
  103. else:
  104. # pass through
  105. return False
  106. class QubesTest(qubesadmin.app.QubesBase):
  107. expected_service_calls = None
  108. expected_calls = None
  109. actual_calls = None
  110. service_calls = None
  111. def __init__(self):
  112. super(QubesTest, self).__init__()
  113. #: expected Admin API calls and saved replies for them
  114. self.expected_calls = {}
  115. #: expected qrexec service calls and saved replies for them
  116. self.expected_service_calls = {}
  117. #: actual calls made
  118. self.actual_calls = []
  119. #: rpc service calls
  120. self.service_calls = []
  121. def qubesd_call(self, dest, method, arg=None, payload=None,
  122. payload_stream=None):
  123. if payload_stream:
  124. payload = (payload or b'') + payload_stream.read()
  125. call_key = (dest, method, arg, payload)
  126. self.actual_calls.append(call_key)
  127. if call_key not in self.expected_calls:
  128. raise AssertionError('Unexpected call {!r}'.format(call_key))
  129. return_data = self.expected_calls[call_key]
  130. if isinstance(return_data, list):
  131. try:
  132. return_data = return_data.pop(0)
  133. except IndexError:
  134. raise AssertionError('Extra call {!r}'.format(call_key))
  135. return self._parse_qubesd_response(return_data)
  136. def run_service(self, dest, service, **kwargs):
  137. self.service_calls.append((dest, service, kwargs))
  138. call_key = (dest, service)
  139. # TODO: consider it as a future extension, as a replacement for
  140. # checking app.service_calls later
  141. # if call_key not in self.expected_service_calls:
  142. # raise AssertionError('Unexpected service call {!r}'.format(call_key))
  143. if call_key in self.expected_service_calls:
  144. kwargs = kwargs.copy()
  145. kwargs['stdout_data'] = self.expected_service_calls[call_key]
  146. return TestProcess(lambda input: self.service_calls.append((dest,
  147. service, input)),
  148. stdout=kwargs.get('stdout', None),
  149. stderr=kwargs.get('stderr', None),
  150. stdout_data=kwargs.get('stdout_data', None),
  151. )
  152. class QubesTestCase(unittest.TestCase):
  153. def setUp(self):
  154. super(QubesTestCase, self).setUp()
  155. self.app = QubesTest()
  156. def assertAllCalled(self):
  157. self.assertEqual(
  158. set(self.app.expected_calls.keys()),
  159. set(self.app.actual_calls))
  160. # and also check if calls expected multiple times were called
  161. self.assertFalse([(call, ret)
  162. for call, ret in self.app.expected_calls.items() if
  163. isinstance(ret, list) and ret],
  164. 'Some calls not called expected number of times')
  165. def assertNotRaises(self, excClass, callableObj=None, *args, **kwargs):
  166. """Fail if an exception of class excClass is raised
  167. by callableObj when invoked with arguments args and keyword
  168. arguments kwargs. If a different type of exception is
  169. raised, it will not be caught, and the test case will be
  170. deemed to have suffered an error, exactly as for an
  171. unexpected exception.
  172. If called with callableObj omitted or None, will return a
  173. context object used like this::
  174. with self.assertRaises(SomeException):
  175. do_something()
  176. The context manager keeps a reference to the exception as
  177. the 'exception' attribute. This allows you to inspect the
  178. exception after the assertion::
  179. with self.assertRaises(SomeException) as cm:
  180. do_something()
  181. the_exception = cm.exception
  182. self.assertEqual(the_exception.error_code, 3)
  183. """
  184. context = _AssertNotRaisesContext(excClass, self)
  185. if callableObj is None:
  186. return context
  187. with context:
  188. callableObj(*args, **kwargs)