qubes/api: refactor creating multiple qubesd sockets

Now there is a single function to do this, shared with tests.
This commit is contained in:
Wojtek Porczyk 2017-06-06 15:49:19 +02:00 committed by Marek Marczykowski-Górecki
parent bec58fc861
commit 96a66ac6bd
No known key found for this signature in database
GPG Key ID: 063938BA42CFA724
7 changed files with 78 additions and 55 deletions

View File

@ -20,10 +20,12 @@
# with this program; if not, see <http://www.gnu.org/licenses/>. # with this program; if not, see <http://www.gnu.org/licenses/>.
import asyncio import asyncio
import errno
import functools import functools
import io import io
import os import os
import shutil import shutil
import socket
import struct import struct
import traceback import traceback
@ -105,6 +107,10 @@ class AbstractQubesAPI(object):
There are also two helper functions for firing events associated with API There are also two helper functions for firing events associated with API
calls. calls.
''' '''
#: the preferred socket location (to be overridden in child's class)
SOCKNAME = None
def __init__(self, app, src, method_name, dest, arg, send_event=None): def __init__(self, app, src, method_name, dest, arg, send_event=None):
#: :py:class:`qubes.Qubes` object #: :py:class:`qubes.Qubes` object
self.app = app self.app = app
@ -332,27 +338,61 @@ class QubesDaemonProtocol(asyncio.Protocol):
self.transport.write(str(exc).encode('utf-8') + b'\0') self.transport.write(str(exc).encode('utf-8') + b'\0')
_umask_lock = asyncio.Lock()
@asyncio.coroutine @asyncio.coroutine
def create_server(sockpath, handler, app, debug=False, *, loop=None): def create_servers(*args, force=False, loop=None, **kwargs):
loop = loop or asyncio.get_event_loop() '''Create multiple Qubes API servers
try:
os.unlink(sockpath)
except FileNotFoundError:
pass
with (yield from _umask_lock): :param qubes.Qubes app: the app that is a backend of the servers
:param bool force: if :py:obj:`True`, unconditionaly remove existing \
sockets; if :py:obj:`False`, raise an error if there is some process \
listening to such socket
:param asyncio.Loop loop: loop
*args* are supposed to be classess inheriting from
:py:class:`AbstractQubesAPI`
*kwargs* (like *app* or *debug* for example) are passed to
:py:class:`QubesDaemonProtocol` constructor
'''
loop = loop or asyncio.get_event_loop()
servers = []
old_umask = os.umask(0o007) old_umask = os.umask(0o007)
try: try:
# XXX this can be optimised with asyncio.wait() to start servers in
# parallel, but I currently don't see the need
for handler in args:
sockpath = handler.SOCKNAME
assert sockpath is not None, \
'SOCKNAME needs to be overloaded in {}'.format(
type(handler).__name__)
if os.path.exists(sockpath):
if force:
os.unlink(sockpath)
else:
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
try:
sock.connect(sockpath)
except ConnectionRefusedError:
# dead socket, remove it anyway
os.unlink(sockpath)
else:
# woops, someone is listening
sock.close()
raise FileExistsError(errno.EEXIST,
'socket already exists: {!r}'.format(sockpath))
server = yield from loop.create_unix_server( server = yield from loop.create_unix_server(
functools.partial(QubesDaemonProtocol, functools.partial(QubesDaemonProtocol, handler, **kwargs),
handler, app=app, debug=debug),
sockpath) sockpath)
finally:
os.umask(old_umask)
for sock in server.sockets: for sock in server.sockets:
shutil.chown(sock.getsockname(), group='qubes') shutil.chown(sock.getsockname(), group='qubes')
return server servers.append(server)
finally:
os.umask(old_umask)
return servers

View File

@ -35,8 +35,6 @@ import qubes.utils
import qubes.vm import qubes.vm
import qubes.vm.qubesvm import qubes.vm.qubesvm
QUBESD_ADMIN_SOCK = '/var/run/qubesd.sock'
class QubesMgmtEventsDispatcher(object): class QubesMgmtEventsDispatcher(object):
def __init__(self, filters, send_event): def __init__(self, filters, send_event):
@ -75,6 +73,8 @@ class QubesAdminAPI(qubes.api.AbstractQubesAPI):
https://www.qubes-os.org/doc/mgmt1/ https://www.qubes-os.org/doc/mgmt1/
''' '''
SOCKNAME = '/var/run/qubesd.sock'
@qubes.api.method('admin.vmclass.List', no_payload=True) @qubes.api.method('admin.vmclass.List', no_payload=True)
@asyncio.coroutine @asyncio.coroutine
def vmclass_list(self): def vmclass_list(self):

View File

@ -29,19 +29,12 @@ import qubes.api.admin
import qubes.vm.adminvm import qubes.vm.adminvm
import qubes.vm.dispvm import qubes.vm.dispvm
QUBESD_INTERNAL_SOCK = '/var/run/qubesd.internal.sock'
class QubesInternalAPI(qubes.api.AbstractQubesAPI): class QubesInternalAPI(qubes.api.AbstractQubesAPI):
''' Communication interface for dom0 components, ''' Communication interface for dom0 components,
by design the input here is trusted.''' by design the input here is trusted.'''
#
# PRIVATE METHODS, not to be called via RPC
#
# SOCKNAME = '/var/run/qubesd.internal.sock'
# ACTUAL RPC CALLS
#
@qubes.api.method('internal.GetSystemInfo', no_payload=True) @qubes.api.method('internal.GetSystemInfo', no_payload=True)
@asyncio.coroutine @asyncio.coroutine

