Bläddra i källkod

qubes: reorganise API protocols

Now instantiating API servers is handled by common function. This is,
among other reasons, for creating ad-hoc sockets for tests.
Wojtek Porczyk 7 år sedan
förälder
incheckning
858e547525
6 ändrade filer med 202 tillägg och 196 borttagningar
  1. 179 0
      qubes/api/__init__.py
  2. 2 0
      qubes/api/admin.py
  3. 2 0
      qubes/api/internal.py
  4. 2 0
      qubes/api/misc.py
  5. 1 0
      qubes/tests/integ/vm_qrexec_gui.py
  6. 16 196
      qubes/tools/qubesd.py

+ 179 - 0
qubes/api/__init__.py

@@ -18,9 +18,16 @@
 #
 # You should have received a copy of the GNU General Public License along
 # with this program; if not, see <http://www.gnu.org/licenses/>.
+
 import asyncio
 import functools
+import io
+import os
+import shutil
+import struct
+import traceback
 
+import qubes.exc
 
 class ProtocolError(AssertionError):
     '''Raised when something is wrong with data received'''
@@ -175,3 +182,175 @@ class AbstractQubesAPI(object):
         '''Fire an event on the source qube to filter for permission'''
         return apply_filters(iterable,
             self.fire_event_for_permission(**kwargs))
