Browse Source

firewall: refactor to remove side effects from prepare_rules()

3hhh 2 years ago
parent
commit
196014831b
1 changed files with 26 additions and 24 deletions
  1. 26 24
      qubesagent/firewall.py

+ 26 - 24
qubesagent/firewall.py

@@ -127,23 +127,22 @@ class FirewallWorker(object):
         rules.append({'action': policy})
         return rules
 
-    def write_dns_info(self, source, host, hostaddrs):
+    def update_dns_info(self, source, dns):
         """
         Write resolved DNS addresses back to QubesDB. This can be useful
         for the user or DNS applications to pin these DNS addresses to the
         IPs resolved during firewall setup.
 
         :param source: VM IP
-        :param host: hostname
-        :param hostaddrs: set of IP addresses :host: was resolved to
+        :param dns: dict: hostname -> set of IP addresses
         :return: None
         """
-        self.qdb.write('/dns/{}/{}'.format(source, host), str(hostaddrs))
-
-    def clear_dns_info(self, source):
-        """ Clear all DNS info for the given VM IP."""
+        #clear old info
         self.qdb.rm('/dns/{}/'.format(source))
 
+        for host, hostaddrs in dns.items():
+            self.qdb.write('/dns/{}/{}'.format(source, host), str(hostaddrs))
+
     def list_targets(self):
         return set(t.split('/')[2] for t in self.qdb.list('/qubes-firewall/'))
 
@@ -281,16 +280,16 @@ class IptablesWorker(FirewallWorker):
             ['-I', 'QBS-FORWARD', '-s', addr, '-j', chain])
         self.chains[family].add(chain)
 
-    def prepare_rules(self, chain, rules, family, source):
+    def prepare_rules(self, chain, rules, family):
         """
         Helper function to translate rules list into input for iptables-restore
 
         :param chain: name of the chain to put rules into
         :param rules: list of rules
         :param family: address family (4 or 6)
-        :param source: source for which to apply the chain
-        :return: input for iptables-restore
-        :rtype: str
+        :return: tuple: (input for iptables-restore, dict of DNS records resolved
+                        during execution)
+        :rtype: (str, dict)
         """
 
         iptables = "*filter\n"
@@ -299,7 +298,7 @@ class IptablesWorker(FirewallWorker):
 
         dns = list(addr + fullmask for addr in self.dns_addresses(family))
 
-        self.clear_dns_info(source)
+        ret_dns = {}
 
         for rule in rules:
             unsupported_opts = set(rule.keys()).difference(
@@ -332,7 +331,7 @@ class IptablesWorker(FirewallWorker):
                     raise RuleParseError('Failed to resolve {}: {}'.format(
                         rule['dsthost'], str(e)))
                 dsthosts = set(item[4][0] + fullmask for item in addrinfo)
-                self.write_dns_info(source, rule['dsthost'], dsthosts)
+                ret_dns[rule['dsthost']] = dsthosts
             else:
                 dsthosts = None
 
@@ -395,7 +394,7 @@ class IptablesWorker(FirewallWorker):
                     iptables += ipt_rule
 
         iptables += 'COMMIT\n'
-        return iptables
+        return (iptables, ret_dns)
 
     def apply_rules_family(self, source, rules, family):
         """
@@ -412,7 +411,7 @@ class IptablesWorker(FirewallWorker):
         if chain not in self.chains[family]:
             self.create_chain(source, chain, family)
 
-        iptables = self.prepare_rules(chain, rules, family, source)
+        (iptables, dns) = self.prepare_rules(chain, rules, family)
         try:
             self.run_ipt(family, ['-F', chain])
             p = self.run_ipt_restore(family, ['-n'])
@@ -420,6 +419,7 @@ class IptablesWorker(FirewallWorker):
             if p.returncode != 0:
                 raise RuleApplyError(
                     'iptables-restore failed: {}'.format(output))
+            self.update_dns_info(source, dns)
         except subprocess.CalledProcessError as e:
             raise RuleApplyError('\'iptables -F {}\' failed: {}'.format(
                 chain, e.output))
@@ -580,16 +580,16 @@ class NftablesWorker(FirewallWorker):
 
         self.run_nft(nft_input)
 
-    def prepare_rules(self, chain, rules, family, source):
+    def prepare_rules(self, chain, rules, family):
         """
-        Helper function to translate rules list into input for iptables-restore
+        Helper function to translate rules list into input for nft
 
         :param chain: name of the chain to put rules into
         :param rules: list of rules
         :param family: address family (4 or 6)
-        :param source: source for which to apply the chain
-        :return: input for iptables-restore
-        :rtype: str
+        :return: tuple: (input for nft, dict of DNS records resolved
+                        during execution)
+        :rtype: (str, dict)
         """
 
         assert family in (4, 6)
@@ -600,7 +600,7 @@ class NftablesWorker(FirewallWorker):
 
         dns = list(addr + fullmask for addr in self.dns_addresses(family))
 
-        self.clear_dns_info(source)
+        ret_dns = {}
 
         for rule in rules:
             unsupported_opts = set(rule.keys()).difference(
@@ -649,7 +649,7 @@ class NftablesWorker(FirewallWorker):
                 dsthosts = set(item[4][0] + fullmask for item in addrinfo)
                 nft_rule += ' {} daddr {{ {} }}'.format(ip_match,
                     ', '.join(dsthosts))
-                self.write_dns_info(source, rule['dsthost'], dsthosts)
+                ret_dns[rule['dsthost']] = dsthosts
 
             if 'dstports' in rule:
                 dstports = rule['dstports']
@@ -703,7 +703,7 @@ class NftablesWorker(FirewallWorker):
                 table='qubes-firewall',
                 chain=chain,
                 rules='\n   '.join(nft_rules)
-            ))
+            ), ret_dns)
 
     def apply_rules_family(self, source, rules, family):
         """
@@ -720,7 +720,9 @@ class NftablesWorker(FirewallWorker):
         if chain not in self.chains[family]:
             self.create_chain(source, chain, family)
 
-        self.run_nft(self.prepare_rules(chain, rules, family, source))
+        (nft, dns) = self.prepare_rules(chain, rules, family)
+        self.run_nft(nft)
+        self.update_dns_info(source, dns)
 
     def apply_rules(self, source, rules):
         if self.is_ip6(source):