Browse Source

qubes: new devices API

Allow device plugin to list attached and available devices. Enforce
at API level every device being exposed by some domain.

This commit only changes devices API, but not update existing users
(pci) yet.

QubesOS/qubes-issues#2257
Marek Marczykowski-Górecki 7 years ago
parent
commit
d7a3c0d319

+ 1 - 1
doc/example.xml

@@ -22,7 +22,7 @@
             </features>
 
             <devices class="pci">
-                <device>01:23.45</device>
+                <device backend-domain="dom0" id="01:23.45"/>
             </devices>
         </domain>
 

+ 181 - 44
qubes/devices.py

@@ -23,8 +23,30 @@
 # 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
 #
 
-import re
-
+'''API for various types of devices.
+
+Main concept is that some domain main
+expose (potentially multiple) devices, which can be attached to other domains.
+Devices can be of different classes (like 'pci', 'usb', etc). Each device
+class is implemented by an extension.
+
+Devices are identified by pair of (backend domain, `ident`), where `ident` is
+:py:class:`str`.
+
+Such extension should provide:
+ - `qubes.devices` endpoint - a class descendant from
+ :py:class:`qubes.devices.DeviceInfo`, designed to hold device description (
+ including class-specific properties)
+ - handle `device-attach:class` and `device-detach:class` events for
+ performing the attach/detach action; events are fired even when domain isn't
+ running and extension should be prepared for this
+ - handle `device-list:class` event - list devices exposed by particular
+ domain; it should return list of appropriate DeviceInfo objects
+ - handle `device-get:class` event - get one device object exposed by this
+ domain of given identifier
+ - handle `device-list-attached:class` event - list currently attached
+ devices to this domain
+'''
 import qubes.utils
 
 
@@ -35,6 +57,52 @@ class DeviceCollection(object):
 
     :param vm: VM for which we manage devices
     :param class_: device class
+
+    This class emits following events on VM object:
+
+        .. event:: device-attach:<class> (device)
+
+            Fired when device is attached to a VM.
+
+            :param device: :py:class:`DeviceInfo` object to be attached
+
+        .. event:: device-pre-attach:<class> (device)
+
+            Fired before device is attached to a VM
+
+            :param device: :py:class:`DeviceInfo` object to be attached
+
+        .. event:: device-detach:<class> (device)
+
+            Fired when device is detached from a VM.
+
+            :param device: :py:class:`DeviceInfo` object to be attached
+
+        .. event:: device-pre-detach:<class> (device)
+
+            Fired before device is detached from a VM
+
+            :param device: :py:class:`DeviceInfo` object to be attached
+
+        .. event:: device-list:<class>
+
+            Fired to get list of devices exposed by a VM. Handlers of this
+            event should return a list of py:class:`DeviceInfo` objects (or
+            appropriate class specific descendant)
+
+        .. event:: device-get:<class> (ident)
+
+            Fired to get a single device, given by the `ident` parameter.
+            Handlers of this event should either return appropriate object of
+            :py:class:`DeviceInfo`, or :py:obj:`None`. Especially should not
+            raise :py:class:`exceptions.KeyError`.
+
+        .. event:: device-list-attached:<class> (persistent)
+
+            Fired to get list of currently attached devices to a VM. Handlers
+            of this event should return list of devices actually attached to
+            a domain, regardless of its settings.
+
     '''
 
     def __init__(self, vm, class_):
@@ -42,23 +110,16 @@ class DeviceCollection(object):
         self._class = class_
         self._set = set()
 
+        self.devclass = qubes.utils.get_entry_point_one(
+            'qubes.devices', self._class)
 
     def attach(self, device):
         '''Attach (add) device to domain.
 
-        :param str device: device identifier (format is class-dependent)
+        :param DeviceInfo device: device object
         '''
 
-        try:
-            devclass = qubes.utils.get_entry_point_one(
-                'qubes.devices', self._class)
-        except KeyError:
-            devclass = str
-
-        if not isinstance(device, devclass):
-            device = devclass(device)
-
-        if device in self:
+        if device in self.attached():
             raise KeyError(
                 'device {!r} of class {} already attached to {!r}'.format(
                     device, self._class, self._vm))
@@ -70,10 +131,10 @@ class DeviceCollection(object):
     def detach(self, device):
         '''Detach (remove) device from domain.
 