+
+
+class QubesDaemonProtocol(asyncio.Protocol):
+    buffer_size = 65536
+    header = struct.Struct('Bx')
+
+    def __init__(self, handler, *args, app, debug=False, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.handler = handler
+        self.app = app
+        self.untrusted_buffer = io.BytesIO()
+        self.len_untrusted_buffer = 0
+        self.transport = None
+        self.debug = debug
+        self.event_sent = False
+        self.mgmt = None
+
+    def connection_made(self, transport):
+        self.transport = transport
+
+    def connection_lost(self, exc):
+        self.untrusted_buffer.close()
+        # for cancellable operation, interrupt it, otherwise it will do nothing
+        if self.mgmt is not None:
+            self.mgmt.cancel()
+        self.transport = None
+
+    def data_received(self, untrusted_data):  # pylint: disable=arguments-differ
+        if self.len_untrusted_buffer + len(untrusted_data) > self.buffer_size:
+            self.app.log.warning('request too long')
+            self.transport.abort()
+            self.untrusted_buffer.close()
+            return
+
+        self.len_untrusted_buffer += \
+            self.untrusted_buffer.write(untrusted_data)
+
+    def eof_received(self):
+        try:
+            src, meth, dest, arg, untrusted_payload = \
+                self.untrusted_buffer.getvalue().split(b'\0', 4)
+        except ValueError:
+            self.app.log.warning('framing error')
+            self.transport.abort()
+            return
+        finally:
+            self.untrusted_buffer.close()
+
+        asyncio.ensure_future(self.respond(
+            src, meth, dest, arg, untrusted_payload=untrusted_payload))
+
+        return True
+
+    @asyncio.coroutine
+    def respond(self, src, meth, dest, arg, *, untrusted_payload):
+        try:
+            self.mgmt = self.handler(self.app, src, meth, dest, arg,
+                self.send_event)
+            response = yield from self.mgmt.execute(
+                untrusted_payload=untrusted_payload)
+            assert not (self.event_sent and response)
+            if self.transport is None:
+                return
+
+        # except clauses will fall through to transport.abort() below
+
+        except PermissionDenied:
+            self.app.log.warning(
+                'permission denied for call %s+%s (%s → %s) '
+                'with payload of %d bytes',
+                    meth, arg, src, dest, len(untrusted_payload))
+
+        except ProtocolError:
+            self.app.log.warning(
+                'protocol error for call %s+%s (%s → %s) '
+                'with payload of %d bytes',
+                    meth, arg, src, dest, len(untrusted_payload))
+
+        except qubes.exc.QubesException as err:
+            msg = ('%r while calling '
+                'src=%r meth=%r dest=%r arg=%r len(untrusted_payload)=%d')
+
+            if self.debug:
+                self.app.log.exception(msg,
+                    err, src, meth, dest, arg, len(untrusted_payload))
+            else:
+                self.app.log.info(msg,
+                    err, src, meth, dest, arg, len(untrusted_payload))
+            if self.transport is not None:
+                self.send_exception(err)
+                self.transport.write_eof()
+                self.transport.close()
+            return
+
+        except Exception:  # pylint: disable=broad-except
+            self.app.log.exception(
+                'unhandled exception while calling '
+                'src=%r meth=%r dest=%r arg=%r len(untrusted_payload)=%d',
+                    src, meth, dest, arg, len(untrusted_payload))
+
+        else:
+            if not self.event_sent:
+                self.send_response(response)
+            try:
+                self.transport.write_eof()
+            except NotImplementedError:
+                pass
+            self.transport.close()
+            return
+
+        # this is reached if from except: blocks; do not put it in finally:,
+        # because this will prevent the good case from sending the reply
+        self.transport.abort()
+
+    def send_header(self, *args):
+        self.transport.write(self.header.pack(*args))
+
+    def send_response(self, content):
+        assert not self.event_sent
+        self.send_header(0x30)
+        if content is not None:
+            self.transport.write(content.encode('utf-8'))
+
+    def send_event(self, subject, event, **kwargs):
+        self.event_sent = True
+        self.send_header(0x31)
+
+        if subject is not self.app:
+            self.transport.write(subject.name.encode('ascii'))
+        self.transport.write(b'\0')
+
+        self.transport.write(event.encode('ascii') + b'\0')
+
+        for k, v in kwargs.items():
+            self.transport.write('{}\0{}\0'.format(k, str(v)).encode('ascii'))
+        self.transport.write(b'\0')
+
+    def send_exception(self, exc):
+        self.send_header(0x32)
+
+        self.transport.write(type(exc).__name__.encode() + b'\0')
+
+        if self.debug:
+            self.transport.write(''.join(traceback.format_exception(
+                type(exc), exc, exc.__traceback__)).encode('utf-8'))
+        self.transport.write(b'\0')
+
+        self.transport.write(str(exc).encode('utf-8') + b'\0')
+
+
+_umask_lock = asyncio.Lock()
+
+@asyncio.coroutine
+def create_server(sockpath, handler, app, debug=False, *, loop=None):
+    loop = loop or asyncio.get_event_loop()
+    try:
+        os.unlink(sockpath)
+    except FileNotFoundError:
+        pass
+
+    with (yield from _umask_lock):
+        old_umask = os.umask(0o007)
+        try:
+            server = yield from loop.create_unix_server(
+                functools.partial(QubesDaemonProtocol,
+                    handler, app=app, debug=debug),
+                sockpath)
+        finally:
+            os.umask(old_umask)
+
+    shutil.chown(sockpath, group='qubes')
+    return server

+ 2 - 0
qubes/api/admin.py

@@ -35,6 +35,8 @@ import qubes.utils
 import qubes.vm
 import qubes.vm.qubesvm
 
+QUBESD_ADMIN_SOCK = '/var/run/qubesd.sock'
+
 
 class QubesMgmtEventsDispatcher(object):
     def __init__(self, filters, send_event):

+ 2 - 0
qubes/api/internal.py

@@ -29,6 +29,8 @@ import qubes.api.admin
 import qubes.vm.adminvm
 import qubes.vm.dispvm
 
+QUBESD_INTERNAL_SOCK = '/var/run/qubesd.internal.sock'
+
 
 class QubesInternalAPI(qubes.api.AbstractQubesAPI):
     ''' Communication interface for dom0 components,

+ 2 - 0
qubes/api/misc.py

@@ -28,6 +28,8 @@ import qubes.api
 import qubes.api.admin
 import qubes.vm.dispvm
 
+QUBESD_MISC_SOCK = '/var/run/qubesd.misc.sock'
+
 
 class QubesMiscAPI(qubes.api.AbstractQubesAPI):
     @qubes.api.method('qubes.FeaturesRequest', no_payload=True)

+ 1 - 0
qubes/tests/integ/vm_qrexec_gui.py

@@ -260,6 +260,7 @@ class TC_00_AppVMMixin(qubes.tests.SystemTestsMixin):
         self.loop.run_until_complete(self.testvm1.start())
         self.loop.run_until_complete(run(self))
 
+    @unittest.skip('#2851, because there is no GUI in vm')
     def test_052_qrexec_vm_service_eof(self):
         """Test for EOF transmission VM(src)->VM(dst)"""
 

+ 16 - 196
qubes/tools/qubesd.py

@@ -1,13 +1,8 @@
 #!/usr/bin/env python3.6
 
 import asyncio
-import functools
-import io
 import os
-import shutil
 import signal
-import struct
-import traceback
 
 import libvirtaio
 
@@ -19,160 +14,7 @@ import qubes.api.misc
 import qubes.utils
 import qubes.vm.qubesvm
 
-QUBESD_SOCK = '/var/run/qubesd.sock'
-QUBESD_INTERNAL_SOCK = '/var/run/qubesd.internal.sock'
-QUBESD_MISC_SOCK = '/var/run/qubesd.misc.sock'
-
-class QubesDaemonProtocol(asyncio.Protocol):
-    buffer_size = 65536
-    header = struct.Struct('Bx')
-
-    def __init__(self, handler, *args, app, debug=False, **kwargs):
-        super().__init__(*args, **kwargs)
-        self.handler = handler
-        self.app = app
-        self.untrusted_buffer = io.BytesIO()
-        self.len_untrusted_buffer = 0
-        self.transport = None
-        self.debug = debug
-        self.event_sent = False
-        self.mgmt = None
-
-    def connection_made(self, transport):
-        self.transport = transport
-
-    def connection_lost(self, exc):
-        self.untrusted_buffer.close()
-        # for cancellable operation, interrupt it, otherwise it will do nothing
-        if self.mgmt is not None:
-            self.mgmt.cancel()
-        self.transport = None
-
-    def data_received(self, untrusted_data):  # pylint: disable=arguments-differ
-        if self.len_untrusted_buffer + len(untrusted_data) > self.buffer_size:
-            self.app.log.warning('request too long')
-            self.transport.abort()
-            self.untrusted_buffer.close()
-            return
-
-        self.len_untrusted_buffer += \
-            self.untrusted_buffer.write(untrusted_data)
-
-    def eof_received(self):
-        try:
-            src, method, dest, arg, untrusted_payload = \
-                self.untrusted_buffer.getvalue().split(b'\0', 4)
-        except ValueError:
-            self.app.log.warning('framing error')
-            self.transport.abort()
-            return
-        finally:
-            self.untrusted_buffer.close()
-
-        asyncio.ensure_future(self.respond(
-            src, method, dest, arg, untrusted_payload=untrusted_payload))
-
-        return True
-
-    @asyncio.coroutine
-    def respond(self, src, method, dest, arg, *, untrusted_payload):
-        try:
-            self.mgmt = self.handler(self.app, src, method, dest, arg,
-                self.send_event)
-            response = yield from self.mgmt.execute(
-                untrusted_payload=untrusted_payload)
-            assert not (self.event_sent and response)
-            if self.transport is None:
-                return
-
-        # except clauses will fall through to transport.abort() below
-
-        except qubes.api.PermissionDenied:
-            self.app.log.warning(
-                'permission denied for call %s+%s (%s → %s) '
-                'with payload of %d bytes',
-                    method, arg, src, dest, len(untrusted_payload))
-
-        except qubes.api.ProtocolError:
-            self.app.log.warning(
-                'protocol error for call %s+%s (%s → %s) '
-                'with payload of %d bytes',
-                    method, arg, src, dest, len(untrusted_payload))
-
-        except qubes.exc.QubesException as err:
-            msg = ('%r while calling '
-                'src=%r method=%r dest=%r arg=%r len(untrusted_payload)=%d')
-
-            if self.debug:
-                self.app.log.exception(msg,
-                    err, src, method, dest, arg, len(untrusted_payload))
-            else:
-                self.app.log.info(msg,
-                    err, src, method, dest, arg, len(untrusted_payload))
-            if self.transport is not None:
-                self.send_exception(err)
-                self.transport.write_eof()
-                self.transport.close()
-            return
-
-        except Exception:  # pylint: disable=broad-except
-            self.app.log.exception(
-                'unhandled exception while calling '
-                'src=%r method=%r dest=%r arg=%r len(untrusted_payload)=%d',
-                    src, method, dest, arg, len(untrusted_payload))
-
-        else:
-            if not self.event_sent:
-                self.send_response(response)
-            try:
-                self.transport.write_eof()
-            except NotImplementedError:
-                pass
-            self.transport.close()
-            return
-
-        # this is reached if from except: blocks; do not put it in finally:,
-        # because this will prevent the good case from sending the reply
-        self.transport.abort()
-
-
-    def send_header(self, *args):
-        self.transport.write(self.header.pack(*args))
-
-    def send_response(self, content):
-        assert not self.event_sent
-        self.send_header(0x30)
-        if content is not None:
-            self.transport.write(content.encode('utf-8'))
-
-    def send_event(self, subject, event, **kwargs):
-        self.event_sent = True
-        self.send_header(0x31)
-
-        if subject is not self.app:
-            self.transport.write(subject.name.encode('ascii'))
-        self.transport.write(b'\0')
-
-        self.transport.write(event.encode('ascii') + b'\0')
-
-        for k, v in kwargs.items():
-            self.transport.write('{}\0{}\0'.format(k, str(v)).encode('ascii'))
-        self.transport.write(b'\0')
-
-    def send_exception(self, exc):
-        self.send_header(0x32)
-
-        self.transport.write(type(exc).__name__.encode() + b'\0')
-
-        if self.debug:
-            self.transport.write(''.join(traceback.format_exception(
-                type(exc), exc, exc.__traceback__)).encode('utf-8'))
-        self.transport.write(b'\0')
-
-        self.transport.write(str(exc).encode('utf-8') + b'\0')
-
-
-def sighandler(loop, signame, *servers):
+def sighandler(loop, signame, servers):
     print('caught {}, exiting'.format(signame))
     for server in servers:
         server.close()
@@ -183,7 +25,6 @@ parser.add_argument('--debug', action='store_true', default=False,
     help='Enable verbose error logging (all exceptions with full '
          'tracebacks) and also send tracebacks to Admin API clients')
 
-
 def main(args=None):
     loop = asyncio.get_event_loop()
     libvirtaio.virEventRegisterAsyncIOImpl(loop=loop)
@@ -195,42 +36,23 @@ def main(args=None):
 
     args.app.vmm.register_event_handlers(args.app)
 
-    try:
-        os.unlink(QUBESD_SOCK)
-    except FileNotFoundError:
-        pass
-    old_umask = os.umask(0o007)
-    server = loop.run_until_complete(loop.create_unix_server(
-        functools.partial(QubesDaemonProtocol, qubes.api.admin.QubesAdminAPI,
-            app=args.app, debug=args.debug), QUBESD_SOCK))
-    shutil.chown(QUBESD_SOCK, group='qubes')
-
-    try:
-        os.unlink(QUBESD_INTERNAL_SOCK)
-    except FileNotFoundError:
-        pass
-    server_internal = loop.run_until_complete(loop.create_unix_server(
-        functools.partial(QubesDaemonProtocol,
-            qubes.api.internal.QubesInternalAPI,
-            app=args.app, debug=args.debug), QUBESD_INTERNAL_SOCK))
-    shutil.chown(QUBESD_INTERNAL_SOCK, group='qubes')
-
-    try:
-        os.unlink(QUBESD_MISC_SOCK)
-    except FileNotFoundError:
-        pass
-    server_misc = loop.run_until_complete(loop.create_unix_server(
-        functools.partial(QubesDaemonProtocol,
-            qubes.api.misc.QubesMiscAPI,
-            app=args.app, debug=args.debug), QUBESD_MISC_SOCK))
-    shutil.chown(QUBESD_MISC_SOCK, group='qubes')
-
-    os.umask(old_umask)
-    del old_umask
+    servers = []
+    servers.append(loop.run_until_complete(qubes.api.create_server(
+        qubes.api.admin.QUBESD_ADMIN_SOCK,
+        qubes.api.admin.QubesAdminAPI,
+        app=args.app, debug=args.debug)))
+    servers.append(loop.run_until_complete(qubes.api.create_server(
+        qubes.api.internal.QUBESD_INTERNAL_SOCK,
+        qubes.api.internal.QubesInternalAPI,
+        app=args.app, debug=args.debug)))
+    servers.append(loop.run_until_complete(qubes.api.create_server(
+        qubes.api.misc.QUBESD_MISC_SOCK,
+        qubes.api.misc.QubesMiscAPI,
+        app=args.app, debug=args.debug)))
 
     for signame in ('SIGINT', 'SIGTERM'):
         loop.add_signal_handler(getattr(signal, signame),
-            sighandler, loop, signame, server, server_internal, server_misc)
+            sighandler, loop, signame, servers)
 
     qubes.utils.systemd_notify()
     # make sure children will not inherit this
@@ -239,9 +61,7 @@ def main(args=None):
     try:
         loop.run_forever()
         loop.run_until_complete(asyncio.wait([
-            server.wait_closed(),
-            server_internal.wait_closed(),
-        ]))
+            server.wait_closed() for server in servers]))
     finally:
         loop.close()