Browse Source

Merge remote-tracking branch 'origin/pr/303'

* origin/pr/303:
  firewall: prefer - over _ for QubesDB path
  firewall: put DNS resolving into its own function
  firewall: start watches before initial load
  tests/firewall: added test for /dns/[ip]/[domain] info
  tests/firewall: some code refactoring
  add some checks for QubesDB /qubes-firewall_handled/[ip]
  firewall: adjust tests to the new tuple returned by prepare_rules()
  firewall: mark an IP as handled in /qubes-firewall_handled/[ip] after each handling iteration
  mock qubesdb.rm()
  firewall: refactor to remove side effects from prepare_rules()
  Export DNS information obtained during firewall setup to QubesDB
Marek Marczykowski-Górecki 2 years ago
parent
commit
39a010445e
2 changed files with 137 additions and 43 deletions
  1. 78 28
      qubesagent/firewall.py
  2. 59 15
      qubesagent/test_firewall.py

+ 78 - 28
qubesagent/firewall.py

@@ -127,6 +127,55 @@ class FirewallWorker(object):
         rules.append({'action': policy})
         return rules
 
+    def resolve_dns(self, fqdn, family):
+        """
+        Resolve the given FQDN via DNS.
+        :param fqdn: FQDN
+        :param family: 4 or 6 for IPv4 or IPv6
+        :return: see socket.getaddrinfo()
+        :raises: RuleParseError
+        """
+        try:
+            addrinfo = socket.getaddrinfo(fqdn, None,
+                (socket.AF_INET6 if family == 6 else socket.AF_INET))
+        except socket.gaierror as e:
+            raise RuleParseError('Failed to resolve {}: {}'.format(
+                fqdn, str(e)))
+        except UnicodeError as e:
+            raise RuleParseError('Invalid destination {}: {}'.format(
+                fqdn, str(e)))
+        return addrinfo
+
+
+    def update_dns_info(self, source, dns):
+        """
+        Write resolved DNS addresses back to QubesDB. This can be useful
+        for the user or DNS applications to pin these DNS addresses to the
+        IPs resolved during firewall setup.
+
+        :param source: VM IP
+        :param dns: dict: hostname -> set of IP addresses
+        :return: None
+        """
+        #clear old info
+        self.qdb.rm('/dns/{}/'.format(source))
+
+        for host, hostaddrs in dns.items():
+            self.qdb.write('/dns/{}/{}'.format(source, host), str(hostaddrs))
+
+    def update_handled(self, addr):
+        """
+        Update the QubesDB count of how often the given address was handled.
+        User applications may watch these paths for count increases to remain
+        up to date with QubesDB changes.
+        """
+        cnt = self.qdb.read('/qubes-firewall-handled/{}'.format(addr))
+        try:
+            cnt = int(cnt)
+        except (TypeError, ValueError):
+            cnt = 0
+        self.qdb.write('/qubes-firewall-handled/{}'.format(addr), str(cnt+1))
+
     def list_targets(self):
         return set(t.split('/')[2] for t in self.qdb.list('/qubes-firewall/'))
 
@@ -163,6 +212,8 @@ class FirewallWorker(object):
                 self.log_error(
                     'Failed to block traffic for {}'.format(addr))
 
+        self.update_handled(addr)
+
     @staticmethod
     def dns_addresses(family=None):
         with open('/etc/resolv.conf') as resolv:
@@ -180,14 +231,14 @@ class FirewallWorker(object):
         self.run_firewall_dir()
         self.run_user_script()
         self.sd_notify('READY=1')
+        self.qdb.watch('/qubes-firewall/')
+        self.qdb.watch('/connected-ips')
+        self.qdb.watch('/connected-ips6')
         # initial load
         for source_addr in self.list_targets():
             self.handle_addr(source_addr)
         self.update_connected_ips(4)
         self.update_connected_ips(6)
-        self.qdb.watch('/qubes-firewall/')
-        self.qdb.watch('/connected-ips')
-        self.qdb.watch('/connected-ips6')
         try:
             for watch_path in iter(self.qdb.read_watch, None):
                 if watch_path == '/connected-ips':
@@ -271,8 +322,9 @@ class IptablesWorker(FirewallWorker):
         :param chain: name of the chain to put rules into
         :param rules: list of rules
         :param family: address family (4 or 6)