-        :param str device: device identifier (format is class-dependent)
+        :param DeviceInfo device: device object
         '''
 
-        if device not in self:
+        if device not in self.attached():
             raise KeyError(
                 'device {!r} of class {} not attached to {!r}'.format(
                     device, self._class, self._vm))
@@ -81,17 +142,73 @@ class DeviceCollection(object):
         self._set.remove(device)
         self._vm.fire_event('device-detach:' + self._class, device)
 
+    def attached(self, persistent=None):
+        '''List devices which are (or may be) attached to this vm
+
+        Devices may be attached persistently (so they are included in
+        :file:`qubes.xml`) or not. Device can also be in :file:`qubes.xml`,
+        but be temporarily detached.
+
+        :param bool persistent: only include devices which are (or are not) \
+        attached persistently - None means both
+        '''
+        seen = self._set.copy()
+
+        # ask for really attached devices only when requested not only
+        # persistent ones
+        if persistent is not True:
+            attached = self._vm.fire_event(
+                'device-list-attached:' + self._class,
+                persistent=persistent)
+            for device in attached:
+                device_persistent = device in self._set
+                if persistent is not None and device_persistent != persistent:
+                    continue
+                assert device.frontend_domain == self._vm, \
+                    '{!r} != {!r}'.format(device.frontend_domain, self._vm)
+
+                yield device
+
+                try:
+                    seen.remove(device)
+                except KeyError:
+                    pass
+
+        if persistent is False:
+            return
+
+        for device in seen:
+            # get fresh object - may contain updated information
+            device = device.backend_domain.devices[self._class][device.ident]
+            yield device
+
+    def available(self):
+        '''List devices exposed by this vm'''
+        devices = self._vm.fire_event('device-list:' + self._class)
+        return devices
 
     def __iter__(self):
-        return iter(self._set)
+        return iter(self.available())
 
+    def __getitem__(self, ident):
+        '''Get device object with given ident.
 
-    def __contains__(self, item):
-        return item in self._set
+        :returns: py:class:`DeviceInfo`
 
+        If domain isn't running, it is impossible to check device validity,
+        so return UnknownDevice object. Also do the same for non-existing
+        devices - otherwise it will be impossible to detach already
+        disconnected device.
 
-    def __len__(self):
-        return len(self._set)
+        :raises AssertionError: when multiple devices with the same ident are
+        found
+        '''
+        dev = self._vm.fire_event('device-get:' + self._class, ident)
+        if dev:
+            assert len(dev) == 1
+            return dev[0]
+        else:
+            return UnknownDevice(self._vm, ident)
 
 
 class DeviceManager(dict):
@@ -109,30 +226,50 @@ class DeviceManager(dict):
         return self[key]
 
 
-class RegexDevice(str):
-    regex = None
-    def __init__(self, *args, **kwargs):
-        super(RegexDevice, self).__init__(*args, **kwargs)
-
-        if self.regex is None:
-            raise NotImplementedError(
-                'You should overload .regex attribute in subclass')
-
-        dev_match = self.regex.match(self)
-        if not dev_match:
-            raise ValueError('Invalid device identifier: {!r}'.format(self))
-
-        for group in self.regex.groupindex:
-            setattr(self, group, dev_match.group(group))
-
-
-class PCIDevice(RegexDevice):
-    regex = re.compile(
-        r'^(?P<bus>[0-9a-f]+):(?P<device>[0-9a-f]+)\.(?P<function>[0-9a-f]+)$')
-
-    @property
-    def libvirt_name(self):
-        return 'pci_0000_{}_{}_{}'.format(self.bus, self.device, self.function)
+class DeviceInfo(object):
+    # pylint: disable=too-few-public-methods
+    def __init__(self, backend_domain, ident, description=None,
+            frontend_domain=None, **kwargs):
+        #: domain providing this device
+        self.backend_domain = backend_domain
+        #: device identifier (unique for given domain and device type)
+        self.ident = ident
+        #: human readable description/name of the device
+        self.description = description
+        #: (running) domain to which device is currently attached
+        self.frontend_domain = frontend_domain
+        self.data = kwargs
+
+        if hasattr(self, 'regex'):
+            # pylint: disable=no-member
+            dev_match = self.regex.match(ident)
+            if not dev_match:
+                raise ValueError('Invalid device identifier: {!r}'.format(
+                    ident))
+
+            for group in self.regex.groupindex:
+                setattr(self, group, dev_match.group(group))
+
+    def __hash__(self):
+        return hash(self.ident)
+
+    def __eq__(self, other):
+        return (
+            self.backend_domain == other.backend_domain and
+            self.ident == other.ident
+        )
+
+
+class UnknownDevice(DeviceInfo):
+    # pylint: disable=too-few-public-methods
+    '''Unknown device - for example exposed by domain not running currently'''
+
+    def __init__(self, backend_domain, ident, description=None,
+            frontend_domain=None, **kwargs):
+        if description is None:
+            description = "Unknown device"
+        super(UnknownDevice, self).__init__(backend_domain, ident, description,
+            frontend_domain, **kwargs)
 
 
 class BlockDevice(object):

+ 1 - 1
qubes/events.py

@@ -129,7 +129,7 @@ class Emitter(object):
         '''
 
         if not self.events_enabled:
