get_connected_ips: handle empty and missing keys, add tests

This commit is contained in:
Pawel Marczewski 2020-01-14 10:22:38 +01:00
parent e43fd2fc5a
commit 4aace50313
No known key found for this signature in database
GPG Key ID: DE42EE9B14F96465
2 changed files with 58 additions and 13 deletions

View File

@ -82,6 +82,8 @@ class FirewallWorker(object):
def get_connected_ips(self, family): def get_connected_ips(self, family):
ips = self.qdb.read('/connected-ips6' if family == 6 else '/connected-ips') ips = self.qdb.read('/connected-ips6' if family == 6 else '/connected-ips')
if ips is None:
return []
return ips.decode().split() return ips.decode().split()
def run_firewall_dir(self): def run_firewall_dir(self):
@ -511,20 +513,24 @@ class NftablesWorker(FirewallWorker):
def update_connected_ips(self, family): def update_connected_ips(self, family):
family_name = ('ip6' if family == 6 else 'ip') family_name = ('ip6' if family == 6 else 'ip')
ips = self.get_connected_ips(family) table = 'qubes-firewall'
if ips:
addr = '{' + ', '.join(ips) + '}'
irule = 'iifname != "vif*" {family_name} saddr {addr} drop\n'.format(
family_name=family_name, addr=addr)
orule = 'oifname != "vif*" {family_name} daddr {addr} drop\n'.format(
family_name=family_name, addr=addr)
else:
irule = ''
orule = ''
nft_input = ( self.run_nft((
'flush chain {family_name} {table} prerouting\n' 'flush chain {family_name} {table} prerouting\n'
'flush chain {family_name} {table} postrouting\n' 'flush chain {family_name} {table} postrouting\n'
).format(family_name=family_name, table=table))
ips = self.get_connected_ips(family)
if not ips:
return
addr = '{' + ', '.join(ips) + '}'
irule = 'iifname != "vif*" {family_name} saddr {addr} drop\n'.format(
family_name=family_name, addr=addr)
orule = 'oifname != "vif*" {family_name} daddr {addr} drop\n'.format(
family_name=family_name, addr=addr)
nft_input = (
'table {family_name} {table} {{\n' 'table {family_name} {table} {{\n'
' chain prerouting {{\n' ' chain prerouting {{\n'
' {irule}' ' {irule}'
@ -535,7 +541,7 @@ class NftablesWorker(FirewallWorker):
'}}\n' '}}\n'
).format( ).format(
family_name=family_name, family_name=family_name,
table='qubes-firewall', table=table,
irule=irule, irule=irule,
orule=orule, orule=orule,
) )

View File

@ -349,6 +349,25 @@ class TestIptablesWorker(TestCase):
'!', '-o', 'vif+', '-d', '10.137.0.2', '-j', 'DROP'] '!', '-o', 'vif+', '-d', '10.137.0.2', '-j', 'DROP']
]) ])
def test_009_update_connected_ips_empty(self):
self.obj.qdb.entries['/connected-ips'] = b''
self.obj.called_commands[4] = []
self.obj.update_connected_ips(4)
self.assertEqual(self.obj.called_commands[4], [
['-t', 'raw', '-F', 'QBS-PREROUTING'],
['-t', 'mangle', '-F', 'QBS-POSTROUTING'],
])
def test_010_update_connected_ips_missing(self):
self.obj.called_commands[4] = []
self.obj.update_connected_ips(4)
self.assertEqual(self.obj.called_commands[4], [
['-t', 'raw', '-F', 'QBS-PREROUTING'],
['-t', 'mangle', '-F', 'QBS-POSTROUTING'],
])
class TestNftablesWorker(TestCase): class TestNftablesWorker(TestCase):
def setUp(self): def setUp(self):
@ -534,7 +553,8 @@ class TestNftablesWorker(TestCase):
self.assertEqual(self.obj.loaded_rules, [ self.assertEqual(self.obj.loaded_rules, [
'flush chain ip qubes-firewall prerouting\n' 'flush chain ip qubes-firewall prerouting\n'
'flush chain ip qubes-firewall postrouting\n' 'flush chain ip qubes-firewall postrouting\n',
'table ip qubes-firewall {\n' 'table ip qubes-firewall {\n'
' chain prerouting {\n' ' chain prerouting {\n'
' iifname != "vif*" ip saddr {10.137.0.1, 10.137.0.2} drop\n' ' iifname != "vif*" ip saddr {10.137.0.1, 10.137.0.2} drop\n'
@ -545,6 +565,25 @@ class TestNftablesWorker(TestCase):
'}\n' '}\n'
]) ])
def test_009_update_connected_ips_empty(self):
self.obj.qdb.entries['/connected-ips'] = b''
self.obj.loaded_rules = []
self.obj.update_connected_ips(4)
self.assertEqual(self.obj.loaded_rules, [
'flush chain ip qubes-firewall prerouting\n'
'flush chain ip qubes-firewall postrouting\n'
])
def test_010_update_connected_ips_missing(self):
self.obj.loaded_rules = []
self.obj.update_connected_ips(4)
self.assertEqual(self.obj.loaded_rules, [
'flush chain ip qubes-firewall prerouting\n'
'flush chain ip qubes-firewall postrouting\n'
])
class TestFirewallWorker(TestCase): class TestFirewallWorker(TestCase):
def setUp(self): def setUp(self):
self.obj = FirewallWorker() self.obj = FirewallWorker()