firewall.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653
  1. # vim: fileencoding=utf-8
  2. #
  3. # The Qubes OS Project, https://www.qubes-os.org/
  4. #
  5. # Copyright (C) 2016
  6. # Marek Marczykowski-Górecki <marmarek@invisiblethingslab.com>
  7. #
  8. # This program is free software; you can redistribute it and/or modify
  9. # it under the terms of the GNU General Public License as published by
  10. # the Free Software Foundation; either version 2 of the License, or
  11. # (at your option) any later version.
  12. #
  13. # This program is distributed in the hope that it will be useful,
  14. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  15. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  16. # GNU General Public License for more details.
  17. #
  18. # You should have received a copy of the GNU General Public License along
  19. # with this program; if not, write to the Free Software Foundation, Inc.,
  20. # 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
  21. #
  22. import logging
  23. import os
  24. import socket
  25. import subprocess
  26. from distutils import spawn
  27. import daemon
  28. import qubesdb
  29. import sys
  30. import signal
  31. class RuleParseError(Exception):
  32. pass
  33. class RuleApplyError(Exception):
  34. pass
  35. class FirewallWorker(object):
  36. def __init__(self):
  37. self.terminate_requested = False
  38. self.qdb = qubesdb.QubesDB()
  39. self.log = logging.getLogger('qubes.firewall')
  40. self.log.addHandler(logging.StreamHandler(sys.stderr))
  41. def init(self):
  42. """Create appropriate chains/tables"""
  43. raise NotImplementedError
  44. def sd_notify(self, state):
  45. """Send notification to systemd, if available"""
  46. # based on sdnotify python module
  47. if 'NOTIFY_SOCKET' not in os.environ:
  48. return
  49. addr = os.environ['NOTIFY_SOCKET']
  50. if addr[0] == '@':
  51. addr = '\0' + addr[1:]
  52. try:
  53. sock = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
  54. sock.connect(addr)
  55. sock.sendall(state.encode())
  56. except:
  57. # generally ignore error on systemd notification
  58. pass
  59. def cleanup(self):
  60. """Remove tables/chains - reverse work done by init"""
  61. raise NotImplementedError
  62. def apply_rules(self, source_addr, rules):
  63. """Apply rules in given source address"""
  64. raise NotImplementedError
  65. def run_firewall_dir(self):
  66. """Run scripts dir contents, before user script"""
  67. script_dir_paths = ['/etc/qubes/qubes-firewall.d',
  68. '/rw/config/qubes-firewall.d']
  69. for script_dir_path in script_dir_paths:
  70. if not os.path.isdir(script_dir_path):
  71. continue
  72. for d_script in sorted(os.listdir(script_dir_path)):
  73. d_script_path = os.path.join(script_dir_path, d_script)
  74. if os.path.isfile(d_script_path) and \
  75. os.access(d_script_path, os.X_OK):
  76. subprocess.call([d_script_path])
  77. def run_user_script(self):
  78. """Run user script in /rw/config"""
  79. user_script_path = '/rw/config/qubes-firewall-user-script'
  80. if os.path.isfile(user_script_path) and \
  81. os.access(user_script_path, os.X_OK):
  82. subprocess.call([user_script_path])
  83. def read_rules(self, target):
  84. """Read rules from QubesDB and return them as a list of dicts"""
  85. entries = self.qdb.multiread('/qubes-firewall/{}/'.format(target))
  86. assert isinstance(entries, dict)
  87. # drop full path
  88. entries = dict(((k.split('/')[3], v.decode())
  89. for k, v in entries.items()))
  90. if 'policy' not in entries:
  91. raise RuleParseError('No \'policy\' defined')
  92. policy = entries.pop('policy')
  93. rules = []
  94. for ruleno, rule in sorted(entries.items()):
  95. if len(ruleno) != 4 or not ruleno.isdigit():
  96. raise RuleParseError(
  97. 'Unexpected non-rule found: {}={}'.format(ruleno, rule))
  98. rule_dict = dict(elem.split('=') for elem in rule.split(' '))
  99. if 'action' not in rule_dict:
  100. raise RuleParseError('Rule \'{}\' lack action'.format(rule))
  101. rules.append(rule_dict)
  102. rules.append({'action': policy})
  103. return rules
  104. def list_targets(self):
  105. return set(t.split('/')[2] for t in self.qdb.list('/qubes-firewall/'))
  106. @staticmethod
  107. def is_ip6(addr):
  108. return addr.count(':') > 0
  109. def log_error(self, msg):
  110. self.log.error(msg)
  111. subprocess.call(
  112. ['notify-send', '-t', '3000', msg],
  113. env=os.environ.copy().update({'DISPLAY': ':0'})
  114. )
  115. def handle_addr(self, addr):
  116. try:
  117. rules = self.read_rules(addr)
  118. self.apply_rules(addr, rules)
  119. except RuleParseError as e:
  120. self.log_error(
  121. 'Failed to parse rules for {} ({}), blocking traffic'.format(
  122. addr, str(e)
  123. ))
  124. self.apply_rules(addr, [{'action': 'drop'}])
  125. except RuleApplyError as e:
  126. self.log_error(
  127. 'Failed to apply rules for {} ({}), blocking traffic'.format(
  128. addr, str(e))
  129. )
  130. # retry with fallback rules
  131. try:
  132. self.apply_rules(addr, [{'action': 'drop'}])
  133. except RuleApplyError:
  134. self.log_error(
  135. 'Failed to block traffic for {}'.format(addr))
  136. @staticmethod
  137. def dns_addresses(family=None):
  138. with open('/etc/resolv.conf') as resolv:
  139. for line in resolv.readlines():
  140. line = line.strip()
  141. if line.startswith('nameserver'):
  142. if line.count('.') == 3 and (family or 4) == 4:
  143. yield line.split(' ')[1]
  144. elif line.count(':') and (family or 6) == 6:
  145. yield line.split(' ')[1]
  146. def main(self):
  147. self.terminate_requested = False
  148. self.init()
  149. self.run_firewall_dir()
  150. self.run_user_script()
  151. self.sd_notify('READY=1')
  152. # initial load
  153. for source_addr in self.list_targets():
  154. self.handle_addr(source_addr)
  155. self.qdb.watch('/qubes-firewall/')
  156. try:
  157. for watch_path in iter(self.qdb.read_watch, None):
  158. # ignore writing rules itself - wait for final write at
  159. # source_addr level empty write (/qubes-firewall/SOURCE_ADDR)
  160. if watch_path.count('/') > 2:
  161. continue
  162. source_addr = watch_path.split('/')[2]
  163. self.handle_addr(source_addr)
  164. except OSError: # EINTR
  165. # signal received, don't continue the loop
  166. pass
  167. self.cleanup()
  168. def terminate(self):
  169. self.terminate_requested = True
  170. class IptablesWorker(FirewallWorker):
  171. supported_rule_opts = ['action', 'proto', 'dst4', 'dst6', 'dsthost',
  172. 'dstports', 'specialtarget', 'icmptype']
  173. def __init__(self):
  174. super(IptablesWorker, self).__init__()
  175. self.chains = {
  176. 4: set(),
  177. 6: set(),
  178. }
  179. @staticmethod
  180. def chain_for_addr(addr):
  181. """Generate iptables chain name for given source address address"""
  182. return 'qbs-' + addr.replace('.', '-').replace(':', '-')[-20:]
  183. def run_ipt(self, family, args, **kwargs):
  184. # pylint: disable=no-self-use
  185. if family == 6:
  186. subprocess.check_call(['ip6tables'] + args, **kwargs)
  187. else:
  188. subprocess.check_call(['iptables'] + args, **kwargs)
  189. def run_ipt_restore(self, family, args):
  190. # pylint: disable=no-self-use
  191. if family == 6:
  192. return subprocess.Popen(['ip6tables-restore'] + args,
  193. stdin=subprocess.PIPE,
  194. stdout=subprocess.PIPE,
  195. stderr=subprocess.STDOUT)
  196. else:
  197. return subprocess.Popen(['iptables-restore'] + args,
  198. stdin=subprocess.PIPE,
  199. stdout=subprocess.PIPE,
  200. stderr=subprocess.STDOUT)
  201. def create_chain(self, addr, chain, family):
  202. """
  203. Create iptables chain and hook traffic coming from `addr` to it.
  204. :param addr: source IP from which traffic should be handled by the
  205. chain
  206. :param chain: name of the chain to create
  207. :param family: address family (4 or 6)
  208. :return: None
  209. """
  210. self.run_ipt(family, ['-N', chain])
  211. self.run_ipt(family,
  212. ['-I', 'QBS-FORWARD', '-s', addr, '-j', chain])
  213. self.chains[family].add(chain)
  214. def prepare_rules(self, chain, rules, family):
  215. """
  216. Helper function to translate rules list into input for iptables-restore
  217. :param chain: name of the chain to put rules into
  218. :param rules: list of rules
  219. :param family: address family (4 or 6)
  220. :return: input for iptables-restore
  221. :rtype: str
  222. """
  223. iptables = "*filter\n"
  224. fullmask = '/128' if family == 6 else '/32'
  225. dns = list(addr + fullmask for addr in self.dns_addresses(family))
  226. for rule in rules:
  227. unsupported_opts = set(rule.keys()).difference(
  228. set(self.supported_rule_opts))
  229. if unsupported_opts:
  230. raise RuleParseError(
  231. 'Unsupported rule option(s): {!s}'.format(unsupported_opts))
  232. if 'dst4' in rule and family == 6:
  233. raise RuleParseError('IPv4 rule found for IPv6 address')
  234. if 'dst6' in rule and family == 4:
  235. raise RuleParseError('dst6 rule found for IPv4 address')
  236. if 'proto' in rule:
  237. if rule['proto'] == 'icmp' and family == 6:
  238. protos = ['icmpv6']
  239. else:
  240. protos = [rule['proto']]
  241. else:
  242. protos = None
  243. if 'dst4' in rule:
  244. dsthosts = [rule['dst4']]
  245. elif 'dst6' in rule:
  246. dsthosts = [rule['dst6']]
  247. elif 'dsthost' in rule:
  248. try:
  249. addrinfo = socket.getaddrinfo(rule['dsthost'], None,
  250. (socket.AF_INET6 if family == 6 else socket.AF_INET))
  251. except socket.gaierror as e:
  252. raise RuleParseError('Failed to resolve {}: {}'.format(
  253. rule['dsthost'], str(e)))
  254. dsthosts = set(item[4][0] + fullmask for item in addrinfo)
  255. else:
  256. dsthosts = None
  257. if 'dstports' in rule:
  258. dstports = rule['dstports'].replace('-', ':')
  259. else:
  260. dstports = None
  261. if rule.get('specialtarget', None) == 'dns':
  262. if dstports not in ('53:53', None):
  263. continue
  264. else:
  265. dstports = '53:53'
  266. if not dns:
  267. continue
  268. if protos is not None:
  269. protos = {'tcp', 'udp'}.intersection(protos)
  270. else:
  271. protos = {'tcp', 'udp'}
  272. if dsthosts is not None:
  273. dsthosts = set(dns).intersection(dsthosts)
  274. else:
  275. dsthosts = dns
  276. if 'icmptype' in rule:
  277. icmptype = rule['icmptype']
  278. else:
  279. icmptype = None
  280. # make them iterable
  281. if protos is None:
  282. protos = [None]
  283. if dsthosts is None:
  284. dsthosts = [None]
  285. if rule['action'] == 'accept':
  286. action = 'ACCEPT'
  287. elif rule['action'] == 'drop':
  288. action = 'REJECT --reject-with {}'.format(
  289. 'icmp6-adm-prohibited' if family == 6 else
  290. 'icmp-admin-prohibited')
  291. else:
  292. raise RuleParseError(
  293. 'Invalid rule action {}'.format(rule['action']))
  294. # sorting here is only to ease writing tests
  295. for proto in sorted(protos):
  296. for dsthost in sorted(dsthosts):
  297. ipt_rule = '-A {}'.format(chain)
  298. if dsthost is not None:
  299. ipt_rule += ' -d {}'.format(dsthost)
  300. if proto is not None:
  301. ipt_rule += ' -p {}'.format(proto)
  302. if dstports is not None:
  303. ipt_rule += ' --dport {}'.format(dstports)
  304. if icmptype is not None:
  305. ipt_rule += ' --icmp-type {}'.format(icmptype)
  306. ipt_rule += ' -j {}\n'.format(action)
  307. iptables += ipt_rule
  308. iptables += 'COMMIT\n'
  309. return iptables
  310. def apply_rules_family(self, source, rules, family):
  311. """
  312. Apply rules for given source address.
  313. Handle only rules for given address family (IPv4 or IPv6).
  314. :param source: source address
  315. :param rules: rules list
  316. :param family: address family, either 4 or 6
  317. :return: None
  318. """
  319. chain = self.chain_for_addr(source)
  320. if chain not in self.chains[family]:
  321. self.create_chain(source, chain, family)
  322. iptables = self.prepare_rules(chain, rules, family)
  323. try:
  324. self.run_ipt(family, ['-F', chain])
  325. p = self.run_ipt_restore(family, ['-n'])
  326. (output, _) = p.communicate(iptables.encode())
  327. if p.returncode != 0:
  328. raise RuleApplyError(
  329. 'iptables-restore failed: {}'.format(output))
  330. except subprocess.CalledProcessError as e:
  331. raise RuleApplyError('\'iptables -F {}\' failed: {}'.format(
  332. chain, e.output))
  333. def apply_rules(self, source, rules):
  334. if self.is_ip6(source):
  335. self.apply_rules_family(source, rules, 6)
  336. else:
  337. self.apply_rules_family(source, rules, 4)
  338. def init(self):
  339. # make sure 'QBS_FORWARD' chain exists - should be created before
  340. # starting qubes-firewall
  341. try:
  342. self.run_ipt(4, ['-F', 'QBS-FORWARD'])
  343. self.run_ipt(4,
  344. ['-A', 'QBS-FORWARD', '!', '-i', 'vif+', '-j', 'RETURN'])
  345. self.run_ipt(4, ['-A', 'QBS-FORWARD', '-j', 'DROP'])
  346. self.run_ipt(6, ['-F', 'QBS-FORWARD'])
  347. self.run_ipt(6,
  348. ['-A', 'QBS-FORWARD', '!', '-i', 'vif+', '-j', 'RETURN'])
  349. self.run_ipt(6, ['-A', 'QBS-FORWARD', '-j', 'DROP'])
  350. except subprocess.CalledProcessError:
  351. self.log_error('\'QBS-FORWARD\' chain not found, create it first')
  352. sys.exit(1)
  353. def cleanup(self):
  354. for family in (4, 6):
  355. self.run_ipt(family, ['-F', 'QBS-FORWARD'])
  356. for chain in self.chains[family]:
  357. self.run_ipt(family, ['-F', chain])
  358. self.run_ipt(family, ['-X', chain])
  359. class NftablesWorker(FirewallWorker):
  360. supported_rule_opts = ['action', 'proto', 'dst4', 'dst6', 'dsthost',
  361. 'dstports', 'specialtarget', 'icmptype']
  362. def __init__(self):
  363. super(NftablesWorker, self).__init__()
  364. self.chains = {
  365. 4: set(),
  366. 6: set(),
  367. }
  368. @staticmethod
  369. def chain_for_addr(addr):
  370. """Generate iptables chain name for given source address address"""
  371. return 'qbs-' + addr.replace('.', '-').replace(':', '-')
  372. def run_nft(self, nft_input):
  373. # pylint: disable=no-self-use
  374. p = subprocess.Popen(['nft', '-f', '/dev/stdin'],
  375. stdin=subprocess.PIPE,
  376. stdout=subprocess.PIPE,
  377. stderr=subprocess.STDOUT)
  378. stdout, _ = p.communicate(nft_input.encode())
  379. if p.returncode != 0:
  380. raise RuleApplyError('nft failed: {}'.format(stdout))
  381. def create_chain(self, addr, chain, family):
  382. """
  383. Create iptables chain and hook traffic coming from `addr` to it.
  384. :param addr: source IP from which traffic should be handled by the
  385. chain
  386. :param chain: name of the chain to create
  387. :param family: address family (4 or 6)
  388. :return: None
  389. """
  390. nft_input = (
  391. 'table {family} {table} {{\n'
  392. ' chain {chain} {{\n'
  393. ' }}\n'
  394. ' chain forward {{\n'
  395. ' {family} saddr {ip} jump {chain}\n'
  396. ' }}\n'
  397. '}}\n'.format(
  398. family=("ip6" if family == 6 else "ip"),
  399. table='qubes-firewall',
  400. chain=chain,
  401. ip=addr,
  402. )
  403. )
  404. self.run_nft(nft_input)
  405. self.chains[family].add(chain)
  406. def prepare_rules(self, chain, rules, family):
  407. """
  408. Helper function to translate rules list into input for iptables-restore
  409. :param chain: name of the chain to put rules into
  410. :param rules: list of rules
  411. :param family: address family (4 or 6)
  412. :return: input for iptables-restore
  413. :rtype: str
  414. """
  415. assert family in (4, 6)
  416. nft_rules = []
  417. ip_match = 'ip6' if family == 6 else 'ip'
  418. fullmask = '/128' if family == 6 else '/32'
  419. dns = list(addr + fullmask for addr in self.dns_addresses(family))
  420. for rule in rules:
  421. unsupported_opts = set(rule.keys()).difference(
  422. set(self.supported_rule_opts))
  423. if unsupported_opts:
  424. raise RuleParseError(
  425. 'Unsupported rule option(s): {!s}'.format(unsupported_opts))
  426. if 'dst4' in rule and family == 6:
  427. raise RuleParseError('IPv4 rule found for IPv6 address')
  428. if 'dst6' in rule and family == 4:
  429. raise RuleParseError('dst6 rule found for IPv4 address')
  430. nft_rule = ""
  431. if rule['action'] == 'accept':
  432. action = 'accept'
  433. elif rule['action'] == 'drop':
  434. action = 'reject with icmp{} type admin-prohibited'.format(
  435. 'v6' if family == 6 else '')
  436. else:
  437. raise RuleParseError(
  438. 'Invalid rule action {}'.format(rule['action']))
  439. if 'proto' in rule:
  440. if family == 4:
  441. nft_rule += ' ip protocol {}'.format(rule['proto'])
  442. elif family == 6:
  443. proto = 'icmpv6' if rule['proto'] == 'icmp' \
  444. else rule['proto']
  445. nft_rule += ' ip6 nexthdr {}'.format(proto)
  446. if 'dst4' in rule:
  447. nft_rule += ' ip daddr {}'.format(rule['dst4'])
  448. elif 'dst6' in rule:
  449. nft_rule += ' ip6 daddr {}'.format(rule['dst6'])
  450. elif 'dsthost' in rule:
  451. try:
  452. addrinfo = socket.getaddrinfo(rule['dsthost'], None,
  453. (socket.AF_INET6 if family == 6 else socket.AF_INET))
  454. except socket.gaierror as e:
  455. raise RuleParseError('Failed to resolve {}: {}'.format(
  456. rule['dsthost'], str(e)))
  457. nft_rule += ' {} daddr {{ {} }}'.format(ip_match,
  458. ', '.join(set(item[4][0] + fullmask for item in addrinfo)))
  459. if 'dstports' in rule:
  460. dstports = rule['dstports']
  461. if len(set(dstports.split('-'))) == 1:
  462. dstports = dstports.split('-')[0]
  463. else:
  464. dstports = None
  465. if rule.get('specialtarget', None) == 'dns':
  466. if dstports not in ('53', None):
  467. continue
  468. else:
  469. dstports = '53'
  470. if not dns:
  471. continue
  472. nft_rule += ' {} daddr {{ {} }}'.format(ip_match, ', '.join(
  473. dns))
  474. if 'icmptype' in rule:
  475. if family == 4:
  476. nft_rule += ' icmp type {}'.format(rule['icmptype'])
  477. elif family == 6:
  478. nft_rule += ' icmpv6 type {}'.format(rule['icmptype'])
  479. # now duplicate rules for tcp/udp if needed
  480. # it isn't possible to specify "tcp dport xx || udp dport xx" in
  481. # one rule
  482. if dstports is not None:
  483. if 'proto' not in rule:
  484. nft_rules.append(
  485. nft_rule + ' tcp dport {} {}'.format(
  486. dstports, action))
  487. nft_rules.append(
  488. nft_rule + ' udp dport {} {}'.format(
  489. dstports, action))
  490. else:
  491. nft_rules.append(
  492. nft_rule + ' {} dport {} {}'.format(
  493. rule['proto'], dstports, action))
  494. else:
  495. nft_rules.append(nft_rule + ' ' + action)
  496. return (
  497. 'flush chain {family} {table} {chain}\n'
  498. 'table {family} {table} {{\n'
  499. ' chain {chain} {{\n'
  500. ' {rules}\n'
  501. ' }}\n'
  502. '}}\n'.format(
  503. family=('ip6' if family == 6 else 'ip'),
  504. table='qubes-firewall',
  505. chain=chain,
  506. rules='\n '.join(nft_rules)
  507. ))
  508. def apply_rules_family(self, source, rules, family):
  509. """
  510. Apply rules for given source address.
  511. Handle only rules for given address family (IPv4 or IPv6).
  512. :param source: source address
  513. :param rules: rules list
  514. :param family: address family, either 4 or 6
  515. :return: None
  516. """
  517. chain = self.chain_for_addr(source)
  518. if chain not in self.chains[family]:
  519. self.create_chain(source, chain, family)
  520. self.run_nft(self.prepare_rules(chain, rules, family))
  521. def apply_rules(self, source, rules):
  522. if self.is_ip6(source):
  523. self.apply_rules_family(source, rules, 6)
  524. else:
  525. self.apply_rules_family(source, rules, 4)
  526. def init(self):
  527. # make sure 'QBS_FORWARD' chain exists - should be created before
  528. # starting qubes-firewall
  529. nft_init = (
  530. 'table {family} qubes-firewall {{\n'
  531. ' chain forward {{\n'
  532. ' type filter hook forward priority 0;\n'
  533. ' policy drop;\n'
  534. ' ct state established,related accept\n'
  535. ' meta iifname != "vif*" accept\n'
  536. ' }}\n'
  537. '}}\n'
  538. )
  539. nft_init = ''.join(
  540. nft_init.format(family=family) for family in ('ip', 'ip6'))
  541. self.run_nft(nft_init)
  542. def cleanup(self):
  543. nft_cleanup = (
  544. 'delete table ip qubes-firewall\n'
  545. 'delete table ip6 qubes-firewall\n'
  546. )
  547. self.run_nft(nft_cleanup)
  548. def main():
  549. if spawn.find_executable('nft'):
  550. worker = NftablesWorker()
  551. else:
  552. worker = IptablesWorker()
  553. context = daemon.DaemonContext()
  554. context.stderr = sys.stderr
  555. context.detach_process = False
  556. context.files_preserve = [worker.qdb.watch_fd()]
  557. context.signal_map = {
  558. signal.SIGTERM: lambda _signal, _stack: worker.terminate(),
  559. }
  560. with context:
  561. worker.main()
  562. if __name__ == '__main__':
  563. main()