Merge remote-tracking branch 'origin/pr/303'

* origin/pr/303:
  firewall: prefer - over _ for QubesDB path
  firewall: put DNS resolving into its own function
  firewall: start watches before initial load
  tests/firewall: added test for /dns/[ip]/[domain] info
  tests/firewall: some code refactoring
  add some checks for QubesDB /qubes-firewall_handled/[ip]
  firewall: adjust tests to the new tuple returned by prepare_rules()
  firewall: mark an IP as handled in /qubes-firewall_handled/[ip] after each handling iteration
  mock qubesdb.rm()
  firewall: refactor to remove side effects from prepare_rules()
  Export DNS information obtained during firewall setup to QubesDB
This commit is contained in:
Marek Marczykowski-Górecki 2021-06-01 05:15:27 +02:00
commit 39a010445e
No known key found for this signature in database
GPG Key ID: 063938BA42CFA724
2 changed files with 137 additions and 43 deletions

View File

@ -127,6 +127,55 @@ class FirewallWorker(object):
rules.append({'action': policy}) rules.append({'action': policy})
return rules return rules
def resolve_dns(self, fqdn, family):
"""
Resolve the given FQDN via DNS.
:param fqdn: FQDN
:param family: 4 or 6 for IPv4 or IPv6
:return: see socket.getaddrinfo()
:raises: RuleParseError
"""
try:
addrinfo = socket.getaddrinfo(fqdn, None,
(socket.AF_INET6 if family == 6 else socket.AF_INET))
except socket.gaierror as e:
raise RuleParseError('Failed to resolve {}: {}'.format(
fqdn, str(e)))
except UnicodeError as e:
raise RuleParseError('Invalid destination {}: {}'.format(
fqdn, str(e)))
return addrinfo
def update_dns_info(self, source, dns):
"""
Write resolved DNS addresses back to QubesDB. This can be useful
for the user or DNS applications to pin these DNS addresses to the
IPs resolved during firewall setup.
:param source: VM IP
:param dns: dict: hostname -> set of IP addresses
:return: None
"""
#clear old info
self.qdb.rm('/dns/{}/'.format(source))
for host, hostaddrs in dns.items():
self.qdb.write('/dns/{}/{}'.format(source, host), str(hostaddrs))
def update_handled(self, addr):
"""
Update the QubesDB count of how often the given address was handled.
User applications may watch these paths for count increases to remain
up to date with QubesDB changes.
"""
cnt = self.qdb.read('/qubes-firewall-handled/{}'.format(addr))
try:
cnt = int(cnt)
except (TypeError, ValueError):
cnt = 0
self.qdb.write('/qubes-firewall-handled/{}'.format(addr), str(cnt+1))
def list_targets(self): def list_targets(self):
return set(t.split('/')[2] for t in self.qdb.list('/qubes-firewall/')) return set(t.split('/')[2] for t in self.qdb.list('/qubes-firewall/'))
@ -163,6 +212,8 @@ class FirewallWorker(object):
self.log_error( self.log_error(
'Failed to block traffic for {}'.format(addr)) 'Failed to block traffic for {}'.format(addr))
self.update_handled(addr)
@staticmethod @staticmethod
def dns_addresses(family=None): def dns_addresses(family=None):
with open('/etc/resolv.conf') as resolv: with open('/etc/resolv.conf') as resolv:
@ -180,14 +231,14 @@ class FirewallWorker(object):
self.run_firewall_dir() self.run_firewall_dir()
self.run_user_script() self.run_user_script()
self.sd_notify('READY=1') self.sd_notify('READY=1')
self.qdb.watch('/qubes-firewall/')
self.qdb.watch('/connected-ips')
self.qdb.watch('/connected-ips6')
# initial load # initial load
for source_addr in self.list_targets(): for source_addr in self.list_targets():
self.handle_addr(source_addr) self.handle_addr(source_addr)
self.update_connected_ips(4) self.update_connected_ips(4)
self.update_connected_ips(6) self.update_connected_ips(6)
self.qdb.watch('/qubes-firewall/')
self.qdb.watch('/connected-ips')
self.qdb.watch('/connected-ips6')
try: try:
for watch_path in iter(self.qdb.read_watch, None): for watch_path in iter(self.qdb.read_watch, None):
if watch_path == '/connected-ips': if watch_path == '/connected-ips':
@ -271,8 +322,9 @@ class IptablesWorker(FirewallWorker):
:param chain: name of the chain to put rules into :param chain: name of the chain to put rules into
:param rules: list of rules :param rules: list of rules
:param family: address family (4 or 6) :param family: address family (4 or 6)
:return: input for iptables-restore :return: tuple: (input for iptables-restore, dict of DNS records resolved
:rtype: str during execution)
:rtype: (str, dict)
""" """
iptables = "*filter\n" iptables = "*filter\n"
@ -281,6 +333,8 @@ class IptablesWorker(FirewallWorker):
dns = list(addr + fullmask for addr in self.dns_addresses(family)) dns = list(addr + fullmask for addr in self.dns_addresses(family))
ret_dns = {}
for rule in rules: for rule in rules:
unsupported_opts = set(rule.keys()).difference( unsupported_opts = set(rule.keys()).difference(
set(self.supported_rule_opts)) set(self.supported_rule_opts))
@ -305,13 +359,9 @@ class IptablesWorker(FirewallWorker):
elif 'dst6' in rule: elif 'dst6' in rule:
dsthosts = [rule['dst6']] dsthosts = [rule['dst6']]
elif 'dsthost' in rule: elif 'dsthost' in rule:
try: addrinfo = self.resolve_dns(rule['dsthost'], family)
addrinfo = socket.getaddrinfo(rule['dsthost'], None,
(socket.AF_INET6 if family == 6 else socket.AF_INET))
except socket.gaierror as e:
raise RuleParseError('Failed to resolve {}: {}'.format(
rule['dsthost'], str(e)))
dsthosts = set(item[4][0] + fullmask for item in addrinfo) dsthosts = set(item[4][0] + fullmask for item in addrinfo)
ret_dns[rule['dsthost']] = dsthosts
else: else:
dsthosts = None dsthosts = None
@ -374,7 +424,7 @@ class IptablesWorker(FirewallWorker):
iptables += ipt_rule iptables += ipt_rule
iptables += 'COMMIT\n' iptables += 'COMMIT\n'
return iptables return (iptables, ret_dns)
def apply_rules_family(self, source, rules, family): def apply_rules_family(self, source, rules, family):
""" """
@ -391,7 +441,7 @@ class IptablesWorker(FirewallWorker):
if chain not in self.chains[family]: if chain not in self.chains[family]:
self.create_chain(source, chain, family) self.create_chain(source, chain, family)
iptables = self.prepare_rules(chain, rules, family) (iptables, dns) = self.prepare_rules(chain, rules, family)
try: try:
self.run_ipt(family, ['-F', chain]) self.run_ipt(family, ['-F', chain])
p = self.run_ipt_restore(family, ['-n']) p = self.run_ipt_restore(family, ['-n'])
@ -399,6 +449,7 @@ class IptablesWorker(FirewallWorker):
if p.returncode != 0: if p.returncode != 0:
raise RuleApplyError( raise RuleApplyError(
'iptables-restore failed: {}'.format(output)) 'iptables-restore failed: {}'.format(output))
self.update_dns_info(source, dns)
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
raise RuleApplyError('\'iptables -F {}\' failed: {}'.format( raise RuleApplyError('\'iptables -F {}\' failed: {}'.format(
chain, e.output)) chain, e.output))
@ -561,13 +612,14 @@ class NftablesWorker(FirewallWorker):
def prepare_rules(self, chain, rules, family): def prepare_rules(self, chain, rules, family):
""" """
Helper function to translate rules list into input for iptables-restore Helper function to translate rules list into input for nft
:param chain: name of the chain to put rules into :param chain: name of the chain to put rules into
:param rules: list of rules :param rules: list of rules
:param family: address family (4 or 6) :param family: address family (4 or 6)
:return: input for iptables-restore :return: tuple: (input for nft, dict of DNS records resolved
:rtype: str during execution)
:rtype: (str, dict)
""" """
assert family in (4, 6) assert family in (4, 6)
@ -578,6 +630,8 @@ class NftablesWorker(FirewallWorker):
dns = list(addr + fullmask for addr in self.dns_addresses(family)) dns = list(addr + fullmask for addr in self.dns_addresses(family))
ret_dns = {}
for rule in rules: for rule in rules:
unsupported_opts = set(rule.keys()).difference( unsupported_opts = set(rule.keys()).difference(
set(self.supported_rule_opts)) set(self.supported_rule_opts))
@ -613,17 +667,11 @@ class NftablesWorker(FirewallWorker):
elif 'dst6' in rule: elif 'dst6' in rule:
nft_rule += ' ip6 daddr {}'.format(rule['dst6']) nft_rule += ' ip6 daddr {}'.format(rule['dst6'])
elif 'dsthost' in rule: elif 'dsthost' in rule:
try: addrinfo = self.resolve_dns(rule['dsthost'], family)
addrinfo = socket.getaddrinfo(rule['dsthost'], None, dsthosts = set(item[4][0] + fullmask for item in addrinfo)
(socket.AF_INET6 if family == 6 else socket.AF_INET))
except socket.gaierror as e:
raise RuleParseError('Failed to resolve {}: {}'.format(
rule['dsthost'], str(e)))
except UnicodeError as e:
raise RuleParseError('Invalid destination {}: {}'.format(
rule['dsthost'], str(e)))
nft_rule += ' {} daddr {{ {} }}'.format(ip_match, nft_rule += ' {} daddr {{ {} }}'.format(ip_match,
', '.join(set(item[4][0] + fullmask for item in addrinfo))) ', '.join(dsthosts))
ret_dns[rule['dsthost']] = dsthosts
if 'dstports' in rule: if 'dstports' in rule:
dstports = rule['dstports'] dstports = rule['dstports']
@ -677,7 +725,7 @@ class NftablesWorker(FirewallWorker):
table='qubes-firewall', table='qubes-firewall',
chain=chain, chain=chain,
rules='\n '.join(nft_rules) rules='\n '.join(nft_rules)
)) ), ret_dns)
def apply_rules_family(self, source, rules, family): def apply_rules_family(self, source, rules, family):
""" """
@ -694,7 +742,9 @@ class NftablesWorker(FirewallWorker):
if chain not in self.chains[family]: if chain not in self.chains[family]:
self.create_chain(source, chain, family) self.create_chain(source, chain, family)
self.run_nft(self.prepare_rules(chain, rules, family)) (nft, dns) = self.prepare_rules(chain, rules, family)
self.run_nft(nft)
self.update_dns_info(source, dns)
def apply_rules(self, source, rules): def apply_rules(self, source, rules):
if self.is_ip6(source): if self.is_ip6(source):