-        :return: input for iptables-restore
-        :rtype: str
+        :return: tuple: (input for iptables-restore, dict of DNS records resolved
+                        during execution)
+        :rtype: (str, dict)
         """
 
         iptables = "*filter\n"
@@ -281,6 +333,8 @@ class IptablesWorker(FirewallWorker):
 
         dns = list(addr + fullmask for addr in self.dns_addresses(family))
 
+        ret_dns = {}
+
         for rule in rules:
             unsupported_opts = set(rule.keys()).difference(
                 set(self.supported_rule_opts))
@@ -305,13 +359,9 @@ class IptablesWorker(FirewallWorker):
             elif 'dst6' in rule:
                 dsthosts = [rule['dst6']]
             elif 'dsthost' in rule:
-                try:
-                    addrinfo = socket.getaddrinfo(rule['dsthost'], None,
-                        (socket.AF_INET6 if family == 6 else socket.AF_INET))
-                except socket.gaierror as e:
-                    raise RuleParseError('Failed to resolve {}: {}'.format(
-                        rule['dsthost'], str(e)))
+                addrinfo = self.resolve_dns(rule['dsthost'], family)
                 dsthosts = set(item[4][0] + fullmask for item in addrinfo)
+                ret_dns[rule['dsthost']] = dsthosts
             else:
                 dsthosts = None
 
@@ -374,7 +424,7 @@ class IptablesWorker(FirewallWorker):
                     iptables += ipt_rule
 
         iptables += 'COMMIT\n'
-        return iptables
+        return (iptables, ret_dns)
 
     def apply_rules_family(self, source, rules, family):
         """
@@ -391,7 +441,7 @@ class IptablesWorker(FirewallWorker):
         if chain not in self.chains[family]:
             self.create_chain(source, chain, family)
 
-        iptables = self.prepare_rules(chain, rules, family)
+        (iptables, dns) = self.prepare_rules(chain, rules, family)
         try:
             self.run_ipt(family, ['-F', chain])
             p = self.run_ipt_restore(family, ['-n'])
@@ -399,6 +449,7 @@ class IptablesWorker(FirewallWorker):
             if p.returncode != 0:
                 raise RuleApplyError(
                     'iptables-restore failed: {}'.format(output))
+            self.update_dns_info(source, dns)
         except subprocess.CalledProcessError as e:
             raise RuleApplyError('\'iptables -F {}\' failed: {}'.format(
                 chain, e.output))
@@ -561,13 +612,14 @@ class NftablesWorker(FirewallWorker):
 
     def prepare_rules(self, chain, rules, family):
         """
-        Helper function to translate rules list into input for iptables-restore
+        Helper function to translate rules list into input for nft
 
         :param chain: name of the chain to put rules into
         :param rules: list of rules
         :param family: address family (4 or 6)
-        :return: input for iptables-restore
-        :rtype: str
+        :return: tuple: (input for nft, dict of DNS records resolved
+                        during execution)
+        :rtype: (str, dict)
         """
 
         assert family in (4, 6)
@@ -578,6 +630,8 @@ class NftablesWorker(FirewallWorker):
 
         dns = list(addr + fullmask for addr in self.dns_addresses(family))
 
+        ret_dns = {}
+
         for rule in rules:
             unsupported_opts = set(rule.keys()).difference(
                 set(self.supported_rule_opts))
@@ -613,17 +667,11 @@ class NftablesWorker(FirewallWorker):
             elif 'dst6' in rule:
                 nft_rule += ' ip6 daddr {}'.format(rule['dst6'])
             elif 'dsthost' in rule:
-                try:
-                    addrinfo = socket.getaddrinfo(rule['dsthost'], None,
-                        (socket.AF_INET6 if family == 6 else socket.AF_INET))
-                except socket.gaierror as e:
-                    raise RuleParseError('Failed to resolve {}: {}'.format(
-                        rule['dsthost'], str(e)))
-                except UnicodeError as e:
-                    raise RuleParseError('Invalid destination {}: {}'.format(
-                        rule['dsthost'], str(e)))
+                addrinfo = self.resolve_dns(rule['dsthost'], family)
+                dsthosts = set(item[4][0] + fullmask for item in addrinfo)
                 nft_rule += ' {} daddr {{ {} }}'.format(ip_match,
-                    ', '.join(set(item[4][0] + fullmask for item in addrinfo)))
+                    ', '.join(dsthosts))
+                ret_dns[rule['dsthost']] = dsthosts
 
             if 'dstports' in rule:
                 dstports = rule['dstports']
@@ -677,7 +725,7 @@ class NftablesWorker(FirewallWorker):
                 table='qubes-firewall',
                 chain=chain,
                 rules='\n   '.join(nft_rules)
