get_connected_ips: handle empty and missing keys, add tests
This commit is contained in:
parent
e43fd2fc5a
commit
4aace50313
@ -82,6 +82,8 @@ class FirewallWorker(object):
|
||||
|
||||
def get_connected_ips(self, family):
|
||||
ips = self.qdb.read('/connected-ips6' if family == 6 else '/connected-ips')
|
||||
if ips is None:
|
||||
return []
|
||||
return ips.decode().split()
|
||||
|
||||
def run_firewall_dir(self):
|
||||
@ -511,20 +513,24 @@ class NftablesWorker(FirewallWorker):
|
||||
|
||||
def update_connected_ips(self, family):
|
||||
family_name = ('ip6' if family == 6 else 'ip')
|
||||
ips = self.get_connected_ips(family)
|
||||
if ips:
|
||||
addr = '{' + ', '.join(ips) + '}'
|
||||
irule = 'iifname != "vif*" {family_name} saddr {addr} drop\n'.format(
|
||||
family_name=family_name, addr=addr)
|
||||
orule = 'oifname != "vif*" {family_name} daddr {addr} drop\n'.format(
|
||||
family_name=family_name, addr=addr)
|
||||
else:
|
||||
irule = ''
|
||||
orule = ''
|
||||
table = 'qubes-firewall'
|
||||
|
||||
nft_input = (
|
||||
self.run_nft((
|
||||
'flush chain {family_name} {table} prerouting\n'
|
||||
'flush chain {family_name} {table} postrouting\n'
|
||||
).format(family_name=family_name, table=table))
|
||||
|
||||
ips = self.get_connected_ips(family)
|
||||
if not ips:
|
||||
return
|
||||
|
||||
addr = '{' + ', '.join(ips) + '}'
|
||||
irule = 'iifname != "vif*" {family_name} saddr {addr} drop\n'.format(
|
||||
family_name=family_name, addr=addr)
|
||||
orule = 'oifname != "vif*" {family_name} daddr {addr} drop\n'.format(
|
||||
family_name=family_name, addr=addr)
|
||||
|
||||
nft_input = (
|
||||
'table {family_name} {table} {{\n'
|
||||
' chain prerouting {{\n'
|
||||
' {irule}'
|
||||
@ -535,7 +541,7 @@ class NftablesWorker(FirewallWorker):
|
||||
'}}\n'
|
||||
).format(
|
||||
family_name=family_name,
|
||||
table='qubes-firewall',
|
||||
table=table,
|
||||
irule=irule,
|
||||
orule=orule,
|
||||
)
|
||||
|
@ -349,6 +349,25 @@ class TestIptablesWorker(TestCase):
|
||||
'!', '-o', 'vif+', '-d', '10.137.0.2', '-j', 'DROP']
|
||||
])
|
||||
|
||||
def test_009_update_connected_ips_empty(self):
|
||||
self.obj.qdb.entries['/connected-ips'] = b''
|
||||
self.obj.called_commands[4] = []
|
||||
self.obj.update_connected_ips(4)
|
||||
|
||||
self.assertEqual(self.obj.called_commands[4], [
|
||||
['-t', 'raw', '-F', 'QBS-PREROUTING'],
|
||||
['-t', 'mangle', '-F', 'QBS-POSTROUTING'],
|
||||
])
|
||||
|
||||
def test_010_update_connected_ips_missing(self):
|
||||
self.obj.called_commands[4] = []
|
||||
self.obj.update_connected_ips(4)
|
||||
|
||||
self.assertEqual(self.obj.called_commands[4], [
|
||||
['-t', 'raw', '-F', 'QBS-PREROUTING'],
|
||||
['-t', 'mangle', '-F', 'QBS-POSTROUTING'],
|
||||
])
|
||||
|
||||
|
||||
class TestNftablesWorker(TestCase):
|
||||
def setUp(self):
|
||||
@ -534,7 +553,8 @@ class TestNftablesWorker(TestCase):
|
||||
|
||||
self.assertEqual(self.obj.loaded_rules, [
|
||||
'flush chain ip qubes-firewall prerouting\n'
|
||||
'flush chain ip qubes-firewall postrouting\n'
|
||||
'flush chain ip qubes-firewall postrouting\n',
|
||||
|
||||
'table ip qubes-firewall {\n'
|
||||
' chain prerouting {\n'
|
||||
' iifname != "vif*" ip saddr {10.137.0.1, 10.137.0.2} drop\n'
|
||||
@ -545,6 +565,25 @@ class TestNftablesWorker(TestCase):
|
||||
'}\n'
|
||||
])
|
||||
|
||||
def test_009_update_connected_ips_empty(self):
|
||||
self.obj.qdb.entries['/connected-ips'] = b''
|
||||
self.obj.loaded_rules = []
|
||||
self.obj.update_connected_ips(4)
|
||||
|
||||
self.assertEqual(self.obj.loaded_rules, [
|
||||
'flush chain ip qubes-firewall prerouting\n'
|
||||
'flush chain ip qubes-firewall postrouting\n'
|
||||
])
|
||||
|
||||
def test_010_update_connected_ips_missing(self):
|
||||
self.obj.loaded_rules = []
|
||||
self.obj.update_connected_ips(4)
|
||||
|
||||
self.assertEqual(self.obj.loaded_rules, [
|
||||
'flush chain ip qubes-firewall prerouting\n'
|
||||
'flush chain ip qubes-firewall postrouting\n'
|
||||
])
|
||||
|
||||
class TestFirewallWorker(TestCase):
|
||||
def setUp(self):
|
||||
self.obj = FirewallWorker()
|
||||
|
Loading…
Reference in New Issue
Block a user