Explorar el Código

qubes/mgmt: explicit method decorator and misc improvements

- Get rid of @not_in_api, exchange for explicit @api() decorator.
- Old @no_payload decorator becomes an argument (keyword-only).
- Factor out AbstractQubesMgmt class to be a base class for other mgmt
  backends.
- Use async def instead of @asyncio.coroutine.

QubesOS/qubes-issues#2622
Wojtek Porczyk hace 7 años
padre
commit
c4ef02c377
Se han modificado 1 ficheros con 116 adiciones y 108 borrados
  1. 116 108
      qubes/mgmt.py

+ 116 - 108
qubes/mgmt.py

@@ -22,10 +22,10 @@
 Qubes OS Management API
 '''
 
-import asyncio
+import functools
 import string
 
-import functools
+import pkg_resources
 
 import qubes.vm.qubesvm
 import qubes.storage
@@ -40,30 +40,59 @@ class PermissionDenied(Exception):
     pass
 
 
-def not_in_api(func):
-    '''Decorator for methods not intended to appear in API.
+def api(name, *, no_payload=False):
+    '''Decorator factory for methods intended to appear in API.
+
+    The decorated method can be called from public API using a child of
+    :py:class:`AbstractQubesMgmt` class. The method becomes "public", and can be
+    called using remote management interface.
+
+    :param str name: qrexec rpc method name
+    :param bool no_payload: if :py:obj:`True`, will barf on non-empty payload; \
+        also will not pass payload at all to the method
+
+    The expected function method should have one argument (other than usual
+    *self*), ``untrusted_payload``, which will contain the payload.
 
-    The decorated method cannot be called from public API using
-    :py:class:`QubesMgmt` class. The method becomes "private", and can be
-    called only as a helper for other methods.
+    .. warning::
+        This argument has to be named such, to remind the programmer that the
+        content of this variable is indeed untrusted.
+
+    If *no_payload* is true, then the method is called with no arguments.
     '''
-    func.not_in_api = True
-    return func
 
+    # TODO regexp for vm/dev classess; supply regexp groups as untrusted_ kwargs
 
-def no_payload(func):
-    @functools.wraps(func)
-    def wrapper(self, untrusted_payload):
-        if untrusted_payload != b'':
-            raise ProtocolError('unexpected payload')
-        return func(self)
-    return wrapper
+    def decorator(func):
+        if no_payload:
+            # the following assignment is needed for how closures work in Python
+            _func = func
+            @functools.wraps(_func)
+            def wrapper(self, untrusted_payload):
+                if untrusted_payload != b'':
+                    raise ProtocolError('unexpected payload')
+                return _func(self)
+            func = wrapper
 
+        func._rpcname = name  # pylint: disable=protected-access
+        return func
 
-class QubesMgmt(object):
-    '''Implementation of Qubes Management API calls
+    return decorator
+
+class AbstractQubesMgmt(object):
+    '''Common code for Qubes Management Protocol handling
+
+    Different interfaces can expose different API call sets, however they share
+    common protocol and common implementation framework. This class is the
+    latter.
 
