浏览代码

qubes: fix netvm properties and tests

fixes QubesOS/qubes-issues#1816
Wojtek Porczyk 8 年之前
父节点
当前提交
786884ad7a

+ 1 - 1
qubes/app.py

@@ -460,7 +460,7 @@ class VMCollection(object):
 
         while len(new_vms) > 0:
             cur_vm = new_vms.pop()
-            for vm in cur_vm.connected_vms.values():
+            for vm in cur_vm.connected_vms:
                 if vm in dependent_vms:
                     continue
                 dependent_vms.add(vm.qid)

+ 1 - 0
qubes/tests/__init__.py

@@ -950,6 +950,7 @@ def load_tests(loader, tests, pattern): # pylint: disable=unused-argument
             'qubes.tests.storage',
             'qubes.tests.storage_file',
             'qubes.tests.vm.qubesvm',
+            'qubes.tests.vm.mix.net',
             'qubes.tests.vm.adminvm',
             'qubes.tests.app',
             ):

+ 0 - 1
qubes/tests/app.py

@@ -30,7 +30,6 @@ import lxml.etree
 
 import qubes
 import qubes.events
-import qubes.vm
 
 import qubes.tests
 

+ 112 - 0
qubes/tests/init.py

@@ -283,3 +283,115 @@ class TC_20_PropertyHolder(qubes.tests.QubesTestCase):
     @unittest.skip('test not implemented')
     def test_010_property_require(self):
         pass
+
+
+class TestVM(qubes.vm.BaseVM):
+    qid = qubes.property('qid', type=int)
+    name = qubes.property('name')
+    netid = qid
+
+class TestApp(qubes.tests.TestEmitter):
+    pass
+
+class TC_30_VMCollection(qubes.tests.QubesTestCase):
+    def setUp(self):
+        self.app = TestApp()
+        self.vms = qubes.VMCollection(self.app)
+
+        self.testvm1 = TestVM(None, None, qid=1, name='testvm1')
+        self.testvm2 = TestVM(None, None, qid=2, name='testvm2')
+
+    def test_000_contains(self):
+        self.vms._dict = {1: self.testvm1}
+
+        self.assertIn(1, self.vms)
+        self.assertIn('testvm1', self.vms)
+        self.assertIn(self.testvm1, self.vms)
+
+        self.assertNotIn(2, self.vms)
+        self.assertNotIn('testvm2', self.vms)
+        self.assertNotIn(self.testvm2, self.vms)
+
+    def test_001_getitem(self):
+        self.vms._dict = {1: self.testvm1}
+
+        self.assertIs(self.vms[1], self.testvm1)
+        self.assertIs(self.vms['testvm1'], self.testvm1)
+        self.assertIs(self.vms[self.testvm1], self.testvm1)
+
+    def test_002_add(self):
+        self.vms.add(self.testvm1)
+        self.assertIn(1, self.vms)
+
+        self.assertEventFired(self.app, 'domain-add', args=[self.testvm1])
+
+        with self.assertRaises(TypeError):
+            self.vms.add(object())
+
+        testvm_qid_collision = TestVM(None, None, name='testvm2', qid=1)
+        testvm_name_collision = TestVM(None, None, name='testvm1', qid=2)
+
+        with self.assertRaises(ValueError):
+            self.vms.add(testvm_qid_collision)
+        with self.assertRaises(ValueError):
+            self.vms.add(testvm_name_collision)
+
+    def test_003_qids(self):
+        self.vms.add(self.testvm1)
+        self.vms.add(self.testvm2)
+
+        self.assertItemsEqual(self.vms.qids(), [1, 2])
+        self.assertItemsEqual(self.vms.keys(), [1, 2])
+
+    def test_004_names(self):
+        self.vms.add(self.testvm1)
+        self.vms.add(self.testvm2)
+
+        self.assertItemsEqual(self.vms.names(), ['testvm1', 'testvm2'])
+
+    def test_005_vms(self):
+        self.vms.add(self.testvm1)
+        self.vms.add(self.testvm2)
+
+        self.assertItemsEqual(self.vms.vms(), [self.testvm1, self.testvm2])
+        self.assertItemsEqual(self.vms.values(), [self.testvm1, self.testvm2])
+
+    def test_006_items(self):
+        self.vms.add(self.testvm1)
+        self.vms.add(self.testvm2)
+
+        self.assertItemsEqual(self.vms.items(),
+            [(1, self.testvm1), (2, self.testvm2)])
+
+    def test_007_len(self):
+        self.vms.add(self.testvm1)
+        self.vms.add(self.testvm2)
+
+        self.assertEqual(len(self.vms), 2)
+
+    def test_008_delitem(self):
+        self.vms.add(self.testvm1)
+        self.vms.add(self.testvm2)
+
+        del self.vms['testvm2']
+
+        self.assertItemsEqual(self.vms.vms(), [self.testvm1])
+        self.assertEventFired(self.app, 'domain-delete', args=[self.testvm2])
+
+    def test_100_get_new_unused_qid(self):
+        self.vms.add(self.testvm1)
+        self.vms.add(self.testvm2)
+
+        self.vms.get_new_unused_qid()
+
+    def test_101_get_new_unused_netid(self):
+        self.vms.add(self.testvm1)
+        self.vms.add(self.testvm2)
+
+        self.vms.get_new_unused_netid()
+
+#   def test_200_get_vms_based_on(self):
+#       pass
+
+#   def test_201_get_vms_connected_to(self):
+#       pass

