diff --git a/qubesagent/firewall.py b/qubesagent/firewall.py index b8af899..0bf2ed1 100755 --- a/qubesagent/firewall.py +++ b/qubesagent/firewall.py @@ -82,6 +82,8 @@ class FirewallWorker(object): def get_connected_ips(self, family): ips = self.qdb.read('/connected-ips6' if family == 6 else '/connected-ips') + if ips is None: + return [] return ips.decode().split() def run_firewall_dir(self): @@ -511,20 +513,24 @@ class NftablesWorker(FirewallWorker): def update_connected_ips(self, family): family_name = ('ip6' if family == 6 else 'ip') - ips = self.get_connected_ips(family) - if ips: - addr = '{' + ', '.join(ips) + '}' - irule = 'iifname != "vif*" {family_name} saddr {addr} drop\n'.format( - family_name=family_name, addr=addr) - orule = 'oifname != "vif*" {family_name} daddr {addr} drop\n'.format( - family_name=family_name, addr=addr) - else: - irule = '' - orule = '' + table = 'qubes-firewall' - nft_input = ( + self.run_nft(( 'flush chain {family_name} {table} prerouting\n' 'flush chain {family_name} {table} postrouting\n' + ).format(family_name=family_name, table=table)) + + ips = self.get_connected_ips(family) + if not ips: + return + + addr = '{' + ', '.join(ips) + '}' + irule = 'iifname != "vif*" {family_name} saddr {addr} drop\n'.format( + family_name=family_name, addr=addr) + orule = 'oifname != "vif*" {family_name} daddr {addr} drop\n'.format( + family_name=family_name, addr=addr) + + nft_input = ( 'table {family_name} {table} {{\n' ' chain prerouting {{\n' ' {irule}' @@ -535,7 +541,7 @@ class NftablesWorker(FirewallWorker): '}}\n' ).format( family_name=family_name, - table='qubes-firewall', + table=table, irule=irule, orule=orule, ) diff --git a/qubesagent/test_firewall.py b/qubesagent/test_firewall.py index b3db626..7de73fc 100644 --- a/qubesagent/test_firewall.py +++ b/qubesagent/test_firewall.py @@ -349,6 +349,25 @@ class TestIptablesWorker(TestCase): '!', '-o', 'vif+', '-d', '10.137.0.2', '-j', 'DROP'] ]) + def test_009_update_connected_ips_empty(self): + self.obj.qdb.entries['/connected-ips'] = b'' + self.obj.called_commands[4] = [] + self.obj.update_connected_ips(4) + + self.assertEqual(self.obj.called_commands[4], [ + ['-t', 'raw', '-F', 'QBS-PREROUTING'], + ['-t', 'mangle', '-F', 'QBS-POSTROUTING'], + ]) + + def test_010_update_connected_ips_missing(self): + self.obj.called_commands[4] = [] + self.obj.update_connected_ips(4) + + self.assertEqual(self.obj.called_commands[4], [ + ['-t', 'raw', '-F', 'QBS-PREROUTING'], + ['-t', 'mangle', '-F', 'QBS-POSTROUTING'], + ]) + class TestNftablesWorker(TestCase): def setUp(self): @@ -534,7 +553,8 @@ class TestNftablesWorker(TestCase): self.assertEqual(self.obj.loaded_rules, [ 'flush chain ip qubes-firewall prerouting\n' - 'flush chain ip qubes-firewall postrouting\n' + 'flush chain ip qubes-firewall postrouting\n', + 'table ip qubes-firewall {\n' ' chain prerouting {\n' ' iifname != "vif*" ip saddr {10.137.0.1, 10.137.0.2} drop\n' @@ -545,6 +565,25 @@ class TestNftablesWorker(TestCase): '}\n' ]) + def test_009_update_connected_ips_empty(self): + self.obj.qdb.entries['/connected-ips'] = b'' + self.obj.loaded_rules = [] + self.obj.update_connected_ips(4) + + self.assertEqual(self.obj.loaded_rules, [ + 'flush chain ip qubes-firewall prerouting\n' + 'flush chain ip qubes-firewall postrouting\n' + ]) + + def test_010_update_connected_ips_missing(self): + self.obj.loaded_rules = [] + self.obj.update_connected_ips(4) + + self.assertEqual(self.obj.loaded_rules, [ + 'flush chain ip qubes-firewall prerouting\n' + 'flush chain ip qubes-firewall postrouting\n' + ]) + class TestFirewallWorker(TestCase): def setUp(self): self.obj = FirewallWorker()