Merge remote-tracking branch 'origin/pr/303'

* origin/pr/303:
  firewall: prefer - over _ for QubesDB path
  firewall: put DNS resolving into its own function
  firewall: start watches before initial load
  tests/firewall: added test for /dns/[ip]/[domain] info
  tests/firewall: some code refactoring
  add some checks for QubesDB /qubes-firewall_handled/[ip]
  firewall: adjust tests to the new tuple returned by prepare_rules()
  firewall: mark an IP as handled in /qubes-firewall_handled/[ip] after each handling iteration
  mock qubesdb.rm()
  firewall: refactor to remove side effects from prepare_rules()
  Export DNS information obtained during firewall setup to QubesDB
This commit is contained in:
Marek Marczykowski-Górecki 2021-06-01 05:15:27 +02:00
commit 39a010445e
No known key found for this signature in database
GPG Key ID: 063938BA42CFA724
2 changed files with 137 additions and 43 deletions

View File

@ -127,6 +127,55 @@ class FirewallWorker(object):
rules.append({'action': policy})
return rules
def resolve_dns(self, fqdn, family):
"""
Resolve the given FQDN via DNS.
:param fqdn: FQDN
:param family: 4 or 6 for IPv4 or IPv6
:return: see socket.getaddrinfo()
:raises: RuleParseError
"""
try:
addrinfo = socket.getaddrinfo(fqdn, None,
(socket.AF_INET6 if family == 6 else socket.AF_INET))
except socket.gaierror as e:
raise RuleParseError('Failed to resolve {}: {}'.format(
fqdn, str(e)))
except UnicodeError as e:
raise RuleParseError('Invalid destination {}: {}'.format(
fqdn, str(e)))
return addrinfo
def update_dns_info(self, source, dns):
"""
Write resolved DNS addresses back to QubesDB. This can be useful
for the user or DNS applications to pin these DNS addresses to the
IPs resolved during firewall setup.
:param source: VM IP
:param dns: dict: hostname -> set of IP addresses
:return: None
"""
#clear old info
self.qdb.rm('/dns/{}/'.format(source))
for host, hostaddrs in dns.items():
self.qdb.write('/dns/{}/{}'.format(source, host), str(hostaddrs))
def update_handled(self, addr):
"""
Update the QubesDB count of how often the given address was handled.
User applications may watch these paths for count increases to remain
up to date with QubesDB changes.
"""
cnt = self.qdb.read('/qubes-firewall-handled/{}'.format(addr))
try:
cnt = int(cnt)
except (TypeError, ValueError):
cnt = 0
self.qdb.write('/qubes-firewall-handled/{}'.format(addr), str(cnt+1))
def list_targets(self):
return set(t.split('/')[2] for t in self.qdb.list('/qubes-firewall/'))
@ -163,6 +212,8 @@ class FirewallWorker(object):
self.log_error(
'Failed to block traffic for {}'.format(addr))
self.update_handled(addr)
@staticmethod
def dns_addresses(family=None):
with open('/etc/resolv.conf') as resolv:
@ -180,14 +231,14 @@ class FirewallWorker(object):
self.run_firewall_dir()
self.run_user_script()
self.sd_notify('READY=1')
self.qdb.watch('/qubes-firewall/')
self.qdb.watch('/connected-ips')
self.qdb.watch('/connected-ips6')
# initial load
for source_addr in self.list_targets():
self.handle_addr(source_addr)
self.update_connected_ips(4)
self.update_connected_ips(6)
self.qdb.watch('/qubes-firewall/')
self.qdb.watch('/connected-ips')
self.qdb.watch('/connected-ips6')
try:
for watch_path in iter(self.qdb.read_watch, None):
if watch_path == '/connected-ips':
@ -271,8 +322,9 @@ class IptablesWorker(FirewallWorker):
:param chain: name of the chain to put rules into
:param rules: list of rules
:param family: address family (4 or 6)
:return: input for iptables-restore
:rtype: str
:return: tuple: (input for iptables-restore, dict of DNS records resolved
during execution)
:rtype: (str, dict)
"""
iptables = "*filter\n"
@ -281,6 +333,8 @@ class IptablesWorker(FirewallWorker):
dns = list(addr + fullmask for addr in self.dns_addresses(family))
ret_dns = {}
for rule in rules:
unsupported_opts = set(rule.keys()).difference(
set(self.supported_rule_opts))
@ -305,13 +359,9 @@ class IptablesWorker(FirewallWorker):
elif 'dst6' in rule:
dsthosts = [rule['dst6']]
elif 'dsthost' in rule:
try:
addrinfo = socket.getaddrinfo(rule['dsthost'], None,
(socket.AF_INET6 if family == 6 else socket.AF_INET))
except socket.gaierror as e:
raise RuleParseError('Failed to resolve {}: {}'.format(
rule['dsthost'], str(e)))
addrinfo = self.resolve_dns(rule['dsthost'], family)
dsthosts = set(item[4][0] + fullmask for item in addrinfo)
ret_dns[rule['dsthost']] = dsthosts
else:
dsthosts = None
@ -374,7 +424,7 @@ class IptablesWorker(FirewallWorker):
iptables += ipt_rule
iptables += 'COMMIT\n'
return iptables
return (iptables, ret_dns)
def apply_rules_family(self, source, rules, family):
"""
@ -391,7 +441,7 @@ class IptablesWorker(FirewallWorker):
if chain not in self.chains[family]:
self.create_chain(source, chain, family)
iptables = self.prepare_rules(chain, rules, family)
(iptables, dns) = self.prepare_rules(chain, rules, family)
try:
self.run_ipt(family, ['-F', chain])
p = self.run_ipt_restore(family, ['-n'])
@ -399,6 +449,7 @@ class IptablesWorker(FirewallWorker):
if p.returncode != 0:
raise RuleApplyError(
'iptables-restore failed: {}'.format(output))
self.update_dns_info(source, dns)
except subprocess.CalledProcessError as e:
raise RuleApplyError('\'iptables -F {}\' failed: {}'.format(
chain, e.output))
@ -561,13 +612,14 @@ class NftablesWorker(FirewallWorker):
def prepare_rules(self, chain, rules, family):
"""
Helper function to translate rules list into input for iptables-restore
Helper function to translate rules list into input for nft
:param chain: name of the chain to put rules into
:param rules: list of rules
:param family: address family (4 or 6)
:return: input for iptables-restore
:rtype: str
:return: tuple: (input for nft, dict of DNS records resolved
during execution)
:rtype: (str, dict)
"""
assert family in (4, 6)
@ -578,6 +630,8 @@ class NftablesWorker(FirewallWorker):
dns = list(addr + fullmask for addr in self.dns_addresses(family))
ret_dns = {}
for rule in rules:
unsupported_opts = set(rule.keys()).difference(
set(self.supported_rule_opts))
@ -613,17 +667,11 @@ class NftablesWorker(FirewallWorker):
elif 'dst6' in rule:
nft_rule += ' ip6 daddr {}'.format(rule['dst6'])
elif 'dsthost' in rule:
try:
addrinfo = socket.getaddrinfo(rule['dsthost'], None,
(socket.AF_INET6 if family == 6 else socket.AF_INET))
except socket.gaierror as e:
raise RuleParseError('Failed to resolve {}: {}'.format(
rule['dsthost'], str(e)))
except UnicodeError as e:
raise RuleParseError('Invalid destination {}: {}'.format(
rule['dsthost'], str(e)))
addrinfo = self.resolve_dns(rule['dsthost'], family)
dsthosts = set(item[4][0] + fullmask for item in addrinfo)
nft_rule += ' {} daddr {{ {} }}'.format(ip_match,
', '.join(set(item[4][0] + fullmask for item in addrinfo)))
', '.join(dsthosts))
ret_dns[rule['dsthost']] = dsthosts
if 'dstports' in rule:
dstports = rule['dstports']
@ -677,7 +725,7 @@ class NftablesWorker(FirewallWorker):
table='qubes-firewall',
chain=chain,
rules='\n '.join(nft_rules)
))
), ret_dns)
def apply_rules_family(self, source, rules, family):
"""
@ -694,7 +742,9 @@ class NftablesWorker(FirewallWorker):
if chain not in self.chains[family]:
self.create_chain(source, chain, family)
self.run_nft(self.prepare_rules(chain, rules, family))
(nft, dns) = self.prepare_rules(chain, rules, family)
self.run_nft(nft)
self.update_dns_info(source, dns)
def apply_rules(self, source, rules):
if self.is_ip6(source):

View File

@ -1,5 +1,6 @@
import logging
import operator
import re
from unittest import TestCase
from unittest.mock import patch
@ -29,6 +30,17 @@ class DummyQubesDB(object):
except KeyError:
return None
def rm(self, path):
if path.endswith('/'):
for key in list(self.entries):
if key.startswith(path):
self.entries.pop(key)
else:
self.entries.pop(path)
def write(self, path, val):
self.entries[path] = val
def multiread(self, prefix):
result = {}
for key, value in self.entries.items():
@ -154,8 +166,31 @@ class NftablesWorker(qubesagent.firewall.NftablesWorker):
else:
return ['2001::1', '2001::2']
class WorkerCommon(object):
def assertPrepareRulesDnsRet(self, dns_ret, expected_domain, family):
self.assertEqual(dns_ret.keys(), {expected_domain})
self.assertIsInstance(dns_ret[expected_domain], set)
if family == 4:
self.assertIsNotNone(re.match('^\d+\.\d+\.\d+\.\d+/32$',
dns_ret[expected_domain].pop()))
elif family == 6:
self.assertIsNotNone(re.match('^[0-9a-f:]+/\d+$',
dns_ret[expected_domain].pop()))
else:
raise ValueError()
class TestIptablesWorker(TestCase):
def test_701_dns_info(self):
rules = [
{'action': 'accept', 'proto': 'tcp',
'dstports': '80-80', 'dsthost': 'ripe.net'},
{'action': 'drop'},
]
self.obj.apply_rules('10.137.0.1', rules)
self.assertIsNotNone(self.obj.qdb.read('/dns/10.137.0.1/ripe.net'))
self.obj.apply_rules('10.137.0.1', [{'action': 'drop'}])
self.assertIsNone(self.obj.qdb.read('/dns/10.137.0.1/ripe.net'))
class TestIptablesWorker(TestCase, WorkerCommon):
def setUp(self):
super(TestIptablesWorker, self).setUp()
self.obj = IptablesWorker()
@ -212,8 +247,9 @@ class TestIptablesWorker(TestCase):
"--reject-with icmp-admin-prohibited\n"
"COMMIT\n"
)
self.assertEqual(self.obj.prepare_rules('chain', rules, 4),
expected_iptables)
ret = self.obj.prepare_rules('chain', rules, 4)
self.assertEqual(ret[0], expected_iptables)
self.assertPrepareRulesDnsRet(ret[1], 'yum.qubes-os.org', 4)
with self.assertRaises(qubesagent.firewall.RuleParseError):
self.obj.prepare_rules('chain', [{'unknown': 'xxx'}], 4)
with self.assertRaises(qubesagent.firewall.RuleParseError):
@ -250,8 +286,9 @@ class TestIptablesWorker(TestCase):
"--reject-with icmp6-adm-prohibited\n"
"COMMIT\n"
)
self.assertEqual(self.obj.prepare_rules('chain', rules, 6),
expected_iptables)
ret = self.obj.prepare_rules('chain', rules, 6)
self.assertEqual(ret[0], expected_iptables)
self.assertPrepareRulesDnsRet(ret[1], 'ripe.net', 6)
def test_004_apply_rules4(self):
rules = [{'action': 'accept'}]
@ -263,7 +300,7 @@ class TestIptablesWorker(TestCase):
['-I', 'QBS-FORWARD', '-s', '10.137.0.1', '-j', chain],
['-F', chain]])
self.assertEqual(self.obj.loaded_iptables[4],
self.obj.prepare_rules(chain, rules, 4))
self.obj.prepare_rules(chain, rules, 4)[0])
self.assertEqual(self.obj.called_commands[6], [])
self.assertIsNone(self.obj.loaded_iptables[6])
@ -277,7 +314,7 @@ class TestIptablesWorker(TestCase):
['-I', 'QBS-FORWARD', '-s', '2000::a', '-j', chain],
['-F', chain]])
self.assertEqual(self.obj.loaded_iptables[6],
self.obj.prepare_rules(chain, rules, 6))
self.obj.prepare_rules(chain, rules, 6)[0])
self.assertEqual(self.obj.called_commands[4], [])
self.assertIsNone(self.obj.loaded_iptables[4])
@ -372,8 +409,7 @@ class TestIptablesWorker(TestCase):
['-t', 'mangle', '-F', 'QBS-POSTROUTING'],
])
class TestNftablesWorker(TestCase):
class TestNftablesWorker(TestCase, WorkerCommon):
def setUp(self):
super(TestNftablesWorker, self).setUp()
self.obj = NftablesWorker()
@ -440,8 +476,9 @@ class TestNftablesWorker(TestCase):
' }\n'
'}\n'
)
self.assertEqual(self.obj.prepare_rules('chain', rules, 4),
expected_nft)
ret = self.obj.prepare_rules('chain', rules, 4)
self.assertEqual(ret[0], expected_nft)
self.assertPrepareRulesDnsRet(ret[1], 'yum.qubes-os.org', 4)
with self.assertRaises(qubesagent.firewall.RuleParseError):
self.obj.prepare_rules('chain', [{'unknown': 'xxx'}], 4)
with self.assertRaises(qubesagent.firewall.RuleParseError):
@ -477,8 +514,9 @@ class TestNftablesWorker(TestCase):
' }\n'
'}\n'
)
self.assertEqual(self.obj.prepare_rules('chain', rules, 6),
expected_nft)
ret = self.obj.prepare_rules('chain', rules, 6)
self.assertEqual(ret[0], expected_nft)
self.assertPrepareRulesDnsRet(ret[1], 'ripe.net', 6)
def test_004_apply_rules4(self):
rules = [{'action': 'accept'}]
@ -486,7 +524,7 @@ class TestNftablesWorker(TestCase):
self.obj.apply_rules('10.137.0.1', rules)
self.assertEqual(self.obj.loaded_rules,
[self.expected_create_chain('ip', '10.137.0.1', chain),
self.obj.prepare_rules(chain, rules, 4),
self.obj.prepare_rules(chain, rules, 4)[0],
])
def test_005_apply_rules6(self):
@ -495,7 +533,7 @@ class TestNftablesWorker(TestCase):
self.obj.apply_rules('2000::a', rules)
self.assertEqual(self.obj.loaded_rules,
[self.expected_create_chain('ip6', '2000::a', chain),
self.obj.prepare_rules(chain, rules, 6),
self.obj.prepare_rules(chain, rules, 6)[0],
])
def test_006_init(self):
@ -647,11 +685,17 @@ class TestFirewallWorker(TestCase):
def test_handle_addr(self):
self.obj.handle_addr('10.137.0.2')
self.assertEqual(self.obj.rules['10.137.0.2'], [{'action': 'accept'}])
self.assertEqual(self.obj.qdb.entries['/qubes-firewall-handled/10.137.0.2'], '1')
self.obj.handle_addr('10.137.0.2')
self.assertEqual(self.obj.rules['10.137.0.2'], [{'action': 'accept'}])
self.assertEqual(self.obj.qdb.entries['/qubes-firewall-handled/10.137.0.2'], '2')
# fallback to block all
self.obj.handle_addr('10.137.0.3')
self.assertEqual(self.obj.rules['10.137.0.3'], [{'action': 'drop'}])
self.assertEqual(self.obj.qdb.entries['/qubes-firewall-handled/10.137.0.3'], '1')
self.obj.handle_addr('10.137.0.4')
self.assertEqual(self.obj.rules['10.137.0.4'], [{'action': 'drop'}])
self.assertEqual(self.obj.qdb.entries['/qubes-firewall-handled/10.137.0.4'], '1')
@patch('os.path.isfile')
@patch('os.access')