+ 4 - 0
qubes/tests/vm/__init__.py

@@ -29,6 +29,10 @@ class TestVMM(object):
     # pylint: disable=too-few-public-methods
     def __init__(self, offline_mode=False):
         self.offline_mode = offline_mode
+    @property
+    def libvirt_conn(self):
+        import libvirt
+        raise libvirt.libvirtError('phony error')
 
 class TestHost(object):
     # pylint: disable=too-few-public-methods

+ 0 - 0
qubes/tests/vm/mix/__init__.py


+ 124 - 0
qubes/tests/vm/mix/net.py

@@ -0,0 +1,124 @@
+#!/usr/bin/python2 -O
+# vim: fileencoding=utf-8
+# pylint: disable=protected-access
+
+#
+# The Qubes OS Project, https://www.qubes-os.org/
+#
+# Copyright (C) 2014-2016  Joanna Rutkowska <joanna@invisiblethingslab.com>
+# Copyright (C) 2014-2016  Wojtek Porczyk <woju@invisiblethingslab.com>
+#
+# This program is free software; you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation; either version 2 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License along
+# with this program; if not, write to the Free Software Foundation, Inc.,
+# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+#
+
+import qubes
+import qubes.vm.qubesvm
+
+import qubes.tests
+import qubes.tests.vm.qubesvm
+
+class TC_00_NetVMMixin(
+        qubes.tests.vm.qubesvm.QubesVMTestsMixin, qubes.tests.QubesTestCase):
+    def setUp(self):
+        super(TC_00_NetVMMixin, self).setUp()
+        self.app = qubes.tests.vm.TestApp()
+
+    def setup_netvms(self, vm):
+        # usage of QubesVM here means that those tests should be after
+        # testing properties used here
+        self.netvm1 = qubes.vm.qubesvm.QubesVM(self.app, None, qid=2,
+            name=qubes.tests.VMPREFIX + 'netvm1',
+            provides_network=True)
+        self.netvm2 = qubes.vm.qubesvm.QubesVM(self.app, None, qid=3,
+            name=qubes.tests.VMPREFIX + 'netvm2',
+            provides_network=True)
+        self.nonetvm = qubes.vm.qubesvm.QubesVM(self.app, None, qid=4,
+            name=qubes.tests.VMPREFIX + 'nonet')
+        self.app.domains = qubes.VMCollection(self.app)
+        for domain in (vm, self.netvm1, self.netvm2, self.nonetvm):
+            self.app.domains._dict[domain.qid] = domain
+        self.app.default_netvm = self.netvm1
+        self.app.default_fw_netvm = self.netvm1
+
+
+    def test_140_netvm(self):
+        vm = self.get_vm()
+        self.setup_netvms(vm)
+        self.assertPropertyDefaultValue(vm, 'netvm', self.app.default_netvm)
+        self.assertPropertyValue(vm, 'netvm', self.netvm2, self.netvm2,
+            self.netvm2.name)
+        del vm.netvm
+        self.assertPropertyDefaultValue(vm, 'netvm', self.app.default_netvm)
+        self.assertPropertyValue(vm, 'netvm', self.netvm2.name, self.netvm2,
+            self.netvm2.name)
+        self.assertPropertyValue(vm, 'netvm', None, None, '')
+
+    def test_141_netvm_invalid(self):
+        vm = self.get_vm()
+        self.setup_netvms(vm)
+        self.assertPropertyInvalidValue(vm, 'netvm', 'invalid')
+        self.assertPropertyInvalidValue(vm, 'netvm', 123)
+
+    def test_142_netvm_netvm(self):
+        vm = self.get_vm()
+        self.setup_netvms(vm)
+        self.assertPropertyInvalidValue(vm, 'netvm', self.nonetvm)
+
+    def test_143_netvm_loopback(self):
+        vm = self.get_vm()
+        self.app.domains = {1: vm, vm: vm}
+        self.assertPropertyInvalidValue(vm, 'netvm', vm)
+
+    def test_290_dispvm_netvm(self):
+        vm = self.get_vm()
+        self.setup_netvms(vm)
+        self.assertPropertyDefaultValue(vm, 'dispvm_netvm',
+            self.app.default_netvm)
+        self.assertPropertyValue(vm, 'dispvm_netvm', self.netvm2, self.netvm2,
+            self.netvm2.name)
+        del vm.dispvm_netvm
+        self.assertPropertyDefaultValue(vm, 'dispvm_netvm',
+            self.app.default_netvm)
+        self.assertPropertyValue(vm, 'dispvm_netvm', self.netvm2.name,
+            self.netvm2, self.netvm2.name)
+        # XXX FIXME xml value
+        self.assertPropertyValue(vm, 'dispvm_netvm', None, None, 'None')
+
+    def test_291_dispvm_netvm_invalid(self):
+        vm = self.get_vm()
+        self.setup_netvms(vm)
+        self.assertPropertyInvalidValue(vm, 'dispvm_netvm', 'invalid')
+        self.assertPropertyInvalidValue(vm, 'dispvm_netvm', 123)
+
+    def test_291_dispvm_netvm_netvm(self):
+        vm = self.get_vm()
+        nonetvm = TestVM(qid=2, app=self.app, name='nonetvm')
+        self.app.domains = {1: vm, 2: nonetvm}
+        self.assertPropertyInvalidValue(vm, 'dispvm_netvm', nonetvm)
+
+    def test_291_dispvm_netvm_default(self):
+        """Check if vm.dispvm_netvm default is really vm.netvm"""
+        vm = self.get_vm()
+        self.setup_netvms(vm)
+        vm.netvm = self.netvm2
+        self.assertPropertyDefaultValue(vm, 'dispvm_netvm', self.netvm2)
+        del vm.netvm
+        self.assertPropertyDefaultValue(vm, 'dispvm_netvm', self.netvm1)
+
+    def test_292_dispvm_netvm_loopback(self):
+        vm = self.get_vm()
+        self.app.domains = {1: vm, vm: vm}
+        self.assertPropertyInvalidValue(vm, 'dispvm_netvm', vm)
+