View File

@ -1,5 +1,6 @@
import logging import logging
import operator import operator
import re
from unittest import TestCase from unittest import TestCase
from unittest.mock import patch from unittest.mock import patch
@ -29,6 +30,17 @@ class DummyQubesDB(object):
except KeyError: except KeyError:
return None return None
def rm(self, path):
if path.endswith('/'):
for key in list(self.entries):
if key.startswith(path):
self.entries.pop(key)
else:
self.entries.pop(path)
def write(self, path, val):
self.entries[path] = val
def multiread(self, prefix): def multiread(self, prefix):
result = {} result = {}
for key, value in self.entries.items(): for key, value in self.entries.items():
@ -154,8 +166,31 @@ class NftablesWorker(qubesagent.firewall.NftablesWorker):
else: else:
return ['2001::1', '2001::2'] return ['2001::1', '2001::2']
class WorkerCommon(object):
def assertPrepareRulesDnsRet(self, dns_ret, expected_domain, family):
self.assertEqual(dns_ret.keys(), {expected_domain})
self.assertIsInstance(dns_ret[expected_domain], set)
if family == 4:
self.assertIsNotNone(re.match('^\d+\.\d+\.\d+\.\d+/32$',
dns_ret[expected_domain].pop()))
elif family == 6:
self.assertIsNotNone(re.match('^[0-9a-f:]+/\d+$',
dns_ret[expected_domain].pop()))
else:
raise ValueError()
class TestIptablesWorker(TestCase): def test_701_dns_info(self):
rules = [
{'action': 'accept', 'proto': 'tcp',
'dstports': '80-80', 'dsthost': 'ripe.net'},
{'action': 'drop'},
]
self.obj.apply_rules('10.137.0.1', rules)
self.assertIsNotNone(self.obj.qdb.read('/dns/10.137.0.1/ripe.net'))
self.obj.apply_rules('10.137.0.1', [{'action': 'drop'}])
self.assertIsNone(self.obj.qdb.read('/dns/10.137.0.1/ripe.net'))
class TestIptablesWorker(TestCase, WorkerCommon):
def setUp(self): def setUp(self):
super(TestIptablesWorker, self).setUp() super(TestIptablesWorker, self).setUp()
self.obj = IptablesWorker() self.obj = IptablesWorker()
@ -212,8 +247,9 @@ class TestIptablesWorker(TestCase):
"--reject-with icmp-admin-prohibited\n" "--reject-with icmp-admin-prohibited\n"
"COMMIT\n" "COMMIT\n"
) )
self.assertEqual(self.obj.prepare_rules('chain', rules, 4), ret = self.obj.prepare_rules('chain', rules, 4)
expected_iptables) self.assertEqual(ret[0], expected_iptables)
self.assertPrepareRulesDnsRet(ret[1], 'yum.qubes-os.org', 4)
with self.assertRaises(qubesagent.firewall.RuleParseError): with self.assertRaises(qubesagent.firewall.RuleParseError):
self.obj.prepare_rules('chain', [{'unknown': 'xxx'}], 4) self.obj.prepare_rules('chain', [{'unknown': 'xxx'}], 4)
with self.assertRaises(qubesagent.firewall.RuleParseError): with self.assertRaises(qubesagent.firewall.RuleParseError):
@ -250,8 +286,9 @@ class TestIptablesWorker(TestCase):
"--reject-with icmp6-adm-prohibited\n" "--reject-with icmp6-adm-prohibited\n"
"COMMIT\n" "COMMIT\n"
) )
self.assertEqual(self.obj.prepare_rules('chain', rules, 6), ret = self.obj.prepare_rules('chain', rules, 6)
expected_iptables) self.assertEqual(ret[0], expected_iptables)
self.assertPrepareRulesDnsRet(ret[1], 'ripe.net', 6)
def test_004_apply_rules4(self): def test_004_apply_rules4(self):
rules = [{'action': 'accept'}] rules = [{'action': 'accept'}]
@ -263,7 +300,7 @@ class TestIptablesWorker(TestCase):
['-I', 'QBS-FORWARD', '-s', '10.137.0.1', '-j', chain], ['-I', 'QBS-FORWARD', '-s', '10.137.0.1', '-j', chain],
['-F', chain]]) ['-F', chain]])
self.assertEqual(self.obj.loaded_iptables[4], self.assertEqual(self.obj.loaded_iptables[4],
self.obj.prepare_rules(chain, rules, 4)) self.obj.prepare_rules(chain, rules, 4)[0])
self.assertEqual(self.obj.called_commands[6], []) self.assertEqual(self.obj.called_commands[6], [])
self.assertIsNone(self.obj.loaded_iptables[6]) self.assertIsNone(self.obj.loaded_iptables[6])
@ -277,7 +314,7 @@ class TestIptablesWorker(TestCase):
['-I', 'QBS-FORWARD', '-s', '2000::a', '-j', chain], ['-I', 'QBS-FORWARD', '-s', '2000::a', '-j', chain],
['-F', chain]]) ['-F', chain]])
self.assertEqual(self.obj.loaded_iptables[6], self.assertEqual(self.obj.loaded_iptables[6],
self.obj.prepare_rules(chain, rules, 6)) self.obj.prepare_rules(chain, rules, 6)[0])
self.assertEqual(self.obj.called_commands[4], []) self.assertEqual(self.obj.called_commands[4], [])
self.assertIsNone(self.obj.loaded_iptables[4]) self.assertIsNone(self.obj.loaded_iptables[4])
@ -372,8 +409,7 @@ class TestIptablesWorker(TestCase):
['-t', 'mangle', '-F', 'QBS-POSTROUTING'], ['-t', 'mangle', '-F', 'QBS-POSTROUTING'],
]) ])
class TestNftablesWorker(TestCase, WorkerCommon):
class TestNftablesWorker(TestCase):
def setUp(self): def setUp(self):
super(TestNftablesWorker, self).setUp() super(TestNftablesWorker, self).setUp()
self.obj = NftablesWorker() self.obj = NftablesWorker()
@ -440,8 +476,9 @@ class TestNftablesWorker(TestCase):
' }\n' ' }\n'
'}\n' '}\n'
) )
self.assertEqual(self.obj.prepare_rules('chain', rules, 4), ret = self.obj.prepare_rules('chain', rules, 4)
expected_nft) self.assertEqual(ret[0], expected_nft)
self.assertPrepareRulesDnsRet(ret[1], 'yum.qubes-os.org', 4)
with self.assertRaises(qubesagent.firewall.RuleParseError): with self.assertRaises(qubesagent.firewall.RuleParseError):
self.obj.prepare_rules('chain', [{'unknown': 'xxx'}], 4) self.obj.prepare_rules('chain', [{'unknown': 'xxx'}], 4)
with self.assertRaises(qubesagent.firewall.RuleParseError): with self.assertRaises(qubesagent.firewall.RuleParseError):
@ -477,8 +514,9 @@ class TestNftablesWorker(TestCase):
' }\n' ' }\n'
'}\n' '}\n'
) )
self.assertEqual(self.obj.prepare_rules('chain', rules, 6), ret = self.obj.prepare_rules('chain', rules, 6)
expected_nft) self.assertEqual(ret[0], expected_nft)
self.assertPrepareRulesDnsRet(ret[1], 'ripe.net', 6)
def test_004_apply_rules4(self): def test_004_apply_rules4(self):
rules = [{'action': 'accept'}] rules = [{'action': 'accept'}]
@ -486,7 +524,7 @@ class TestNftablesWorker(TestCase):
self.obj.apply_rules('10.137.0.1', rules) self.obj.apply_rules('10.137.0.1', rules)
self.assertEqual(self.obj.loaded_rules, self.assertEqual(self.obj.loaded_rules,
[self.expected_create_chain('ip', '10.137.0.1', chain), [self.expected_create_chain('ip', '10.137.0.1', chain),
self.obj.prepare_rules(chain, rules, 4), self.obj.prepare_rules(chain, rules, 4)[0],
]) ])
def test_005_apply_rules6(self): def test_005_apply_rules6(self):
@ -495,7 +533,7 @@ class TestNftablesWorker(TestCase):
self.obj.apply_rules('2000::a', rules) self.obj.apply_rules('2000::a', rules)
self.assertEqual(self.obj.loaded_rules, self.assertEqual(self.obj.loaded_rules,
[self.expected_create_chain('ip6', '2000::a', chain), [self.expected_create_chain('ip6', '2000::a', chain),
self.obj.prepare_rules(chain, rules, 6), self.obj.prepare_rules(chain, rules, 6)[0],
]) ])
def test_006_init(self): def test_006_init(self):
@ -647,11 +685,17 @@ class TestFirewallWorker(TestCase):
def test_handle_addr(self): def test_handle_addr(self):
self.obj.handle_addr('10.137.0.2') self.obj.handle_addr('10.137.0.2')
self.assertEqual(self.obj.rules['10.137.0.2'], [{'action': 'accept'}]) self.assertEqual(self.obj.rules['10.137.0.2'], [{'action': 'accept'}])
self.assertEqual(self.obj.qdb.entries['/qubes-firewall-handled/10.137.0.2'], '1')
self.obj.handle_addr('10.137.0.2')
self.assertEqual(self.obj.rules['10.137.0.2'], [{'action': 'accept'}])
self.assertEqual(self.obj.qdb.entries['/qubes-firewall-handled/10.137.0.2'], '2')
# fallback to block all # fallback to block all
self.obj.handle_addr('10.137.0.3') self.obj.handle_addr('10.137.0.3')
self.assertEqual(self.obj.rules['10.137.0.3'], [{'action': 'drop'}]) self.assertEqual(self.obj.rules['10.137.0.3'], [{'action': 'drop'}])
self.assertEqual(self.obj.qdb.entries['/qubes-firewall-handled/10.137.0.3'], '1')
self.obj.handle_addr('10.137.0.4') self.obj.handle_addr('10.137.0.4')
self.assertEqual(self.obj.rules['10.137.0.4'], [{'action': 'drop'}]) self.assertEqual(self.obj.rules['10.137.0.4'], [{'action': 'drop'}])
self.assertEqual(self.obj.qdb.entries['/qubes-firewall-handled/10.137.0.4'], '1')
@patch('os.path.isfile') @patch('os.path.isfile')
@patch('os.access') @patch('os.access')