From 0d49ba27103a0612a898e6fb9a04202cebf58905 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marek=20Marczykowski-G=C3=B3recki?= Date: Sun, 30 Apr 2017 22:56:22 +0200 Subject: [PATCH] firewall: add firewall API To keep API compatibility with core-admin, most data structures are copied from there. --- qubesmgmt/firewall.py | 448 ++++++++++++++++++++++++++++++++++ qubesmgmt/tests/firewall.py | 467 ++++++++++++++++++++++++++++++++++++ qubesmgmt/vm/__init__.py | 4 + 3 files changed, 919 insertions(+) create mode 100644 qubesmgmt/firewall.py create mode 100644 qubesmgmt/tests/firewall.py diff --git a/qubesmgmt/firewall.py b/qubesmgmt/firewall.py new file mode 100644 index 0000000..06045cf --- /dev/null +++ b/qubesmgmt/firewall.py @@ -0,0 +1,448 @@ +# -*- encoding: utf8 -*- +# pylint: disable=too-few-public-methods +# +# The Qubes OS Project, http://www.qubes-os.org +# +# Copyright (C) 2017 Marek Marczykowski-Górecki +# +# +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License as published by +# the Free Software Foundation; either version 2.1 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License along +# with this program; if not, see . + +'''Firewall configuration interface''' + +import datetime +import socket + +class RuleOption(object): + '''Base class for a single rule element''' + def __init__(self, value): + self._value = str(value) + + @property + def rule(self): + '''API representation of this rule element''' + raise NotImplementedError + + def __str__(self): + return self._value + + def __eq__(self, other): + return str(self) == other + + +# noinspection PyAbstractClass +class RuleChoice(RuleOption): + '''Base class for multiple-choices rule elements''' + # pylint: disable=abstract-method + def __init__(self, value): + super(RuleChoice, self).__init__(value) + self.allowed_values = \ + [v for k, v in self.__class__.__dict__.items() + if not k.startswith('__') and isinstance(v, str) and + not v.startswith('__')] + if value not in self.allowed_values: + raise ValueError(value) + + +class Action(RuleChoice): + '''Rule action''' + accept = 'accept' + drop = 'drop' + + @property + def rule(self): + '''API representation of this rule element''' + return 'action=' + str(self) + + +class Proto(RuleChoice): + '''Protocol name''' + tcp = 'tcp' + udp = 'udp' + icmp = 'icmp' + + @property + def rule(self): + '''API representation of this rule element''' + return 'proto=' + str(self) + + +class DstHost(RuleOption): + '''Represent host/network address: either IPv4, IPv6, or DNS name''' + def __init__(self, value, prefixlen=None): + # TODO: in python >= 3.3 ipaddress module could be used + if value.count('/') > 1: + raise ValueError('Too many /: ' + value) + elif not value.count('/'): + # add prefix length to bare IP addresses + try: + socket.inet_pton(socket.AF_INET6, value) + self.prefixlen = prefixlen or 128 + if self.prefixlen < 0 or self.prefixlen > 128: + raise ValueError( + 'netmask for IPv6 must be between 0 and 128') + value += '/' + str(self.prefixlen) + self.type = 'dst6' + except socket.error: + try: + socket.inet_pton(socket.AF_INET, value) + if value.count('.') != 3: + raise ValueError( + 'Invalid number of dots in IPv4 address') + self.prefixlen = prefixlen or 32 + if self.prefixlen < 0 or self.prefixlen > 32: + raise ValueError( + 'netmask for IPv4 must be between 0 and 32') + value += '/' + str(self.prefixlen) + self.type = 'dst4' + except socket.error: + self.type = 'dsthost' + self.prefixlen = 0 + else: + host, prefixlen = value.split('/', 1) + prefixlen = int(prefixlen) + if prefixlen < 0: + raise ValueError('netmask must be non-negative') + self.prefixlen = prefixlen + try: + socket.inet_pton(socket.AF_INET6, host) + if prefixlen > 128: + raise ValueError('netmask for IPv6 must be <= 128') + self.type = 'dst6' + except socket.error: + try: + socket.inet_pton(socket.AF_INET, host) + if prefixlen > 32: + raise ValueError('netmask for IPv4 must be <= 32') + self.type = 'dst4' + if host.count('.') != 3: + raise ValueError( + 'Invalid number of dots in IPv4 address') + except socket.error: + raise ValueError('Invalid IP address: ' + host) + + super(DstHost, self).__init__(value) + + @property + def rule(self): + '''API representation of this rule element''' + return self.type + '=' + str(self) + + +class DstPorts(RuleOption): + '''Destination port(s), for TCP/UDP only''' + def __init__(self, value): + if isinstance(value, int): + value = str(value) + if value.count('-') == 1: + self.range = [int(x) for x in value.split('-', 1)] + elif not value.count('-'): + self.range = [int(value), int(value)] + else: + raise ValueError(value) + if any(port < 0 or port > 65536 for port in self.range): + raise ValueError('Ports out of range') + if self.range[0] > self.range[1]: + raise ValueError('Invalid port range') + super(DstPorts, self).__init__( + str(self.range[0]) if self.range[0] == self.range[1] + else '{!s}-{!s}'.format(*self.range)) + + @property + def rule(self): + '''API representation of this rule element''' + return 'dstports=' + '{!s}-{!s}'.format(*self.range) + + +class IcmpType(RuleOption): + '''ICMP packet type''' + def __init__(self, value): + super(IcmpType, self).__init__(value) + value = int(value) + if value < 0 or value > 255: + raise ValueError('ICMP type out of range') + + @property + def rule(self): + '''API representation of this rule element''' + return 'icmptype=' + str(self) + + +class SpecialTarget(RuleChoice): + '''Special destination''' + dns = 'dns' + + @property + def rule(self): + '''API representation of this rule element''' + return 'specialtarget=' + str(self) + + +class Expire(RuleOption): + '''Rule expire time''' + def __init__(self, value): + super(Expire, self).__init__(value) + self.datetime = datetime.datetime.utcfromtimestamp(int(value)) + + @property + def rule(self): + '''API representation of this rule element''' + return 'expire=' + str(self) + + @property + def expired(self): + '''Have this rule expired already?''' + return self.datetime < datetime.datetime.utcnow() + + +class Comment(RuleOption): + '''User comment''' + @property + def rule(self): + '''API representation of this rule element''' + return 'comment=' + str(self) + + +class Rule(object): + '''A single firewall rule''' + + def __init__(self, rule, **kwargs): + '''Single firewall rule + + :param xml: XML element describing rule, or None + :param kwargs: rule elements + ''' + self._action = None + self._proto = None + self._dsthost = None + self._dstports = None + self._icmptype = None + self._specialtarget = None + self._expire = None + self._comment = None + + rule_dict = {} + if rule is not None: + rule_opts, _, comment = rule.partition('comment=') + + rule_dict = dict(rule_opt.split('=', 1) for rule_opt in + rule_opts.split(' ') if rule_opt) + if comment: + rule_dict['comment'] = comment + rule_dict.update(kwargs) + + rule_elements = ('action', 'proto', 'dsthost', 'dst4', 'dst6', + 'specialtarget', 'dstports', 'icmptype', 'expire', 'comment') + for rule_opt in rule_elements: + value = rule_dict.pop(rule_opt, None) + if value is None: + continue + if rule_opt in ('dst4', 'dst6'): + rule_opt = 'dsthost' + setattr(self, rule_opt, value) + + if rule_dict: + raise ValueError('Unknown rule elements: {!r}'.format( + rule_dict)) + + if self.action is None: + raise ValueError('missing action=') + + @property + def action(self): + '''rule action''' + return self._action + + @action.setter + def action(self, value): + if not isinstance(value, Action): + value = Action(value) + self._action = value + + @property + def proto(self): + '''protocol to match''' + return self._proto + + @proto.setter + def proto(self, value): + if value is not None and not isinstance(value, Proto): + value = Proto(value) + if value not in ('tcp', 'udp'): + self.dstports = None + if value not in ('icmp',): + self.icmptype = None + self._proto = value + + @property + def dsthost(self): + '''destination host/network''' + return self._dsthost + + @dsthost.setter + def dsthost(self, value): + if value is not None and not isinstance(value, DstHost): + value = DstHost(value) + self._dsthost = value + + @property + def dstports(self): + ''''Destination port(s) (for \'tcp\' and \'udp\' protocol only)''' + return self._dstports + + @dstports.setter + def dstports(self, value): + if value is not None: + if self.proto not in ('tcp', 'udp'): + raise ValueError( + 'dstports valid only for \'tcp\' and \'udp\' protocols') + if not isinstance(value, DstPorts): + value = DstPorts(value) + self._dstports = value + + @property + def icmptype(self): + '''ICMP packet type (for \'icmp\' protocol only)''' + return self._icmptype + + @icmptype.setter + def icmptype(self, value): + if value is not None: + if self.proto not in ('icmp',): + raise ValueError('icmptype valid only for \'icmp\' protocol') + if not isinstance(value, IcmpType): + value = IcmpType(value) + self._icmptype = value + + @property + def specialtarget(self): + '''Special target, for now only \'dns\' supported''' + return self._specialtarget + + @specialtarget.setter + def specialtarget(self, value): + if not isinstance(value, SpecialTarget): + value = SpecialTarget(value) + self._specialtarget = value + + @property + def expire(self): + '''Timestamp (UNIX epoch) on which this rule expire''' + return self._expire + + @expire.setter + def expire(self, value): + if not isinstance(value, Expire): + value = Expire(value) + self._expire = value + + @property + def comment(self): + '''User comment''' + return self._comment + + @comment.setter + def comment(self, value): + if not isinstance(value, Comment): + value = Comment(value) + self._comment = value + + @property + def rule(self): + '''API representation of this rule''' + values = [] + # comment must be the last one + for prop in ('action', 'proto', 'dsthost', 'dstports', 'icmptype', + 'specialtarget', 'expire', 'comment'): + value = getattr(self, prop) + if value is None: + continue + if value.rule is None: + continue + values.append(value.rule) + return ' '.join(values) + + def __eq__(self, other): + if isinstance(other, Rule): + return self.rule == other.rule + if isinstance(other, str): + return self.rule == str + return NotImplemented + + def __repr__(self): + return 'Rule(\'{}\')'.format(self.rule) + + +class Firewall(object): + '''Firewal manager for a VM''' + def __init__(self, vm): + self.vm = vm + self._rules = [] + self._policy = None + self._loaded = False + + def load_rules(self): + '''Force (re-)loading firewall rules''' + rules_str = self.vm.qubesd_call(None, 'mgmt.vm.firewall.Get') + rules = [] + for rule_str in rules_str.decode().splitlines(): + rules.append(Rule(rule_str)) + self._rules = rules + self._loaded = True + + @property + def rules(self): + '''Firewall rules + + You can either copy them, edit and then assign new rules list to this + property, or edit in-place and call :py:meth:`save_rules`. + Once rules are loaded, they are cached. To reload rules, + call :py:meth:`load_rules`. + ''' + if not self._loaded: + self.load_rules() + return self._rules + + @rules.setter + def rules(self, value): + self.save_rules(value) + self._rules = value + + def save_rules(self, rules=None): + '''Save firewall rules. Needs to be called after in-place editing + :py:attr:`rules`. + ''' + if rules is None: + rules = self._rules + self.vm.qubesd_call(None, 'mgmt.vm.firewall.Set', + payload=(''.join('{}\n'.format(rule.rule) + for rule in rules)).encode('ascii')) + + @property + def policy(self): + '''Default action to take if no rule matches''' + policy_str = self.vm.qubesd_call(None, 'mgmt.vm.firewall.GetPolicy') + return Action(policy_str.decode()) + + @policy.setter + def policy(self, value): + self.vm.qubesd_call(None, 'mgmt.vm.firewall.SetPolicy', payload=str( + value).encode('ascii')) + + def reload(self): + '''Force reload the same firewall rules. + + Can be used for example to force again names resolution. + ''' + self.vm.qubesd_call(None, 'mgmt.vm.firewall.Reload') diff --git a/qubesmgmt/tests/firewall.py b/qubesmgmt/tests/firewall.py new file mode 100644 index 0000000..ddb5f68 --- /dev/null +++ b/qubesmgmt/tests/firewall.py @@ -0,0 +1,467 @@ +# -*- encoding: utf8 -*- +# +# The Qubes OS Project, http://www.qubes-os.org +# +# Copyright (C) 2017 Marek Marczykowski-Górecki +# +# +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License as published by +# the Free Software Foundation; either version 2.1 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License along +# with this program; if not, see . + + +'''Tests for firewall API. This is mostly copy from core-admin''' +import datetime +import unittest +import qubesmgmt.firewall +import qubesmgmt.tests + + +class TestOption(qubesmgmt.firewall.RuleChoice): + opt1 = 'opt1' + opt2 = 'opt2' + another = 'another' + + +# noinspection PyPep8Naming +class TC_00_RuleChoice(qubesmgmt.tests.QubesTestCase): + def test_000_accept_allowed(self): + with self.assertNotRaises(ValueError): + TestOption('opt1') + TestOption('opt2') + TestOption('another') + + def test_001_value_list(self): + instance = TestOption('opt1') + self.assertEqual( + set(instance.allowed_values), {'opt1', 'opt2', 'another'}) + + def test_010_reject_others(self): + self.assertRaises(ValueError, lambda: TestOption('invalid')) + + +class TC_01_Action(qubesmgmt.tests.QubesTestCase): + def test_000_allowed_values(self): + with self.assertNotRaises(ValueError): + instance = qubesmgmt.firewall.Action('accept') + self.assertEqual( + set(instance.allowed_values), {'accept', 'drop'}) + + def test_001_rule(self): + instance = qubesmgmt.firewall.Action('accept') + self.assertEqual(instance.rule, 'action=accept') + + +# noinspection PyPep8Naming +class TC_02_Proto(qubesmgmt.tests.QubesTestCase): + def test_000_allowed_values(self): + with self.assertNotRaises(ValueError): + instance = qubesmgmt.firewall.Proto('tcp') + self.assertEqual( + set(instance.allowed_values), {'tcp', 'udp', 'icmp'}) + + def test_001_rule(self): + instance = qubesmgmt.firewall.Proto('tcp') + self.assertEqual(instance.rule, 'proto=tcp') + + +# noinspection PyPep8Naming +class TC_02_DstHost(qubesmgmt.tests.QubesTestCase): + def test_000_hostname(self): + with self.assertNotRaises(ValueError): + instance = qubesmgmt.firewall.DstHost('qubes-os.org') + self.assertEqual(instance.type, 'dsthost') + + def test_001_ipv4(self): + with self.assertNotRaises(ValueError): + instance = qubesmgmt.firewall.DstHost('127.0.0.1') + self.assertEqual(instance.type, 'dst4') + self.assertEqual(instance.prefixlen, 32) + self.assertEqual(str(instance), '127.0.0.1/32') + self.assertEqual(instance.rule, 'dst4=127.0.0.1/32') + + def test_002_ipv4_prefixlen(self): + with self.assertNotRaises(ValueError): + instance = qubesmgmt.firewall.DstHost('127.0.0.0', 8) + self.assertEqual(instance.type, 'dst4') + self.assertEqual(instance.prefixlen, 8) + self.assertEqual(str(instance), '127.0.0.0/8') + self.assertEqual(instance.rule, 'dst4=127.0.0.0/8') + + def test_003_ipv4_parse_prefixlen(self): + with self.assertNotRaises(ValueError): + instance = qubesmgmt.firewall.DstHost('127.0.0.0/8') + self.assertEqual(instance.type, 'dst4') + self.assertEqual(instance.prefixlen, 8) + self.assertEqual(str(instance), '127.0.0.0/8') + self.assertEqual(instance.rule, 'dst4=127.0.0.0/8') + + def test_004_ipv4_invalid_prefix(self): + with self.assertRaises(ValueError): + qubesmgmt.firewall.DstHost('127.0.0.0/33') + with self.assertRaises(ValueError): + qubesmgmt.firewall.DstHost('127.0.0.0', 33) + with self.assertRaises(ValueError): + qubesmgmt.firewall.DstHost('127.0.0.0/-1') + + def test_005_ipv4_reject_shortened(self): + # not strictly required, but ppl are used to it + with self.assertRaises(ValueError): + qubesmgmt.firewall.DstHost('127/8') + + def test_006_ipv4_invalid_addr(self): + with self.assertRaises(ValueError): + qubesmgmt.firewall.DstHost('137.327.0.0/16') + with self.assertRaises(ValueError): + qubesmgmt.firewall.DstHost('1.2.3.4.5/32') + + @unittest.expectedFailure + def test_007_ipv4_invalid_network(self): + with self.assertRaises(ValueError): + qubesmgmt.firewall.DstHost('127.0.0.1/32') + + def test_010_ipv6(self): + with self.assertNotRaises(ValueError): + instance = qubesmgmt.firewall.DstHost('2001:abcd:efab::3') + self.assertEqual(instance.type, 'dst6') + self.assertEqual(instance.prefixlen, 128) + self.assertEqual(str(instance), '2001:abcd:efab::3/128') + self.assertEqual(instance.rule, 'dst6=2001:abcd:efab::3/128') + + def test_011_ipv6_prefixlen(self): + with self.assertNotRaises(ValueError): + instance = qubesmgmt.firewall.DstHost('2001:abcd:efab::', 64) + self.assertEqual(instance.type, 'dst6') + self.assertEqual(instance.prefixlen, 64) + self.assertEqual(str(instance), '2001:abcd:efab::/64') + self.assertEqual(instance.rule, 'dst6=2001:abcd:efab::/64') + + def test_012_ipv6_parse_prefixlen(self): + with self.assertNotRaises(ValueError): + instance = qubesmgmt.firewall.DstHost('2001:abcd:efab::/64') + self.assertEqual(instance.type, 'dst6') + self.assertEqual(instance.prefixlen, 64) + self.assertEqual(str(instance), '2001:abcd:efab::/64') + self.assertEqual(instance.rule, 'dst6=2001:abcd:efab::/64') + + def test_013_ipv6_invalid_prefix(self): + with self.assertRaises(ValueError): + qubesmgmt.firewall.DstHost('2001:abcd:efab::3/129') + with self.assertRaises(ValueError): + qubesmgmt.firewall.DstHost('2001:abcd:efab::3', 129) + with self.assertRaises(ValueError): + qubesmgmt.firewall.DstHost('2001:abcd:efab::3/-1') + + def test_014_ipv6_invalid_addr(self): + with self.assertRaises(ValueError): + qubesmgmt.firewall.DstHost('2001:abcd:efab0123::3/128') + with self.assertRaises(ValueError): + qubesmgmt.firewall.DstHost('2001:abcd:efab:3/128') + with self.assertRaises(ValueError): + qubesmgmt.firewall.DstHost('2001:abcd:efab:a:a:a:a:a:a:3/128') + with self.assertRaises(ValueError): + qubesmgmt.firewall.DstHost('2001:abcd:efgh::3/128') + + @unittest.expectedFailure + def test_015_ipv6_invalid_network(self): + with self.assertRaises(ValueError): + qubesmgmt.firewall.DstHost('2001:abcd:efab::3/64') + + @unittest.expectedFailure + def test_020_invalid_hostname(self): + with self.assertRaises(ValueError): + qubesmgmt.firewall.DstHost('www qubes-os.org') + with self.assertRaises(ValueError): + qubesmgmt.firewall.DstHost('https://qubes-os.org') + + +class TC_03_DstPorts(qubesmgmt.tests.QubesTestCase): + def test_000_single_str(self): + with self.assertNotRaises(ValueError): + instance = qubesmgmt.firewall.DstPorts('80') + self.assertEqual(str(instance), '80') + self.assertEqual(instance.range, [80, 80]) + self.assertEqual(instance.rule, 'dstports=80-80') + + def test_001_single_int(self): + with self.assertNotRaises(ValueError): + instance = qubesmgmt.firewall.DstPorts(80) + self.assertEqual(str(instance), '80') + self.assertEqual(instance.range, [80, 80]) + self.assertEqual(instance.rule, 'dstports=80-80') + + def test_002_range(self): + with self.assertNotRaises(ValueError): + instance = qubesmgmt.firewall.DstPorts('80-90') + self.assertEqual(str(instance), '80-90') + self.assertEqual(instance.range, [80, 90]) + self.assertEqual(instance.rule, 'dstports=80-90') + + def test_003_invalid(self): + with self.assertRaises(ValueError): + qubesmgmt.firewall.DstPorts('80-90-100') + with self.assertRaises(ValueError): + qubesmgmt.firewall.DstPorts('abcdef') + with self.assertRaises(ValueError): + qubesmgmt.firewall.DstPorts('80 90') + with self.assertRaises(ValueError): + qubesmgmt.firewall.DstPorts('') + + def test_004_reversed_range(self): + with self.assertRaises(ValueError): + qubesmgmt.firewall.DstPorts('100-20') + + def test_005_out_of_range(self): + with self.assertRaises(ValueError): + qubesmgmt.firewall.DstPorts('1000000000000') + with self.assertRaises(ValueError): + qubesmgmt.firewall.DstPorts(1000000000000) + with self.assertRaises(ValueError): + qubesmgmt.firewall.DstPorts('1-1000000000000') + + +class TC_04_IcmpType(qubesmgmt.tests.QubesTestCase): + def test_000_number(self): + with self.assertNotRaises(ValueError): + instance = qubesmgmt.firewall.IcmpType(8) + self.assertEqual(str(instance), '8') + self.assertEqual(instance.rule, 'icmptype=8') + + def test_001_str(self): + with self.assertNotRaises(ValueError): + instance = qubesmgmt.firewall.IcmpType('8') + self.assertEqual(str(instance), '8') + self.assertEqual(instance.rule, 'icmptype=8') + + def test_002_invalid(self): + with self.assertRaises(ValueError): + qubesmgmt.firewall.IcmpType(600) + with self.assertRaises(ValueError): + qubesmgmt.firewall.IcmpType(-1) + with self.assertRaises(ValueError): + qubesmgmt.firewall.IcmpType('abcde') + with self.assertRaises(ValueError): + qubesmgmt.firewall.IcmpType('') + + +class TC_05_SpecialTarget(qubesmgmt.tests.QubesTestCase): + def test_000_allowed_values(self): + with self.assertNotRaises(ValueError): + instance = qubesmgmt.firewall.SpecialTarget('dns') + self.assertEqual( + set(instance.allowed_values), {'dns'}) + + def test_001_rule(self): + instance = qubesmgmt.firewall.SpecialTarget('dns') + self.assertEqual(instance.rule, 'specialtarget=dns') + + +class TC_06_Expire(qubesmgmt.tests.QubesTestCase): + def test_000_number(self): + with self.assertNotRaises(ValueError): + instance = qubesmgmt.firewall.Expire(1463292452) + self.assertEqual(str(instance), '1463292452') + self.assertEqual(instance.datetime, + datetime.datetime(2016, 5, 15, 6, 7, 32)) + self.assertEqual(instance.rule, 'expire=1463292452') + + def test_001_str(self): + with self.assertNotRaises(ValueError): + instance = qubesmgmt.firewall.Expire('1463292452') + self.assertEqual(str(instance), '1463292452') + self.assertEqual(instance.datetime, + datetime.datetime(2016, 5, 15, 6, 7, 32)) + self.assertEqual(instance.rule, 'expire=1463292452') + + def test_002_invalid(self): + with self.assertRaises(ValueError): + qubesmgmt.firewall.Expire('abcdef') + with self.assertRaises(ValueError): + qubesmgmt.firewall.Expire('') + + def test_003_expired(self): + with self.assertNotRaises(ValueError): + instance = qubesmgmt.firewall.Expire('1463292452') + self.assertTrue(instance.expired) + with self.assertNotRaises(ValueError): + instance = qubesmgmt.firewall.Expire('1583292452') + self.assertFalse(instance.expired) + + +class TC_07_Comment(qubesmgmt.tests.QubesTestCase): + def test_000_str(self): + with self.assertNotRaises(ValueError): + instance = qubesmgmt.firewall.Comment('Some comment') + self.assertEqual(str(instance), 'Some comment') + self.assertEqual(instance.rule, 'comment=Some comment') + + +class TC_10_Rule(qubesmgmt.tests.QubesTestCase): + def test_000_simple(self): + with self.assertNotRaises(ValueError): + rule = qubesmgmt.firewall.Rule(None, action='accept', proto='icmp') + self.assertEqual(rule.rule, 'action=accept proto=icmp') + self.assertIsNone(rule.dsthost) + self.assertIsNone(rule.dstports) + self.assertIsNone(rule.icmptype) + self.assertIsNone(rule.comment) + self.assertIsNone(rule.expire) + self.assertEqual(str(rule.action), 'accept') + self.assertEqual(str(rule.proto), 'icmp') + + def test_001_expire(self): + with self.assertNotRaises(ValueError): + rule = qubesmgmt.firewall.Rule(None, action='accept', proto='icmp', + expire='1463292452') + self.assertEqual(rule.rule, + 'action=accept proto=icmp expire=1463292452') + + + def test_002_dstports(self): + with self.assertNotRaises(ValueError): + rule = qubesmgmt.firewall.Rule(None, action='accept', proto='tcp', + dstports=80) + self.assertEqual(str(rule.dstports), '80') + + def test_003_reject_invalid(self): + with self.assertRaises((ValueError, AssertionError)): + # missing action + qubesmgmt.firewall.Rule(None, proto='icmp') + with self.assertRaises(ValueError): + # not proto=tcp or proto=udp for dstports + qubesmgmt.firewall.Rule(None, action='accept', proto='icmp', + dstports=80) + with self.assertRaises(ValueError): + # not proto=tcp or proto=udp for dstports + qubesmgmt.firewall.Rule(None, action='accept', dstports=80) + with self.assertRaises(ValueError): + # not proto=icmp for icmptype + qubesmgmt.firewall.Rule(None, action='accept', proto='tcp', + icmptype=8) + with self.assertRaises(ValueError): + # not proto=icmp for icmptype + qubesmgmt.firewall.Rule(None, action='accept', icmptype=8) + + def test_004_proto_change(self): + rule = qubesmgmt.firewall.Rule(None, action='accept', proto='tcp') + with self.assertNotRaises(ValueError): + rule.proto = 'udp' + self.assertEqual(rule.rule, 'action=accept proto=udp') + rule = qubesmgmt.firewall.Rule(None, action='accept', proto='tcp', + dstports=80) + with self.assertNotRaises(ValueError): + rule.proto = 'udp' + self.assertEqual(rule.rule, 'action=accept proto=udp dstports=80-80') + rule = qubesmgmt.firewall.Rule(None, action='accept') + with self.assertNotRaises(ValueError): + rule.proto = 'udp' + self.assertEqual(rule.rule, 'action=accept proto=udp') + with self.assertNotRaises(ValueError): + rule.dstports = 80 + self.assertEqual(rule.rule, 'action=accept proto=udp dstports=80-80') + with self.assertNotRaises(ValueError): + rule.proto = 'icmp' + self.assertEqual(rule.rule, 'action=accept proto=icmp') + self.assertIsNone(rule.dstports) + rule.icmptype = 8 + self.assertEqual(rule.rule, 'action=accept proto=icmp icmptype=8') + with self.assertNotRaises(ValueError): + rule.proto = None + self.assertEqual(rule.rule, 'action=accept') + self.assertIsNone(rule.dstports) + + def test_005_parse_str(self): + rule_txt = \ + 'action=accept dst4=192.168.0.0/24 proto=tcp dstports=443' + with self.assertNotRaises(ValueError): + rule = qubesmgmt.firewall.Rule(rule_txt) + self.assertEqual(rule.dsthost, '192.168.0.0/24') + self.assertEqual(rule.proto, 'tcp') + self.assertEqual(rule.dstports, '443') + self.assertIsNone(rule.expire) + self.assertIsNone(rule.comment) + + def test_006_parse_str_comment(self): + rule_txt = \ + 'action=accept dsthost=qubes-os.org comment=Some comment' + with self.assertNotRaises(ValueError): + rule = qubesmgmt.firewall.Rule(rule_txt) + self.assertEqual(rule.dsthost, 'qubes-os.org') + self.assertIsNone(rule.proto) + self.assertIsNone(rule.dstports) + self.assertIsNone(rule.expire) + self.assertEqual(rule.comment, 'Some comment') + + +class TC_11_Firewall(qubesmgmt.tests.QubesTestCase): + def setUp(self): + super(TC_11_Firewall, self).setUp() + self.app.expected_calls[('dom0', 'mgmt.vm.List', None, None)] = \ + b'0\0test-vm class=AppVM state=Halted\n' + self.vm = self.app.domains['test-vm'] + + def test_000_policy_get(self): + self.app.expected_calls[('test-vm', 'mgmt.vm.firewall.GetPolicy', + None, None)] = b'0\0accept' + policy = self.vm.firewall.policy + self.assertEqual(policy, 'accept') + self.assertEqual(policy, qubesmgmt.firewall.Action('accept')) + self.assertAllCalled() + + def test_001_policy_set(self): + self.app.expected_calls[('test-vm', 'mgmt.vm.firewall.SetPolicy', + None, b'drop')] = b'0\0' + self.vm.firewall.policy = 'drop' + self.assertAllCalled() + + def test_002_policy_set2(self): + self.app.expected_calls[('test-vm', 'mgmt.vm.firewall.SetPolicy', + None, b'drop')] = b'0\0' + self.vm.firewall.policy = qubesmgmt.firewall.Action('drop') + self.assertAllCalled() + + def test_010_load_rules(self): + self.app.expected_calls[('test-vm', 'mgmt.vm.firewall.Get', + None, None)] = \ + b'0\0action=accept dsthost=qubes-os.org\n' \ + b'action=drop proto=icmp\n' + rules = self.vm.firewall.rules + self.assertListEqual(rules, [ + qubesmgmt.firewall.Rule('action=accept dsthost=qubes-os.org'), + qubesmgmt.firewall.Rule('action=drop proto=icmp'), + ]) + # check caching + del self.app.expected_calls[('test-vm', 'mgmt.vm.firewall.Get', + None, None)] + rules2 = self.vm.firewall.rules + self.assertEqual(rules, rules2) + # then force reload + self.app.expected_calls[('test-vm', 'mgmt.vm.firewall.Get', + None, None)] = \ + b'0\0action=accept dsthost=qubes-os.org proto=tcp dstports=443\n' + self.vm.firewall.load_rules() + rules3 = self.vm.firewall.rules + self.assertListEqual(rules3, [ + qubesmgmt.firewall.Rule( + 'action=accept dsthost=qubes-os.org proto=tcp dstports=443')]) + self.assertAllCalled() + + def test_020_set_rules(self): + rules_txt = ( + 'action=accept proto=tcp dsthost=qubes-os.org dstports=443-443', + 'action=accept dsthost=example.com', + ) + rules = [qubesmgmt.firewall.Rule(rule) for rule in rules_txt] + self.app.expected_calls[('test-vm', 'mgmt.vm.firewall.Set', None, + ''.join(rule + '\n' for rule in rules_txt).encode('ascii'))] = b'0\0' + self.vm.firewall.rules = rules + self.assertAllCalled() \ No newline at end of file diff --git a/qubesmgmt/vm/__init__.py b/qubesmgmt/vm/__init__.py index 4f4ee63..8029d72 100644 --- a/qubesmgmt/vm/__init__.py +++ b/qubesmgmt/vm/__init__.py @@ -26,6 +26,7 @@ import qubesmgmt.exc import qubesmgmt.storage import qubesmgmt.features import qubesmgmt.devices +import qubesmgmt.firewall class QubesVM(qubesmgmt.base.PropertyHolder): @@ -37,12 +38,15 @@ class QubesVM(qubesmgmt.base.PropertyHolder): devices = None + firewall = None + def __init__(self, app, name): super(QubesVM, self).__init__(app, 'mgmt.vm.property.', name) self._volumes = None self.log = logging.getLogger(name) self.features = qubesmgmt.features.Features(self) self.devices = qubesmgmt.devices.DeviceManager(self) + self.firewall = qubesmgmt.firewall.Firewall(self) @property def name(self):