+ 5 - 90
qubes/tests/vm/qubesvm.py

@@ -140,9 +140,11 @@ class TC_00_setters(qubes.tests.QubesTestCase):
     # there is no check for self.app.get_label()
 
 
-class TC_90_QubesVM(qubes.tests.QubesTestCase):
+class QubesVMTestsMixin(object):
+    property_no_default = object()
+
     def setUp(self):
-        super(TC_90_QubesVM, self).setUp()
+        super(QubesVMTestsMixin, self).setUp()
         self.app = qubes.tests.vm.TestApp()
 
     def get_vm(self, **kwargs):
@@ -150,8 +152,6 @@ class TC_90_QubesVM(qubes.tests.QubesTestCase):
             qid=1, name=qubes.tests.VMPREFIX + 'test',
             **kwargs)
 
-    property_no_default = object()
-
     def assertPropertyValue(self, vm, prop_name, set_value, expected_value,
             expected_xml_content=None):
         # FIXME: any better exception list? or maybe all of that should be a
@@ -208,24 +208,8 @@ class TC_90_QubesVM(qubes.tests.QubesTestCase):
         self.assertPropertyValue(vm, prop_name, 123, True)
         self.assertPropertyInvalidValue(vm, prop_name, '')
 
-    def setup_netvms(self, vm):
-        # usage of QubesVM here means that those tests should be after
-        # testing properties used here
-        self.netvm1 = qubes.vm.qubesvm.QubesVM(self.app, None, qid=2,
-            name=qubes.tests.VMPREFIX + 'netvm1',
-            provides_network=True)
-        self.netvm2 = qubes.vm.qubesvm.QubesVM(self.app, None, qid=3,
-            name=qubes.tests.VMPREFIX + 'netvm2',
-            provides_network=True)
-        self.nonetvm = qubes.vm.qubesvm.QubesVM(self.app, None, qid=4,
-            name=qubes.tests.VMPREFIX + 'nonet')
-        self.app.domains = {}
-        for domain in (vm, self.netvm1, self.netvm2, self.nonetvm):
-            self.app.domains[domain.qid] = domain
-            self.app.domains[domain] = domain
-            self.app.domains[domain.name] = domain
-        self.app.default_netvm = self.netvm1
 