-    This class contains all the methods available in the API.
+    To implement a new interface, inherit from this class and write at least one
+    method and decorate it with :py:func:`api` decorator. It will have access to
+    pre-defined attributes: :py:attr:`app`, :py:attr:`src`, :py:attr:`dest`,
+    :py:attr:`arg` and :py:attr:`method`.
+
+    There are also two helper functions for firing events associated with API
+    calls.
     '''
     def __init__(self, app, src, method, dest, arg):
         #: :py:class:`qubes.Qubes` object
@@ -81,62 +110,55 @@ class QubesMgmt(object):
         #: name of the method
         self.method = method.decode('ascii')
 
-        untrusted_func_name = self.method
-        if untrusted_func_name.startswith('mgmt.'):
-            untrusted_func_name = untrusted_func_name[5:]
-        untrusted_func_name = untrusted_func_name.lower().replace('.', '_')
+        untrusted_candidates = []
+        for attr in dir(self):
+            untrusted_func = getattr(self, attr)
 
-        if untrusted_func_name.startswith('_') \
-                or not '_' in untrusted_func_name:
-            raise ProtocolError(
-                'possibly malicious function name: {!r}'.format(
-                    untrusted_func_name))
+            if not callable(untrusted_func):
+                continue
 
-        try:
-            untrusted_func = getattr(self, untrusted_func_name)
-        except AttributeError:
-            raise ProtocolError(
-                'no such attribute: {!r}'.format(
-                    untrusted_func_name))
+            try:
+                # pylint: disable=protected-access
+                if untrusted_func._rpcname != self.method:
+                    continue
+            except AttributeError:
+                continue
 
-        if not asyncio.iscoroutinefunction(untrusted_func):
-            raise ProtocolError(
-                'no such method: {!r}'.format(
-                    untrusted_func_name))
+            untrusted_candidates.append(untrusted_func)
 
-        if getattr(untrusted_func, 'not_in_api', False):
-            raise ProtocolError(
-                'attempt to call private method: {!r}'.format(
-                    untrusted_func_name))
+        if not untrusted_candidates:
+            raise ProtocolError('no such method: {!r}'.format(self.method))
 
-        self.execute = untrusted_func
-        del untrusted_func_name
-        del untrusted_func
+        assert len(untrusted_candidates) == 1, \
+            'multiple candidates for method {!r}'.format(self.method)
 
-    #
-    # PRIVATE METHODS, not to be called via RPC
-    #
+        #: the method to execute
+        self.execute = untrusted_candidates[0]
+        del untrusted_candidates
 
-    @not_in_api
     def fire_event_for_permission(self, **kwargs):
         '''Fire an event on the source qube to check for permission'''
         return self.src.fire_event_pre('mgmt-permission:{}'.format(self.method),
             dest=self.dest, arg=self.arg, **kwargs)
 
-    @not_in_api
     def fire_event_for_filter(self, iterable, **kwargs):
         '''Fire an event on the source qube to filter for permission'''
         for selector in self.fire_event_for_permission(**kwargs):
             iterable = filter(selector, iterable)
         return iterable
 
-    #
-    # ACTUAL RPC CALLS
-    #
 
-    @asyncio.coroutine
-    @no_payload
-    def vm_list(self):
+class QubesMgmt(AbstractQubesMgmt):
+    '''Implementation of Qubes Management API calls
+
+    This class contains all the methods available in the main API.
+
+    .. seealso::
+        https://www.qubes-os.org/doc/mgmt1/
+    '''
+
+    @api('mgmt.vm.List', no_payload=True)
+    async def vm_list(self):
         '''List all the domains'''
         assert not self.arg
 
@@ -151,9 +173,8 @@ class QubesMgmt(object):
                 vm.get_power_state())
             for vm in sorted(domains))
 
-    @asyncio.coroutine
-    @no_payload
-    def vm_property_list(self):
+    @api('mgmt.vm.property.List', no_payload=True)
+    async def vm_property_list(self):
         '''List all properties on a qube'''
         assert not self.arg
 
@@ -161,9 +182,8 @@ class QubesMgmt(object):
 
         return ''.join('{}\n'.format(prop.__name__) for prop in properties)
 
-    @asyncio.coroutine
-    @no_payload
-    def vm_property_get(self):
+    @api('mgmt.vm.property.Get', no_payload=True)
+    async def vm_property_get(self):
         '''Get a value of one property'''
         assert self.arg in self.dest.property_list()
 
@@ -192,8 +212,8 @@ class QubesMgmt(object):
                 property_type,
                 str(value) if value is not None else '')
 
-    @asyncio.coroutine
-    def vm_property_set(self, untrusted_payload):
+    @api('mgmt.vm.property.Set')
+    async def vm_property_set(self, untrusted_payload):
         assert self.arg in self.dest.property_list()
 
         property_def = self.dest.property_get_def(self.arg)
@@ -204,9 +224,8 @@ class QubesMgmt(object):
         setattr(self.dest, self.arg, newvalue)
         self.app.save()
 
-    @asyncio.coroutine
-    @no_payload
-    def vm_property_help(self):
+    @api('mgmt.vm.property.Help', no_payload=True)
+    async def vm_property_help(self):
         '''Get help for one property'''
         assert self.arg in self.dest.property_list()
 
@@ -219,9 +238,8 @@ class QubesMgmt(object):
 
         return qubes.utils.format_doc(doc)
 
-    @asyncio.coroutine
-    @no_payload
-    def vm_property_reset(self):
+    @api('mgmt.vm.property.Reset', no_payload=True)
+    async def vm_property_reset(self):
         '''Reset a property to a default value'''
         assert self.arg in self.dest.property_list()
 
@@ -230,17 +248,15 @@ class QubesMgmt(object):
         delattr(self.dest, self.arg)
         self.app.save()
 
-    @asyncio.coroutine
-    @no_payload
-    def vm_volume_list(self):
+    @api('mgmt.vm.volume.List', no_payload=True)
+    async def vm_volume_list(self):
         assert not self.arg
 
         volume_names = self.fire_event_for_filter(self.dest.volumes.keys())
         return ''.join('{}\n'.format(name) for name in volume_names)
 
-    @asyncio.coroutine
-    @no_payload
-    def vm_volume_info(self):
+    @api('mgmt.vm.volume.Info', no_payload=True)
+    async def vm_volume_info(self):
         assert self.arg in self.dest.volumes.keys()
 
         self.fire_event_for_permission()
@@ -253,9 +269,8 @@ class QubesMgmt(object):
         return ''.join('{}={}\n'.format(key, getattr(volume, key)) for key in
             volume_properties)
 
-    @asyncio.coroutine
-    @no_payload
-    def vm_volume_listsnapshots(self):
+    @api('mgmt.vm.volume.ListSnapshots', no_payload=True)
+    async def vm_volume_listsnapshots(self):
         assert self.arg in self.dest.volumes.keys()
 
         volume = self.dest.volumes[self.arg]
@@ -264,8 +279,8 @@ class QubesMgmt(object):
 
         return ''.join('{}\n'.format(revision) for revision in revisions)
 
-    @asyncio.coroutine
-    def vm_volume_revert(self, untrusted_payload):
+    @api('mgmt.vm.volume.Revert')
+    async def vm_volume_revert(self, untrusted_payload):
         assert self.arg in self.dest.volumes.keys()
         untrusted_revision = untrusted_payload.decode('ascii').strip()
         del untrusted_payload
@@ -280,8 +295,8 @@ class QubesMgmt(object):
         self.dest.storage.get_pool(volume).revert(revision)
         self.app.save()
 
-    @asyncio.coroutine
-    def vm_volume_resize(self, untrusted_payload):
+    @api('mgmt.vm.volume.Resize')
+    async def vm_volume_resize(self, untrusted_payload):
         assert self.arg in self.dest.volumes.keys()
         untrusted_size = untrusted_payload.decode('ascii').strip()
         del untrusted_payload
@@ -295,9 +310,8 @@ class QubesMgmt(object):
         self.dest.storage.resize(self.arg, size)
         self.app.save()
 
-    @asyncio.coroutine
-    @no_payload
-    def pool_list(self):
+    @api('mgmt.pool.List', no_payload=True)
+    async def pool_list(self):
         assert not self.arg
         assert self.dest.name == 'dom0'
 
@@ -305,9 +319,8 @@ class QubesMgmt(object):
 
         return ''.join('{}\n'.format(pool) for pool in pools)
 
-    @asyncio.coroutine
-    @no_payload
-    def pool_listdrivers(self):
+    @api('mgmt.pool.ListDrivers', no_payload=True)
+    async def pool_listdrivers(self):
         assert self.dest.name == 'dom0'
         assert not self.arg
 
@@ -318,9 +331,8 @@ class QubesMgmt(object):
             ' '.join(qubes.storage.driver_parameters(driver)))
             for driver in drivers)
 
-    @asyncio.coroutine
-    @no_payload
-    def pool_info(self):
+    @api('mgmt.pool.Info', no_payload=True)
+    async def pool_info(self):
         assert self.dest.name == 'dom0'
         assert self.arg in self.app.pools.keys()
 
@@ -331,8 +343,8 @@ class QubesMgmt(object):
         return ''.join('{}={}\n'.format(prop, val)
             for prop, val in sorted(pool.config.items()))
 
-    @asyncio.coroutine
-    def pool_add(self, untrusted_payload):
+    @api('mgmt.pool.Add')
+    async def pool_add(self, untrusted_payload):
         assert self.dest.name == 'dom0'
         drivers = qubes.storage.pool_drivers()
         assert self.arg in drivers
@@ -365,9 +377,8 @@ class QubesMgmt(object):
         self.app.add_pool(name=pool_name, driver=self.arg, **pool_config)
         self.app.save()
 
-    @asyncio.coroutine
-    @no_payload
-    def pool_remove(self):
+    @api('mgmt.pool.Remove', no_payload=True)
+    async def pool_remove(self):
         assert self.dest.name == 'dom0'
         assert self.arg in self.app.pools.keys()
 
@@ -376,9 +387,8 @@ class QubesMgmt(object):
         self.app.remove_pool(self.arg)
         self.app.save()
 
-    @asyncio.coroutine
-    @no_payload
-    def label_list(self):
+    @api('mgmt.label.List', no_payload=True)
+    async def label_list(self):
         assert self.dest.name == 'dom0'
         assert not self.arg
 
@@ -386,9 +396,8 @@ class QubesMgmt(object):
 
         return ''.join('{}\n'.format(label.name) for label in labels)
 
-    @asyncio.coroutine
-    @no_payload
-    def label_get(self):
+    @api('mgmt.label.Get', no_payload=True)
+    async def label_get(self):
         assert self.dest.name == 'dom0'
 
         try:
@@ -400,8 +409,8 @@ class QubesMgmt(object):
 
         return label.color
 
-    @asyncio.coroutine
-    def label_create(self, untrusted_payload):
+    @api('mgmt.label.Create')
+    async def label_create(self, untrusted_payload):
         assert self.dest.name == 'dom0'
 
         # don't confuse label name with label index
@@ -435,9 +444,8 @@ class QubesMgmt(object):
         self.app.labels[new_index] = label
         self.app.save()
 
-    @asyncio.coroutine
-    @no_payload
-    def label_remove(self):
+    @api('mgmt.label.Remove', no_payload=True)
+    async def label_remove(self):
         assert self.dest.name == 'dom0'
 
         try: