Browse Source

firewall: refactor to remove side effects from prepare_rules()

3hhh 3 years ago
parent
commit
196014831b
1 changed files with 26 additions and 24 deletions
  1. 26 24
      qubesagent/firewall.py

+ 26 - 24
qubesagent/firewall.py

@@ -127,23 +127,22 @@ class FirewallWorker(object):
         rules.append({'action': policy})
         rules.append({'action': policy})
         return rules
         return rules
 
 
-    def write_dns_info(self, source, host, hostaddrs):
+    def update_dns_info(self, source, dns):
         """
         """
         Write resolved DNS addresses back to QubesDB. This can be useful
         Write resolved DNS addresses back to QubesDB. This can be useful
         for the user or DNS applications to pin these DNS addresses to the
         for the user or DNS applications to pin these DNS addresses to the
         IPs resolved during firewall setup.
         IPs resolved during firewall setup.
 
 
         :param source: VM IP
         :param source: VM IP
-        :param host: hostname
-        :param hostaddrs: set of IP addresses :host: was resolved to
+        :param dns: dict: hostname -> set of IP addresses
         :return: None
         :return: None
         """
         """
-        self.qdb.write('/dns/{}/{}'.format(source, host), str(hostaddrs))
-
-    def clear_dns_info(self, source):
-        """ Clear all DNS info for the given VM IP."""
+        #clear old info
         self.qdb.rm('/dns/{}/'.format(source))
         self.qdb.rm('/dns/{}/'.format(source))
 
 
+        for host, hostaddrs in dns.items():
+            self.qdb.write('/dns/{}/{}'.format(source, host), str(hostaddrs))
+
     def list_targets(self):
     def list_targets(self):
         return set(t.split('/')[2] for t in self.qdb.list('/qubes-firewall/'))
         return set(t.split('/')[2] for t in self.qdb.list('/qubes-firewall/'))
 
 
@@ -281,16 +280,16 @@ class IptablesWorker(FirewallWorker):
             ['-I', 'QBS-FORWARD', '-s', addr, '-j', chain])
             ['-I', 'QBS-FORWARD', '-s', addr, '-j', chain])
         self.chains[family].add(chain)
         self.chains[family].add(chain)
 
 
-    def prepare_rules(self, chain, rules, family, source):
+    def prepare_rules(self, chain, rules, family):
         """
         """
         Helper function to translate rules list into input for iptables-restore
         Helper function to translate rules list into input for iptables-restore
 
 
         :param chain: name of the chain to put rules into
         :param chain: name of the chain to put rules into
         :param rules: list of rules
         :param rules: list of rules
         :param family: address family (4 or 6)
         :param family: address family (4 or 6)
-        :param source: source for which to apply the chain
-        :return: input for iptables-restore
-        :rtype: str
+        :return: tuple: (input for iptables-restore, dict of DNS records resolved
+                        during execution)
+        :rtype: (str, dict)
         """
         """
 
 
         iptables = "*filter\n"
         iptables = "*filter\n"
@@ -299,7 +298,7 @@ class IptablesWorker(FirewallWorker):
 
 
         dns = list(addr + fullmask for addr in self.dns_addresses(family))
         dns = list(addr + fullmask for addr in self.dns_addresses(family))
 
 
-        self.clear_dns_info(source)
+        ret_dns = {}
 
 
         for rule in rules:
         for rule in rules:
             unsupported_opts = set(rule.keys()).difference(
             unsupported_opts = set(rule.keys()).difference(
@@ -332,7 +331,7 @@ class IptablesWorker(FirewallWorker):
                     raise RuleParseError('Failed to resolve {}: {}'.format(
                     raise RuleParseError('Failed to resolve {}: {}'.format(
                         rule['dsthost'], str(e)))
                         rule['dsthost'], str(e)))
                 dsthosts = set(item[4][0] + fullmask for item in addrinfo)
                 dsthosts = set(item[4][0] + fullmask for item in addrinfo)
-                self.write_dns_info(source, rule['dsthost'], dsthosts)
+                ret_dns[rule['dsthost']] = dsthosts
             else:
             else:
                 dsthosts = None
                 dsthosts = None
 
 
@@ -395,7 +394,7 @@ class IptablesWorker(FirewallWorker):
                     iptables += ipt_rule
                     iptables += ipt_rule
 
 
         iptables += 'COMMIT\n'
         iptables += 'COMMIT\n'
-        return iptables
+        return (iptables, ret_dns)
 
 
     def apply_rules_family(self, source, rules, family):
     def apply_rules_family(self, source, rules, family):
         """
         """
@@ -412,7 +411,7 @@ class IptablesWorker(FirewallWorker):
         if chain not in self.chains[family]:
         if chain not in self.chains[family]:
             self.create_chain(source, chain, family)
             self.create_chain(source, chain, family)
 
 
-        iptables = self.prepare_rules(chain, rules, family, source)
+        (iptables, dns) = self.prepare_rules(chain, rules, family)
         try:
         try:
             self.run_ipt(family, ['-F', chain])
             self.run_ipt(family, ['-F', chain])
             p = self.run_ipt_restore(family, ['-n'])
             p = self.run_ipt_restore(family, ['-n'])
@@ -420,6 +419,7 @@ class IptablesWorker(FirewallWorker):
             if p.returncode != 0:
             if p.returncode != 0:
                 raise RuleApplyError(
                 raise RuleApplyError(
                     'iptables-restore failed: {}'.format(output))
                     'iptables-restore failed: {}'.format(output))
+            self.update_dns_info(source, dns)
         except subprocess.CalledProcessError as e:
         except subprocess.CalledProcessError as e:
             raise RuleApplyError('\'iptables -F {}\' failed: {}'.format(
             raise RuleApplyError('\'iptables -F {}\' failed: {}'.format(
                 chain, e.output))
                 chain, e.output))
@@ -580,16 +580,16 @@ class NftablesWorker(FirewallWorker):
 
 
         self.run_nft(nft_input)
         self.run_nft(nft_input)
 
 
-    def prepare_rules(self, chain, rules, family, source):
+    def prepare_rules(self, chain, rules, family):
         """
         """
-        Helper function to translate rules list into input for iptables-restore
+        Helper function to translate rules list into input for nft
 
 
         :param chain: name of the chain to put rules into
         :param chain: name of the chain to put rules into
         :param rules: list of rules
         :param rules: list of rules
         :param family: address family (4 or 6)
         :param family: address family (4 or 6)
-        :param source: source for which to apply the chain
-        :return: input for iptables-restore
-        :rtype: str
+        :return: tuple: (input for nft, dict of DNS records resolved
+                        during execution)
+        :rtype: (str, dict)
         """
         """
 
 
         assert family in (4, 6)
         assert family in (4, 6)
@@ -600,7 +600,7 @@ class NftablesWorker(FirewallWorker):
 
 
         dns = list(addr + fullmask for addr in self.dns_addresses(family))
         dns = list(addr + fullmask for addr in self.dns_addresses(family))
 
 
-        self.clear_dns_info(source)
+        ret_dns = {}
 
 
         for rule in rules:
         for rule in rules:
             unsupported_opts = set(rule.keys()).difference(
             unsupported_opts = set(rule.keys()).difference(
@@ -649,7 +649,7 @@ class NftablesWorker(FirewallWorker):
                 dsthosts = set(item[4][0] + fullmask for item in addrinfo)
                 dsthosts = set(item[4][0] + fullmask for item in addrinfo)
                 nft_rule += ' {} daddr {{ {} }}'.format(ip_match,
                 nft_rule += ' {} daddr {{ {} }}'.format(ip_match,
                     ', '.join(dsthosts))
                     ', '.join(dsthosts))
-                self.write_dns_info(source, rule['dsthost'], dsthosts)
+                ret_dns[rule['dsthost']] = dsthosts
 
 
             if 'dstports' in rule:
             if 'dstports' in rule:
                 dstports = rule['dstports']
                 dstports = rule['dstports']
@@ -703,7 +703,7 @@ class NftablesWorker(FirewallWorker):
                 table='qubes-firewall',
                 table='qubes-firewall',
                 chain=chain,
                 chain=chain,
                 rules='\n   '.join(nft_rules)
                 rules='\n   '.join(nft_rules)
-            ))
+            ), ret_dns)
 
 
     def apply_rules_family(self, source, rules, family):
     def apply_rules_family(self, source, rules, family):
         """
         """
@@ -720,7 +720,9 @@ class NftablesWorker(FirewallWorker):
         if chain not in self.chains[family]:
         if chain not in self.chains[family]:
             self.create_chain(source, chain, family)
             self.create_chain(source, chain, family)
 
 
-        self.run_nft(self.prepare_rules(chain, rules, family, source))
+        (nft, dns) = self.prepare_rules(chain, rules, family)
+        self.run_nft(nft)
+        self.update_dns_info(source, dns)
 
 
     def apply_rules(self, source, rules):
     def apply_rules(self, source, rules):
         if self.is_ip6(source):
         if self.is_ip6(source):