+class TC_90_QubesVM(QubesVMTestsMixin,qubes.tests.QubesTestCase):
     def test_000_init(self):
         self.get_vm()
 
@@ -287,34 +271,6 @@ class TC_90_QubesVM(qubes.tests.QubesTestCase):
         self.assertPropertyInvalidValue(vm, 'label', 'invalid')
         self.assertPropertyInvalidValue(vm, 'label', 123)
 
-    def test_140_netvm(self):
-        vm = self.get_vm()
-        self.setup_netvms(vm)
-        self.assertPropertyDefaultValue(vm, 'netvm', self.app.default_netvm)
-        self.assertPropertyValue(vm, 'netvm', self.netvm2, self.netvm2,
-            self.netvm2.name)
-        del vm.netvm
-        self.assertPropertyDefaultValue(vm, 'netvm', self.app.default_netvm)
-        self.assertPropertyValue(vm, 'netvm', self.netvm2.name, self.netvm2,
-            self.netvm2.name)
-        self.assertPropertyValue(vm, 'netvm', None, None, '')
-
-    def test_141_netvm_invalid(self):
-        vm = self.get_vm()
-        self.setup_netvms(vm)
-        self.assertPropertyInvalidValue(vm, 'netvm', 'invalid')
-        self.assertPropertyInvalidValue(vm, 'netvm', 123)
-
-    def test_142_netvm_netvm(self):
-        vm = self.get_vm()
-        self.setup_netvms(vm)
-        self.assertPropertyInvalidValue(vm, 'netvm', self.nonetvm)
-
-    def test_143_netvm_loopback(self):
-        vm = self.get_vm()
-        self.app.domains = {1: vm, vm: vm}
-        self.assertPropertyInvalidValue(vm, 'netvm', vm)
-
     def test_150_hvm(self):
         vm = self.get_vm()
         self._test_generic_bool_property(vm, 'hvm')
@@ -501,47 +457,6 @@ class TC_90_QubesVM(qubes.tests.QubesTestCase):
             'qubes-vm@{}.service'.format(vm.name)),
             "systemd service not disabled by resetting autostart")
 