-            ))
+            ), ret_dns)
 
     def apply_rules_family(self, source, rules, family):
         """
@@ -694,7 +742,9 @@ class NftablesWorker(FirewallWorker):
         if chain not in self.chains[family]:
             self.create_chain(source, chain, family)
 
-        self.run_nft(self.prepare_rules(chain, rules, family))
+        (nft, dns) = self.prepare_rules(chain, rules, family)
+        self.run_nft(nft)
+        self.update_dns_info(source, dns)
 
     def apply_rules(self, source, rules):
         if self.is_ip6(source):

+ 59 - 15
qubesagent/test_firewall.py

@@ -1,5 +1,6 @@
 import logging
 import operator
+import re
 from unittest import TestCase
 from unittest.mock import patch
 
@@ -29,6 +30,17 @@ class DummyQubesDB(object):
         except KeyError:
             return None
 
+    def rm(self, path):
+        if path.endswith('/'):
+            for key in list(self.entries):
+                if key.startswith(path):
+                    self.entries.pop(key)
+        else:
+            self.entries.pop(path)
+
+    def write(self, path, val):
+        self.entries[path] = val
+
     def multiread(self, prefix):
         result = {}
         for key, value in self.entries.items():
@@ -154,8 +166,31 @@ class NftablesWorker(qubesagent.firewall.NftablesWorker):
         else:
             return ['2001::1', '2001::2']
 
+class WorkerCommon(object):
+    def assertPrepareRulesDnsRet(self, dns_ret, expected_domain, family):
+        self.assertEqual(dns_ret.keys(), {expected_domain})
+        self.assertIsInstance(dns_ret[expected_domain], set)
+        if family == 4:
+            self.assertIsNotNone(re.match('^\d+\.\d+\.\d+\.\d+/32$',
+                                dns_ret[expected_domain].pop()))
+        elif family == 6:
+            self.assertIsNotNone(re.match('^[0-9a-f:]+/\d+$',
+                                dns_ret[expected_domain].pop()))
+        else:
+            raise ValueError()
+
+    def test_701_dns_info(self):
+        rules = [
+            {'action': 'accept', 'proto': 'tcp',
+                'dstports': '80-80', 'dsthost': 'ripe.net'},
+            {'action': 'drop'},
+        ]
+        self.obj.apply_rules('10.137.0.1', rules)
+        self.assertIsNotNone(self.obj.qdb.read('/dns/10.137.0.1/ripe.net'))
+        self.obj.apply_rules('10.137.0.1', [{'action': 'drop'}])
+        self.assertIsNone(self.obj.qdb.read('/dns/10.137.0.1/ripe.net'))
 
-class TestIptablesWorker(TestCase):
+class TestIptablesWorker(TestCase, WorkerCommon):
     def setUp(self):
         super(TestIptablesWorker, self).setUp()
         self.obj = IptablesWorker()
@@ -212,8 +247,9 @@ class TestIptablesWorker(TestCase):
             "--reject-with icmp-admin-prohibited\n"
             "COMMIT\n"
         )
-        self.assertEqual(self.obj.prepare_rules('chain', rules, 4),
-            expected_iptables)
+        ret = self.obj.prepare_rules('chain', rules, 4)
+        self.assertEqual(ret[0], expected_iptables)
+        self.assertPrepareRulesDnsRet(ret[1], 'yum.qubes-os.org', 4)
         with self.assertRaises(qubesagent.firewall.RuleParseError):
             self.obj.prepare_rules('chain', [{'unknown': 'xxx'}], 4)
         with self.assertRaises(qubesagent.firewall.RuleParseError):
@@ -250,8 +286,9 @@ class TestIptablesWorker(TestCase):
             "--reject-with icmp6-adm-prohibited\n"
             "COMMIT\n"
         )
-        self.assertEqual(self.obj.prepare_rules('chain', rules, 6),
-            expected_iptables)
+        ret = self.obj.prepare_rules('chain', rules, 6)
+        self.assertEqual(ret[0], expected_iptables)
+        self.assertPrepareRulesDnsRet(ret[1], 'ripe.net', 6)
 
     def test_004_apply_rules4(self):
         rules = [{'action': 'accept'}]
@@ -263,7 +300,7 @@ class TestIptablesWorker(TestCase):
                 ['-I', 'QBS-FORWARD', '-s', '10.137.0.1', '-j', chain],
                 ['-F', chain]])
         self.assertEqual(self.obj.loaded_iptables[4],
-            self.obj.prepare_rules(chain, rules, 4))
+            self.obj.prepare_rules(chain, rules, 4)[0])
         self.assertEqual(self.obj.called_commands[6], [])
         self.assertIsNone(self.obj.loaded_iptables[6])
 
@@ -277,7 +314,7 @@ class TestIptablesWorker(TestCase):
                 ['-I', 'QBS-FORWARD', '-s', '2000::a', '-j', chain],
                 ['-F', chain]])
         self.assertEqual(self.obj.loaded_iptables[6],
-            self.obj.prepare_rules(chain, rules, 6))
+            self.obj.prepare_rules(chain, rules, 6)[0])
         self.assertEqual(self.obj.called_commands[4], [])
         self.assertIsNone(self.obj.loaded_iptables[4])
 
@@ -372,8 +409,7 @@ class TestIptablesWorker(TestCase):
             ['-t', 'mangle', '-F', 'QBS-POSTROUTING'],
         ])
 
-
-class TestNftablesWorker(TestCase):
+class TestNftablesWorker(TestCase, WorkerCommon):
     def setUp(self):
         super(TestNftablesWorker, self).setUp()
         self.obj = NftablesWorker()
@@ -440,8 +476,9 @@ class TestNftablesWorker(TestCase):
             '  }\n'
             '}\n'
         )
-        self.assertEqual(self.obj.prepare_rules('chain', rules, 4),
-            expected_nft)
+        ret = self.obj.prepare_rules('chain', rules, 4)
+        self.assertEqual(ret[0], expected_nft)
+        self.assertPrepareRulesDnsRet(ret[1], 'yum.qubes-os.org', 4)
         with self.assertRaises(qubesagent.firewall.RuleParseError):
             self.obj.prepare_rules('chain', [{'unknown': 'xxx'}], 4)
         with self.assertRaises(qubesagent.firewall.RuleParseError):
@@ -477,8 +514,9 @@ class TestNftablesWorker(TestCase):
             '  }\n'
             '}\n'
         )
-        self.assertEqual(self.obj.prepare_rules('chain', rules, 6),
-            expected_nft)
+        ret = self.obj.prepare_rules('chain', rules, 6)
+        self.assertEqual(ret[0], expected_nft)
+        self.assertPrepareRulesDnsRet(ret[1], 'ripe.net', 6)
 
     def test_004_apply_rules4(self):
         rules = [{'action': 'accept'}]
@@ -486,7 +524,7 @@ class TestNftablesWorker(TestCase):
         self.obj.apply_rules('10.137.0.1', rules)
         self.assertEqual(self.obj.loaded_rules,
             [self.expected_create_chain('ip', '10.137.0.1', chain),
-             self.obj.prepare_rules(chain, rules, 4),
+             self.obj.prepare_rules(chain, rules, 4)[0],
              ])
 
     def test_005_apply_rules6(self):
@@ -495,7 +533,7 @@ class TestNftablesWorker(TestCase):
         self.obj.apply_rules('2000::a', rules)
         self.assertEqual(self.obj.loaded_rules,
             [self.expected_create_chain('ip6', '2000::a', chain),
-             self.obj.prepare_rules(chain, rules, 6),
+             self.obj.prepare_rules(chain, rules, 6)[0],
              ])
 
     def test_006_init(self):
@@ -647,11 +685,17 @@ class TestFirewallWorker(TestCase):
     def test_handle_addr(self):
         self.obj.handle_addr('10.137.0.2')
         self.assertEqual(self.obj.rules['10.137.0.2'], [{'action': 'accept'}])
+        self.assertEqual(self.obj.qdb.entries['/qubes-firewall-handled/10.137.0.2'], '1')
+        self.obj.handle_addr('10.137.0.2')
+        self.assertEqual(self.obj.rules['10.137.0.2'], [{'action': 'accept'}])
+        self.assertEqual(self.obj.qdb.entries['/qubes-firewall-handled/10.137.0.2'], '2')
         # fallback to block all
         self.obj.handle_addr('10.137.0.3')
         self.assertEqual(self.obj.rules['10.137.0.3'], [{'action': 'drop'}])
+        self.assertEqual(self.obj.qdb.entries['/qubes-firewall-handled/10.137.0.3'], '1')
         self.obj.handle_addr('10.137.0.4')
         self.assertEqual(self.obj.rules['10.137.0.4'], [{'action': 'drop'}])
+        self.assertEqual(self.obj.qdb.entries['/qubes-firewall-handled/10.137.0.4'], '1')
 
     @patch('os.path.isfile')
     @patch('os.access')