Browse Source

qubes/vm: DeviceManager class for herding devices

collections.defaultdict was not enough, because it cannot pass any arguments to
factory. We need to pass domain object and device class to fire events on attach
and detach.
Wojtek Porczyk 9 years ago
parent
commit
ef4f00dac0
2 changed files with 152 additions and 2 deletions
  1. 68 1
      qubes/vm/__init__.py
  2. 84 1
      tests/vm.py

+ 68 - 1
qubes/vm/__init__.py

@@ -68,6 +68,73 @@ class BaseVMMeta(qubes.plugins.Plugin, qubes.events.EmitterMeta):
         cls.__hooks__ = collections.defaultdict(list)
 
 
+class DeviceCollection(object):
+    '''Bag for devices.
+
+    Used as default value for :py:meth:`DeviceManager.__missing__` factory.
+
+    :param vm: VM for which we manage devices
+    :param class_: device class
+    '''
+
+    def __init__(self, vm, class_):
+        self._vm = vm
+        self._class = class_
+        self._set = set()
+
+
+    def attach(self, device):
+        '''Attach (add) device to domain.
+
+        :param str device: device identifier (format is class-dependent)
+        '''
+
+        if device in self:
+            raise KeyError(
+                'device {!r} of class {} already attached to {!r}'.format(
+                    device, self._class, self._vm))
+        self._vm.fire_event('device-pre-attached:{}'.format(self._class), device)
+        self._set.add(device)
+        self._vm.fire_event('device-attached:{}'.format(self._class), device)
+
+
+    def detach(self, device):
+        '''Detach (remove) device from domain.
+
+        :param str device: device identifier (format is class-dependent)
+        '''
+
+        if device not in self:
+            raise KeyError(
+                'device {!r} of class {} not attached to {!r}'.format(
+                    device, self._class, self._vm))
+        self._vm.fire_event('device-pre-detached:{}'.format(self._class), device)
+        self._set.remove(device)
+        self._vm.fire_event('device-detached:{}'.format(self._class), device)
+
+
+    def __iter__(self):
+        return iter(self._set)
+
+
+    def __contains__(self, item):
+        return item in self._set
+
+
+class DeviceManager(dict):
+    '''Device manager that hold all devices by their classess.
+
+    :param vm: VM for which we manage devices
+    '''
+
+    def __init__(self, vm):
+        super(DeviceManager, self).__init__()
+        self._vm = vm
+
+    def __missing__(self, key):
+        return DeviceCollection(self._vm, key)
+
+
 class BaseVM(qubes.PropertyHolder):
     '''Base class for all VMs
 
@@ -87,7 +154,7 @@ class BaseVM(qubes.PropertyHolder):
             tags={}, *args, **kwargs):
         self.app = app
         self.services = services
-        self.devices = collections.defaultdict(list) if devices is None else devices
+        self.devices = DeviceManager(self) if devices is None else devices
         self.tags = tags
 
         self.events_enabled = False

+ 84 - 1
tests/vm.py

@@ -6,9 +6,92 @@ import unittest
 import lxml.etree
 
 sys.path.insert(0, '../')
+import qubes
+import qubes.events
 import qubes.vm
 
 
+class TestEmitter(qubes.events.Emitter):
+    def __init__(self):
+        super(TestEmitter, self).__init__()
+        self.device_pre_attached_fired = False
+        self.device_attached_fired = False
+        self.device_pre_detached_fired = False
+        self.device_detached_fired = False
+
+    @qubes.events.handler('device-pre-attached:testclass')
+    def on_device_pre_attached(self, event, dev):
+        self.device_pre_attached_fired = True
+
+    @qubes.events.handler('device-attached:testclass')
+    def on_device_attached(self, event, dev):
+        if self.device_pre_attached_fired:
+            self.device_attached_fired = True
+
+    @qubes.events.handler('device-pre-detached:testclass')
+    def on_device_pre_detached(self, event, dev):
+        if self.device_attached_fired:
+            self.device_pre_detached_fired = True
+
+    @qubes.events.handler('device-detached:testclass')
+    def on_device_detached(self, event, dev):
+        if self.device_pre_detached_fired:
+            self.device_detached_fired = True
+
+class TC_00_DeviceCollection(unittest.TestCase):
+    def setUp(self):
+        self.emitter = TestEmitter()
+        self.collection = qubes.vm.DeviceCollection(self.emitter, 'testclass')
+
+    def test_000_init(self):
+        self.assertFalse(self.collection._set)
+
+    def test_001_attach(self):
+        self.collection.attach('testdev')
+        self.assertTrue(self.emitter.device_pre_attached_fired)
+        self.assertTrue(self.emitter.device_attached_fired)
+        self.assertFalse(self.emitter.device_pre_detached_fired)
+        self.assertFalse(self.emitter.device_detached_fired)
+
+    def test_002_detach(self):
+        self.collection.attach('testdev')
+        self.collection.detach('testdev')
+        self.assertTrue(self.emitter.device_pre_attached_fired)
+        self.assertTrue(self.emitter.device_attached_fired)
+        self.assertTrue(self.emitter.device_pre_detached_fired)
+        self.assertTrue(self.emitter.device_detached_fired)
+
+    def test_010_empty_detach(self):
+        with self.assertRaises(LookupError):
+            self.collection.detach('testdev')
+
+    def test_011_double_attach(self):
+        self.collection.attach('testdev')
+
+        with self.assertRaises(LookupError):
+            self.collection.attach('testdev')
+
+    def test_012_double_detach(self):
+        self.collection.attach('testdev')
+        self.collection.detach('testdev')
+
+        with self.assertRaises(LookupError):
+            self.collection.detach('testdev')
+
+
+class TC_01_DeviceManager(unittest.TestCase):
+    def setUp(self):
+        self.emitter = TestEmitter()
+        self.manager = qubes.vm.DeviceManager(self.emitter)
+
+    def test_000_init(self):
+        self.assertEqual(self.manager, {})
+
+    def test_001_missing(self):
+        self.manager['testclass'].attach('testdev')
+        self.assertTrue(self.emitter.device_attached_fired)
+
+
 class TestVM(qubes.vm.BaseVM):
     qid = qubes.property('qid', type=int)
     name = qubes.property('name')
@@ -16,7 +99,7 @@ class TestVM(qubes.vm.BaseVM):
     testlabel = qubes.property('testlabel')
     defaultprop = qubes.property('defaultprop', default='defaultvalue')
 
-class TC_BaseVM(unittest.TestCase):
+class TC_10_BaseVM(unittest.TestCase):
     def setUp(self):
         self.xml = lxml.etree.XML('''
 <qubes version="3"> <!-- xmlns="https://qubes-os.org/QubesXML/1" -->