diff --git a/qubesagent/firewall.py b/qubesagent/firewall.py index 7a5ffc0..1b9c331 100755 --- a/qubesagent/firewall.py +++ b/qubesagent/firewall.py @@ -127,6 +127,26 @@ class FirewallWorker(object): rules.append({'action': policy}) return rules + def resolve_dns(self, fqdn, family): + """ + Resolve the given FQDN via DNS. + :param fqdn: FQDN + :param family: 4 or 6 for IPv4 or IPv6 + :return: see socket.getaddrinfo() + :raises: RuleParseError + """ + try: + addrinfo = socket.getaddrinfo(fqdn, None, + (socket.AF_INET6 if family == 6 else socket.AF_INET)) + except socket.gaierror as e: + raise RuleParseError('Failed to resolve {}: {}'.format( + fqdn, str(e))) + except UnicodeError as e: + raise RuleParseError('Invalid destination {}: {}'.format( + fqdn, str(e))) + return addrinfo + + def update_dns_info(self, source, dns): """ Write resolved DNS addresses back to QubesDB. This can be useful @@ -339,12 +359,7 @@ class IptablesWorker(FirewallWorker): elif 'dst6' in rule: dsthosts = [rule['dst6']] elif 'dsthost' in rule: - try: - addrinfo = socket.getaddrinfo(rule['dsthost'], None, - (socket.AF_INET6 if family == 6 else socket.AF_INET)) - except socket.gaierror as e: - raise RuleParseError('Failed to resolve {}: {}'.format( - rule['dsthost'], str(e))) + addrinfo = self.resolve_dns(rule['dsthost'], family) dsthosts = set(item[4][0] + fullmask for item in addrinfo) ret_dns[rule['dsthost']] = dsthosts else: @@ -652,15 +667,7 @@ class NftablesWorker(FirewallWorker): elif 'dst6' in rule: nft_rule += ' ip6 daddr {}'.format(rule['dst6']) elif 'dsthost' in rule: - try: - addrinfo = socket.getaddrinfo(rule['dsthost'], None, - (socket.AF_INET6 if family == 6 else socket.AF_INET)) - except socket.gaierror as e: - raise RuleParseError('Failed to resolve {}: {}'.format( - rule['dsthost'], str(e))) - except UnicodeError as e: - raise RuleParseError('Invalid destination {}: {}'.format( - rule['dsthost'], str(e))) + addrinfo = self.resolve_dns(rule['dsthost'], family) dsthosts = set(item[4][0] + fullmask for item in addrinfo) nft_rule += ' {} daddr {{ {} }}'.format(ip_match, ', '.join(dsthosts))