View File

@ -28,10 +28,10 @@ import qubes.api
import qubes.api.admin import qubes.api.admin
import qubes.vm.dispvm import qubes.vm.dispvm
QUBESD_MISC_SOCK = '/var/run/qubesd.misc.sock'
class QubesMiscAPI(qubes.api.AbstractQubesAPI): class QubesMiscAPI(qubes.api.AbstractQubesAPI):
SOCKNAME = '/var/run/qubesd.misc.sock'
@qubes.api.method('qubes.FeaturesRequest', no_payload=True) @qubes.api.method('qubes.FeaturesRequest', no_payload=True)
@asyncio.coroutine @asyncio.coroutine
def qubes_features_request(self): def qubes_features_request(self):

View File

@ -103,7 +103,6 @@ except OSError:
# command not found; let's assume we're outside # command not found; let's assume we're outside
pass pass
def skipUnlessDom0(test_item): def skipUnlessDom0(test_item):
'''Decorator that skips test outside dom0. '''Decorator that skips test outside dom0.
@ -591,12 +590,11 @@ class SystemTestsMixin(object):
) )
os.environ['QUBES_XML_PATH'] = XMLPATH os.environ['QUBES_XML_PATH'] = XMLPATH
self.qrexec_policy_server = self.loop.run_until_complete( self.qubesd = self.loop.run_until_complete(
qubes.api.create_server( qubes.api.create_servers(
qubes.api.internal.QUBESD_INTERNAL_SOCK, qubes.api.admin.QubesAdminAPI,
qubes.api.internal.QubesInternalAPI, qubes.api.internal.QubesInternalAPI,
app=self.app, app=self.app, debug=True))
debug=True))
def init_default_template(self, template=None): def init_default_template(self, template=None):
if template is None: if template is None:
@ -680,11 +678,13 @@ class SystemTestsMixin(object):
self.reload_db() self.reload_db()
def tearDown(self): def tearDown(self):
# close the server before super(), because that might close the loop # close the servers before super(), because that might close the loop
for sock in self.qrexec_policy_server.sockets: for server in self.qubesd:
for sock in server.sockets:
os.unlink(sock.getsockname()) os.unlink(sock.getsockname())
self.qrexec_policy_server.close() server.close()
self.loop.run_until_complete(self.qrexec_policy_server.wait_closed()) self.loop.run_until_complete(asyncio.wait([
server.wait_closed() for server in self.qubesd]))
super(SystemTestsMixin, self).tearDown() super(SystemTestsMixin, self).tearDown()
self.remove_test_vms() self.remove_test_vms()

View File

@ -413,7 +413,7 @@ def main():
logging.root.addHandler(ha_kmsg) logging.root.addHandler(ha_kmsg)
if not args.allow_running_along_qubesd \ if not args.allow_running_along_qubesd \
and os.path.exists(qubes.api.admin.QUBESD_ADMIN_SOCK): and os.path.exists(qubes.api.admin.QubesAdminAPI.SOCKNAME):
parser.error('refusing to run until qubesd is disabled') parser.error('refusing to run until qubesd is disabled')
runner = unittest.TextTestRunner(stream=sys.stdout, runner = unittest.TextTestRunner(stream=sys.stdout,

View File

@ -36,19 +36,11 @@ def main(args=None):
args.app.vmm.register_event_handlers(args.app) args.app.vmm.register_event_handlers(args.app)
servers = [] servers = loop.run_until_complete(qubes.api.create_servers(
servers.append(loop.run_until_complete(qubes.api.create_server(
qubes.api.admin.QUBESD_ADMIN_SOCK,
qubes.api.admin.QubesAdminAPI, qubes.api.admin.QubesAdminAPI,
app=args.app, debug=args.debug)))
servers.append(loop.run_until_complete(qubes.api.create_server(
qubes.api.internal.QUBESD_INTERNAL_SOCK,
qubes.api.internal.QubesInternalAPI, qubes.api.internal.QubesInternalAPI,
app=args.app, debug=args.debug)))
servers.append(loop.run_until_complete(qubes.api.create_server(
qubes.api.misc.QUBESD_MISC_SOCK,
qubes.api.misc.QubesMiscAPI, qubes.api.misc.QubesMiscAPI,
app=args.app, debug=args.debug))) app=args.app, debug=args.debug))
socknames = [] socknames = []
for server in servers: for server in servers:
@ -71,11 +63,9 @@ def main(args=None):
try: try:
os.unlink(sockname) os.unlink(sockname)
except FileNotFoundError: except FileNotFoundError:
# XXX args.app.log.warning(
# We had our socket unlinked by somebody else, possibly other 'socket {} got unlinked sometime before shutdown'.format(
# qubesd instance. That also means we probably unlinked their sockname))
# socket when creating our server...
pass
finally: finally:
loop.close() loop.close()