firewall: refactor to remove side effects from prepare_rules()

This commit is contained in:
3hhh 2021-05-15 10:19:18 +02:00
parent 2a5af195f1
commit 196014831b
No known key found for this signature in database
GPG Key ID: EB03A691DB2F0833

View File

@ -127,23 +127,22 @@ class FirewallWorker(object):
rules.append({'action': policy}) rules.append({'action': policy})
return rules return rules
def write_dns_info(self, source, host, hostaddrs): def update_dns_info(self, source, dns):
""" """
Write resolved DNS addresses back to QubesDB. This can be useful Write resolved DNS addresses back to QubesDB. This can be useful
for the user or DNS applications to pin these DNS addresses to the for the user or DNS applications to pin these DNS addresses to the
IPs resolved during firewall setup. IPs resolved during firewall setup.
:param source: VM IP :param source: VM IP
:param host: hostname :param dns: dict: hostname -> set of IP addresses
:param hostaddrs: set of IP addresses :host: was resolved to
:return: None :return: None
""" """
self.qdb.write('/dns/{}/{}'.format(source, host), str(hostaddrs)) #clear old info
def clear_dns_info(self, source):
""" Clear all DNS info for the given VM IP."""
self.qdb.rm('/dns/{}/'.format(source)) self.qdb.rm('/dns/{}/'.format(source))
for host, hostaddrs in dns.items():
self.qdb.write('/dns/{}/{}'.format(source, host), str(hostaddrs))
def list_targets(self): def list_targets(self):
return set(t.split('/')[2] for t in self.qdb.list('/qubes-firewall/')) return set(t.split('/')[2] for t in self.qdb.list('/qubes-firewall/'))
@ -281,16 +280,16 @@ class IptablesWorker(FirewallWorker):
['-I', 'QBS-FORWARD', '-s', addr, '-j', chain]) ['-I', 'QBS-FORWARD', '-s', addr, '-j', chain])
self.chains[family].add(chain) self.chains[family].add(chain)
def prepare_rules(self, chain, rules, family, source): def prepare_rules(self, chain, rules, family):
""" """
Helper function to translate rules list into input for iptables-restore Helper function to translate rules list into input for iptables-restore
:param chain: name of the chain to put rules into :param chain: name of the chain to put rules into
:param rules: list of rules :param rules: list of rules
:param family: address family (4 or 6) :param family: address family (4 or 6)
:param source: source for which to apply the chain :return: tuple: (input for iptables-restore, dict of DNS records resolved
:return: input for iptables-restore during execution)
:rtype: str :rtype: (str, dict)
""" """
iptables = "*filter\n" iptables = "*filter\n"
@ -299,7 +298,7 @@ class IptablesWorker(FirewallWorker):
dns = list(addr + fullmask for addr in self.dns_addresses(family)) dns = list(addr + fullmask for addr in self.dns_addresses(family))
self.clear_dns_info(source) ret_dns = {}
for rule in rules: for rule in rules:
unsupported_opts = set(rule.keys()).difference( unsupported_opts = set(rule.keys()).difference(
@ -332,7 +331,7 @@ class IptablesWorker(FirewallWorker):
raise RuleParseError('Failed to resolve {}: {}'.format( raise RuleParseError('Failed to resolve {}: {}'.format(
rule['dsthost'], str(e))) rule['dsthost'], str(e)))
dsthosts = set(item[4][0] + fullmask for item in addrinfo) dsthosts = set(item[4][0] + fullmask for item in addrinfo)
self.write_dns_info(source, rule['dsthost'], dsthosts) ret_dns[rule['dsthost']] = dsthosts
else: else:
dsthosts = None dsthosts = None
@ -395,7 +394,7 @@ class IptablesWorker(FirewallWorker):
iptables += ipt_rule iptables += ipt_rule
iptables += 'COMMIT\n' iptables += 'COMMIT\n'
return iptables return (iptables, ret_dns)
def apply_rules_family(self, source, rules, family): def apply_rules_family(self, source, rules, family):
""" """
@ -412,7 +411,7 @@ class IptablesWorker(FirewallWorker):
if chain not in self.chains[family]: if chain not in self.chains[family]:
self.create_chain(source, chain, family) self.create_chain(source, chain, family)
iptables = self.prepare_rules(chain, rules, family, source) (iptables, dns) = self.prepare_rules(chain, rules, family)
try: try:
self.run_ipt(family, ['-F', chain]) self.run_ipt(family, ['-F', chain])
p = self.run_ipt_restore(family, ['-n']) p = self.run_ipt_restore(family, ['-n'])
@ -420,6 +419,7 @@ class IptablesWorker(FirewallWorker):
if p.returncode != 0: if p.returncode != 0:
raise RuleApplyError( raise RuleApplyError(
'iptables-restore failed: {}'.format(output)) 'iptables-restore failed: {}'.format(output))
self.update_dns_info(source, dns)
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
raise RuleApplyError('\'iptables -F {}\' failed: {}'.format( raise RuleApplyError('\'iptables -F {}\' failed: {}'.format(
chain, e.output)) chain, e.output))
@ -580,16 +580,16 @@ class NftablesWorker(FirewallWorker):
self.run_nft(nft_input) self.run_nft(nft_input)
def prepare_rules(self, chain, rules, family, source): def prepare_rules(self, chain, rules, family):
""" """
Helper function to translate rules list into input for iptables-restore Helper function to translate rules list into input for nft
:param chain: name of the chain to put rules into :param chain: name of the chain to put rules into
:param rules: list of rules :param rules: list of rules
:param family: address family (4 or 6) :param family: address family (4 or 6)
:param source: source for which to apply the chain :return: tuple: (input for nft, dict of DNS records resolved
:return: input for iptables-restore during execution)
:rtype: str :rtype: (str, dict)
""" """
assert family in (4, 6) assert family in (4, 6)
@ -600,7 +600,7 @@ class NftablesWorker(FirewallWorker):
dns = list(addr + fullmask for addr in self.dns_addresses(family)) dns = list(addr + fullmask for addr in self.dns_addresses(family))
self.clear_dns_info(source) ret_dns = {}
for rule in rules: for rule in rules:
unsupported_opts = set(rule.keys()).difference( unsupported_opts = set(rule.keys()).difference(
@ -649,7 +649,7 @@ class NftablesWorker(FirewallWorker):
dsthosts = set(item[4][0] + fullmask for item in addrinfo) dsthosts = set(item[4][0] + fullmask for item in addrinfo)
nft_rule += ' {} daddr {{ {} }}'.format(ip_match, nft_rule += ' {} daddr {{ {} }}'.format(ip_match,
', '.join(dsthosts)) ', '.join(dsthosts))
self.write_dns_info(source, rule['dsthost'], dsthosts) ret_dns[rule['dsthost']] = dsthosts
if 'dstports' in rule: if 'dstports' in rule:
dstports = rule['dstports'] dstports = rule['dstports']
@ -703,7 +703,7 @@ class NftablesWorker(FirewallWorker):
table='qubes-firewall', table='qubes-firewall',
chain=chain, chain=chain,
rules='\n '.join(nft_rules) rules='\n '.join(nft_rules)
)) ), ret_dns)
def apply_rules_family(self, source, rules, family): def apply_rules_family(self, source, rules, family):
""" """
@ -720,7 +720,9 @@ class NftablesWorker(FirewallWorker):
if chain not in self.chains[family]: if chain not in self.chains[family]:
self.create_chain(source, chain, family) self.create_chain(source, chain, family)
self.run_nft(self.prepare_rules(chain, rules, family, source)) (nft, dns) = self.prepare_rules(chain, rules, family)
self.run_nft(nft)
self.update_dns_info(source, dns)
def apply_rules(self, source, rules): def apply_rules(self, source, rules):
if self.is_ip6(source): if self.is_ip6(source):