-    def test_290_dispvm_netvm(self):
-        vm = self.get_vm()
-        self.setup_netvms(vm)
-        self.assertPropertyDefaultValue(vm, 'dispvm_netvm',
-            self.app.default_netvm)
-        self.assertPropertyValue(vm, 'dispvm_netvm', self.netvm2, self.netvm2,
-            self.netvm2.name)
-        del vm.dispvm_netvm
-        self.assertPropertyDefaultValue(vm, 'dispvm_netvm',
-            self.app.default_netvm)
-        self.assertPropertyValue(vm, 'dispvm_netvm', self.netvm2.name,
-            self.netvm2, self.netvm2.name)
-        # XXX FIXME xml value
-        self.assertPropertyValue(vm, 'dispvm_netvm', None, None, 'None')
-
-    def test_291_dispvm_netvm_invalid(self):
-        vm = self.get_vm()
-        self.setup_netvms(vm)
-        self.assertPropertyInvalidValue(vm, 'dispvm_netvm', 'invalid')
-        self.assertPropertyInvalidValue(vm, 'dispvm_netvm', 123)
-
-    def test_291_dispvm_netvm_netvm(self):
-        vm = self.get_vm()
-        nonetvm = TestVM(qid=2, app=self.app, name='nonetvm')
-        self.app.domains = {1: vm, 2: nonetvm}
-        self.assertPropertyInvalidValue(vm, 'dispvm_netvm', nonetvm)
-
-    def test_291_dispvm_netvm_default(self):
-        """Check if vm.dispvm_netvm default is really vm.netvm"""
-        vm = self.get_vm()
-        self.setup_netvms(vm)
-        vm.netvm = self.netvm2
-        self.assertPropertyDefaultValue(vm, 'dispvm_netvm', self.netvm2)
-        del vm.netvm
-        self.assertPropertyDefaultValue(vm, 'dispvm_netvm', self.netvm1)
-
-    def test_292_dispvm_netvm_loopback(self):
-        vm = self.get_vm()
-        self.app.domains = {1: vm, vm: vm}
-        self.assertPropertyInvalidValue(vm, 'dispvm_netvm', vm)
-
     @unittest.skip('TODO')
     def test_300_qrexec_installed(self):
         vm = self.get_vm()

+ 21 - 6
qubes/vm/mix/net.py

@@ -28,9 +28,10 @@ import libvirt
 import lxml.etree
 
 import qubes
+import qubes.events
 import qubes.exc
 
-class NetVMMixin(object):
+class NetVMMixin(qubes.events.Emitter):
     mac = qubes.property('mac', type=str,
         default='00:16:3E:5E:6C:00',
         ls_width=17,
@@ -268,16 +269,30 @@ class NetVMMixin(object):
         self.fire_event('property-set:netvm', 'netvm', new_netvm, old_netvm)
 
 
-    @qubes.events.handler('property-set:netvm')
-    def on_property_set_netvm(self, event, name, new_netvm, old_netvm=None):
-        # pylint: disable=unused-argument
+    @qubes.events.handler('property-pre-set:netvm')
+    def on_property_pre_set_netvm(self, event, name, new_netvm, old_netvm=None):
+        if new_netvm is None:
+            return
+
+        if not new_netvm.provides_network:
+            raise qubes.exc.QubesValueError(
+                'The {!s} qube does not provide network'.format(new_netvm))
+
+        if new_netvm is self \
+                or new_netvm in self.app.domains.get_vms_connected_to(self):
+            raise qubes.exc.QubesValueError('Loops in network are unsupported')
+
         # TODO offline_mode
-        if self.is_running() and new_netvm is not None \
-                and not new_netvm.is_running():
+        if self.is_running() and not new_netvm.is_running():
             raise qubes.exc.QubesVMNotStartedError(new_netvm,
                 'Cannot dynamically attach to stopped NetVM: {!r}'.format(
                     new_netvm))
 
+
+    @qubes.events.handler('property-set:netvm')
+    def on_property_set_netvm(self, event, name, new_netvm, old_netvm=None):
+        # pylint: disable=unused-argument
+
         if self.netvm is not None:
             if self.is_running():
                 self.detach_network()

+ 4 - 0
rpm_spec/core-dom0.spec

@@ -276,6 +276,10 @@ fi
 %{python_sitelib}/qubes/tests/vm/adminvm.py*
 %{python_sitelib}/qubes/tests/vm/qubesvm.py*
 
+%dir %{python_sitelib}/qubes/tests/vm/mix
+%{python_sitelib}/qubes/tests/vm/mix/__init__.py*
+%{python_sitelib}/qubes/tests/vm/mix/net.py*
+
 %dir %{python_sitelib}/qubes/tests/tools
 %{python_sitelib}/qubes/tests/tools/__init__.py*
 %{python_sitelib}/qubes/tests/tools/init.py*