Quellcode durchsuchen

get_connected_ips: handle empty and missing keys, add tests

Pawel Marczewski vor 4 Jahren
Ursprung
Commit
4aace50313
2 geänderte Dateien mit 58 neuen und 13 gelöschten Zeilen
  1. 18 12
      qubesagent/firewall.py
  2. 40 1
      qubesagent/test_firewall.py

+ 18 - 12
qubesagent/firewall.py

@@ -82,6 +82,8 @@ class FirewallWorker(object):
 
     def get_connected_ips(self, family):
         ips = self.qdb.read('/connected-ips6' if family == 6 else '/connected-ips')
+        if ips is None:
+            return []
         return ips.decode().split()
 
     def run_firewall_dir(self):
@@ -511,20 +513,24 @@ class NftablesWorker(FirewallWorker):
 
     def update_connected_ips(self, family):
         family_name = ('ip6' if family == 6 else 'ip')
-        ips = self.get_connected_ips(family)
-        if ips:
-            addr = '{' + ', '.join(ips) + '}'
-            irule = 'iifname != "vif*" {family_name} saddr {addr} drop\n'.format(
-                family_name=family_name, addr=addr)
-            orule = 'oifname != "vif*" {family_name} daddr {addr} drop\n'.format(
-                family_name=family_name, addr=addr)
-        else:
-            irule = ''
-            orule = ''
+        table = 'qubes-firewall'
 
-        nft_input = (
+        self.run_nft((
             'flush chain {family_name} {table} prerouting\n'
             'flush chain {family_name} {table} postrouting\n'
+        ).format(family_name=family_name, table=table))
+
+        ips = self.get_connected_ips(family)
+        if not ips:
+            return
+
+        addr = '{' + ', '.join(ips) + '}'
+        irule = 'iifname != "vif*" {family_name} saddr {addr} drop\n'.format(
+            family_name=family_name, addr=addr)
+        orule = 'oifname != "vif*" {family_name} daddr {addr} drop\n'.format(
+            family_name=family_name, addr=addr)
+
+        nft_input = (
             'table {family_name} {table} {{\n'
             '  chain prerouting {{\n'
             '    {irule}'
@@ -535,7 +541,7 @@ class NftablesWorker(FirewallWorker):
             '}}\n'
         ).format(
             family_name=family_name,
-            table='qubes-firewall',
+            table=table,
             irule=irule,
             orule=orule,
         )

+ 40 - 1
qubesagent/test_firewall.py

@@ -349,6 +349,25 @@ class TestIptablesWorker(TestCase):
              '!', '-o', 'vif+', '-d', '10.137.0.2', '-j', 'DROP']
         ])
 
+    def test_009_update_connected_ips_empty(self):
+        self.obj.qdb.entries['/connected-ips'] = b''
+        self.obj.called_commands[4] = []
+        self.obj.update_connected_ips(4)
+
+        self.assertEqual(self.obj.called_commands[4], [
+            ['-t', 'raw', '-F', 'QBS-PREROUTING'],
+            ['-t', 'mangle', '-F', 'QBS-POSTROUTING'],
+        ])
+
+    def test_010_update_connected_ips_missing(self):
+        self.obj.called_commands[4] = []
+        self.obj.update_connected_ips(4)
+
+        self.assertEqual(self.obj.called_commands[4], [
+            ['-t', 'raw', '-F', 'QBS-PREROUTING'],
+            ['-t', 'mangle', '-F', 'QBS-POSTROUTING'],
+        ])
+
 
 class TestNftablesWorker(TestCase):
     def setUp(self):
@@ -534,7 +553,8 @@ class TestNftablesWorker(TestCase):
 
         self.assertEqual(self.obj.loaded_rules, [
             'flush chain ip qubes-firewall prerouting\n'
-            'flush chain ip qubes-firewall postrouting\n'
+            'flush chain ip qubes-firewall postrouting\n',
+
             'table ip qubes-firewall {\n'
             '  chain prerouting {\n'
             '    iifname != "vif*" ip saddr {10.137.0.1, 10.137.0.2} drop\n'
@@ -545,6 +565,25 @@ class TestNftablesWorker(TestCase):
             '}\n'
         ])
 
+    def test_009_update_connected_ips_empty(self):
+        self.obj.qdb.entries['/connected-ips'] = b''
+        self.obj.loaded_rules = []
+        self.obj.update_connected_ips(4)
+
+        self.assertEqual(self.obj.loaded_rules, [
+            'flush chain ip qubes-firewall prerouting\n'
+            'flush chain ip qubes-firewall postrouting\n'
+        ])
+
+    def test_010_update_connected_ips_missing(self):
+        self.obj.loaded_rules = []
+        self.obj.update_connected_ips(4)
+
+        self.assertEqual(self.obj.loaded_rules, [
+            'flush chain ip qubes-firewall prerouting\n'
+            'flush chain ip qubes-firewall postrouting\n'
+        ])
+
 class TestFirewallWorker(TestCase):
     def setUp(self):
         self.obj = FirewallWorker()