-            return
+            return []
 
         effects = []
         for cls in order:

+ 5 - 2
qubes/tests/__init__.py

@@ -135,12 +135,15 @@ class TestEmitter(qubes.events.Emitter):
         self.fired_events = collections.Counter()
 
     def fire_event(self, event, *args, **kwargs):
-        super(TestEmitter, self).fire_event(event, *args, **kwargs)
+        effects = super(TestEmitter, self).fire_event(event, *args, **kwargs)
         self.fired_events[(event, args, tuple(sorted(kwargs.items())))] += 1
+        return effects
 
     def fire_event_pre(self, event, *args, **kwargs):
-        super(TestEmitter, self).fire_event_pre(event, *args, **kwargs)
+        effects = super(TestEmitter, self).fire_event_pre(event, *args,
+            **kwargs)
         self.fired_events[(event, args, tuple(sorted(kwargs.items())))] += 1
+        return effects
 
 def expectedFailureIfTemplate(templates):
     """

+ 59 - 13
qubes/tests/devices.py

@@ -27,24 +27,68 @@ import qubes.devices
 
 import qubes.tests
 
+class TestDevice(qubes.devices.DeviceInfo):
+    pass
+
+
+class TestVMCollection(dict):
+    def __iter__(self):
+        return iter(set(self.values()))
+
+
+class TestApp(object):
+    def __init__(self):
+        self.domains = TestVMCollection()
+
+
+class TestVM(qubes.tests.TestEmitter):
+    def __init__(self, app, name, *args, **kwargs):
+        super(TestVM, self).__init__(*args, **kwargs)
+        self.app = app
+        self.name = name
+        self.device = TestDevice(self, 'testdev', 'Description')
+        self.events_enabled = True
+        self.devices = {
+            'testclass': qubes.devices.DeviceCollection(self, 'testclass')
+        }
+        self.app.domains[name] = self
+        self.app.domains[self] = self
+
+    def __str__(self):
+        return self.name
+
+    @qubes.events.handler('device-list-attached:testclass')
+    def dev_testclass_list_attached(self, event, persistent):
+        for vm in self.app.domains:
+            if vm.device.frontend_domain == self:
+                yield vm.device
+
+    @qubes.events.handler('device-list:testclass')
+    def dev_testclass_list(self, event):
+        yield self.device
+
+
 class TC_00_DeviceCollection(qubes.tests.QubesTestCase):
     def setUp(self):
-        self.emitter = qubes.tests.TestEmitter()
-        self.collection = qubes.devices.DeviceCollection(self.emitter, 'testclass')
+        self.app = TestApp()
+        self.emitter = TestVM(self.app, 'vm')
+        self.app.domains['vm'] = self.emitter
+        self.device = self.emitter.device
+        self.collection = self.emitter.devices['testclass']
 
     def test_000_init(self):
         self.assertFalse(self.collection._set)
 
     def test_001_attach(self):
-        self.collection.attach('testdev')
+        self.collection.attach(self.device)
         self.assertEventFired(self.emitter, 'device-pre-attach:testclass')
         self.assertEventFired(self.emitter, 'device-attach:testclass')
         self.assertEventNotFired(self.emitter, 'device-pre-detach:testclass')
         self.assertEventNotFired(self.emitter, 'device-detach:testclass')
 
     def test_002_detach(self):
-        self.collection.attach('testdev')
-        self.collection.detach('testdev')
+        self.collection.attach(self.device)
+        self.collection.detach(self.device)
         self.assertEventFired(self.emitter, 'device-pre-attach:testclass')
         self.assertEventFired(self.emitter, 'device-attach:testclass')
         self.assertEventFired(self.emitter, 'device-pre-detach:testclass')
@@ -52,31 +96,33 @@ class TC_00_DeviceCollection(qubes.tests.QubesTestCase):
 
     def test_010_empty_detach(self):
         with self.assertRaises(LookupError):
-            self.collection.detach('testdev')
+            self.collection.detach(self.device)
 
     def test_011_double_attach(self):
-        self.collection.attach('testdev')
+        self.collection.attach(self.device)
 
         with self.assertRaises(LookupError):
-            self.collection.attach('testdev')
+            self.collection.attach(self.device)
 
     def test_012_double_detach(self):
-        self.collection.attach('testdev')
-        self.collection.detach('testdev')
+        self.collection.attach(self.device)
+        self.collection.detach(self.device)
 
         with self.assertRaises(LookupError):
-            self.collection.detach('testdev')
+            self.collection.detach(self.device)
 
 
 class TC_01_DeviceManager(qubes.tests.QubesTestCase):
     def setUp(self):
-        self.emitter = qubes.tests.TestEmitter()
+        self.app = TestApp()
+        self.emitter = TestVM(self.app, 'vm')
         self.manager = qubes.devices.DeviceManager(self.emitter)
 
     def test_000_init(self):
         self.assertEqual(self.manager, {})
 
     def test_001_missing(self):
-        self.manager['testclass'].attach('testdev')
+        device = TestDevice(self.emitter.app.domains['vm'], 'testdev')
+        self.manager['testclass'].attach(device)
         self.assertEventFired(self.emitter, 'device-attach:testclass')
 

+ 10 - 3
qubes/tests/vm/init.py

@@ -31,6 +31,11 @@ import qubes.vm
 
 import qubes.tests
 
+class TestApp(object):
+    def __init__(self):
+        super(TestApp, self).__init__()
+        self.domains = {}
+
 
 class TestVM(qubes.vm.BaseVM):
     qid = qubes.property('qid', type=int)
@@ -66,7 +71,7 @@ class TC_10_BaseVM(qubes.tests.QubesTestCase):
             </features>
 
             <devices class="pci">
-                <device>00:11.22</device>
+                <device backend-domain="domain1" id="00:11.22"/>
             </devices>
 
             <devices class="usb" />
@@ -81,7 +86,8 @@ class TC_10_BaseVM(qubes.tests.QubesTestCase):
 
     def test_000_load(self):
         node = self.xml.xpath('//domain')[0]
-        vm = TestVM(None, node)
+        vm = TestVM(TestApp(), node)
+        vm.app.domains['domain1'] = vm
         vm.load_properties(load_stage=None)
         vm.load_extras()
 
@@ -97,7 +103,8 @@ class TC_10_BaseVM(qubes.tests.QubesTestCase):
         })
 
         self.assertItemsEqual(vm.devices.keys(), ('pci',))
-        self.assertItemsEqual(vm.devices['pci'], ('00:11.22',))
+        self.assertItemsEqual(list(vm.devices['pci'].attached(persistent=True)),
+            [qubes.devices.PCIDevice(vm, '00:11.22')])
 
         self.assertXMLIsValid(vm.__xml__(), 'domain.rng')
 

+ 8 - 3
qubes/vm/__init__.py

@@ -198,7 +198,11 @@ class BaseVM(qubes.PropertyHolder):
         for parent in self.xml.xpath('./devices'):
             devclass = parent.get('class')
             for node in parent.xpath('./device'):
-                self.devices[devclass].attach(node.text)
+                device = self.devices[devclass].devclass(
+                    self.app.domains[node.get('backend-domain')],
+                    node.get('id')
+                )
+                self.devices[devclass].attach(device)
 
         # tags
         for node in self.xml.xpath('./tags/tag'):
@@ -227,9 +231,10 @@ class BaseVM(qubes.PropertyHolder):
         for devclass in self.devices:
             devices = lxml.etree.Element('devices')
             devices.set('class', devclass)
-            for device in self.devices[devclass]:
+            for device in self.devices[devclass].attached(persistent=True):
                 node = lxml.etree.Element('device')
-                node.text = device
+                node.set('backend-domain', device.backend_domain.name)
+                node.set('id', device.ident)
                 devices.append(node)
             element.append(devices)
 

+ 6 - 3
qubes/vm/qubesvm.py

@@ -243,7 +243,7 @@ class QubesVM(qubes.vm.mix.net.NetVMMixin, qubes.vm.BaseVM):
     # CORE2: swallowed uses_default_kernelopts
     kernelopts = qubes.property('kernelopts', type=str, load_stage=4,
         default=(lambda self: qubes.config.defaults['kernelopts_pcidevs']
-            if len(self.devices['pci']) > 0
+            if list(self.devices['pci'].attached())
             else self.template.kernelopts if hasattr(self, 'template')
             else qubes.config.defaults['kernelopts']),
         ls_width=30,
@@ -476,6 +476,9 @@ class QubesVM(qubes.vm.mix.net.NetVMMixin, qubes.vm.BaseVM):
             self.events_enabled = True
         self.fire_event('domain-init')
 
+    def __hash__(self):
+        return self.qid
+
     def __xml__(self):
         element = super(QubesVM, self).__xml__()
 
@@ -700,7 +703,7 @@ class QubesVM(qubes.vm.mix.net.NetVMMixin, qubes.vm.BaseVM):
         qmemman_client = self.request_memory(mem_required)
 
         # Bind pci devices to pciback driver
-        for pci in self.devices['pci']:
+        for pci in self.devices['pci'].attached():
             self.bind_pci_to_pciback(pci)
 
         self.libvirt_domain.createWithFlags(libvirt.VIR_DOMAIN_START_PAUSED)
@@ -802,7 +805,7 @@ class QubesVM(qubes.vm.mix.net.NetVMMixin, qubes.vm.BaseVM):
         if not self.is_running() and not self.is_paused():
             raise qubes.exc.QubesVMNotRunningError(self)
 
-        if len(self.devices['pci']) > 0:
+        if list(self.devices['pci'].attached()):
             raise qubes.exc.QubesNotImplementedError(
                 'Cannot suspend domain {!r} which has PCI devices attached'
                 .format(self.name))

+ 17 - 7
relaxng/qubes.rng

@@ -206,14 +206,24 @@ the parser will complain about missing combine= attribute on the second <start>.
                     <oneOrMore>
                         <element name="device">
                             <doc:description>
-                                One device. This tag should contain some
-                                identifier, format of which depends on
-                                particular device class.
+                                One device. It's identified by by a pair of
+                                backend domain and some identifier (device class
+                                dependant).
                             </doc:description>
-                            <!-- TODO: pattern dependent on class! -->
-                            <data type="string">
-                                <param name="pattern">[0-9a-f]{2}:[0-9a-f]{2}.[0-9a-f]{2}</param>
-                            </data>
+                            <attribute name="backend-domain">
+                                <doc:description>
+                                    Backend domain name.
+                                </doc:description>
+                                <data type="string">
+                                    <param name="pattern">[a-z0-9_]+</param>
+                                </data>
+                            </attribute>
+                            <attribute name="id">
+                                <!-- TODO: pattern dependent on class! -->
+                                <data type="string">
+                                    <param name="pattern">[0-9a-f]{2}:[0-9a-f]{2}.[0-9a-f]{2}</param>
+                                </data>
+                            </attribute>
                         </element>
                     </oneOrMore>
                 </element>

+ 1 - 0
setup.py

@@ -43,6 +43,7 @@ if __name__ == '__main__':
             ],
             'qubes.devices': [
                 'pci = qubes.devices:PCIDevice',
+                'testclass = qubes.tests.devices:TestDevice',
             ],
             'qubes.storage': [
                 'file = qubes.storage.file:FilePool',

+ 4 - 3
templates/libvirt/xen.xml

@@ -27,7 +27,8 @@
             <viridian/>
         {% endif %}
 
-        {% if vm.devices['pci'] and vm.features.get('pci-e820-host', True) %}
+        {% if vm.devices['pci'].attached() | list
+                and vm.features.get('pci-e820-host', True) %}
             <xen>
                 <e820_host state="on"/>
             </xen>
@@ -90,10 +91,10 @@
         {% endfor %}
 
         {% if vm.netvm %}
-            {% include 'libvirt/devices/net.xml' %}
+            {% include 'libvirt/devices/net.xml' with context %}
         {% endif %}
 
-        {% for device in vm.devices.pci %}
+        {% for device in vm.devices.pci.attached() %}
             {% include 'libvirt/devices/pci.xml' %}
         {% endfor %}