diff --git a/qubes/__init__.py b/qubes/__init__.py index 7f933770..15934b2b 100644 --- a/qubes/__init__.py +++ b/qubes/__init__.py @@ -331,11 +331,12 @@ class property(object): # pylint: disable=redefined-builtin,invalid-name # do not treat type='str' as sufficient validation if self.type is not None and self.type is not str: # assume specific type will preform enough validation + try: + untrusted_newvalue = untrusted_newvalue.decode('ascii', + errors='strict') + except UnicodeDecodeError: + raise qubes.exc.QubesValueError if self.type is bool: - try: - untrusted_newvalue = untrusted_newvalue.decode('ascii') - except UnicodeDecodeError: - raise qubes.exc.QubesValueError return self.bool(None, None, untrusted_newvalue) else: try: diff --git a/qubes/api/admin.py b/qubes/api/admin.py index 6fab0166..5a72c1e2 100644 --- a/qubes/api/admin.py +++ b/qubes/api/admin.py @@ -30,6 +30,7 @@ import libvirt import qubes.api import qubes.devices +import qubes.firewall import qubes.storage import qubes.utils import qubes.vm @@ -992,3 +993,38 @@ class QubesAdminAPI(qubes.api.AbstractQubesAPI): dev.backend_domain, dev.ident) self.dest.devices[devclass].detach(assignment) self.app.save() + + @qubes.api.method('admin.vm.firewall.Get', no_payload=True) + @asyncio.coroutine + def vm_firewall_get(self): + assert not self.arg + + self.fire_event_for_permission() + + return ''.join('{}\n'.format(rule.api_rule) + for rule in self.dest.firewall.rules) + + @qubes.api.method('admin.vm.firewall.Set') + @asyncio.coroutine + def vm_firewall_set(self, untrusted_payload): + assert not self.arg + rules = [] + for untrusted_line in untrusted_payload.decode('ascii', + errors='strict').splitlines(): + rule = qubes.firewall.Rule.from_api_string( + untrusted_rule=untrusted_line) + rules.append(rule) + + self.fire_event_for_permission(rules=rules) + + self.dest.firewall.rules = rules + self.dest.firewall.save() + + @qubes.api.method('admin.vm.firewall.Reload', no_payload=True) + @asyncio.coroutine + def vm_firewall_reload(self): + assert not self.arg + + self.fire_event_for_permission() + + self.dest.fire_event('firewall-changed') diff --git a/qubes/firewall.py b/qubes/firewall.py index 502fe7a1..a9d195e8 100644 --- a/qubes/firewall.py +++ b/qubes/firewall.py @@ -22,6 +22,7 @@ # import datetime +import string import subprocess import itertools @@ -34,13 +35,22 @@ import qubes.vm.qubesvm class RuleOption(object): - def __init__(self, value): - self._value = str(value) + def __init__(self, untrusted_value): + # subset of string.punctuation + safe_set = string.ascii_letters + string.digits + \ + ':;,./-_[]' + assert all(x in safe_set for x in str(untrusted_value)) + value = str(untrusted_value) + self._value = value @property def rule(self): raise NotImplementedError + @property + def api_rule(self): + return self.rule + def __str__(self): return self._value @@ -50,14 +60,15 @@ class RuleOption(object): # noinspection PyAbstractClass class RuleChoice(RuleOption): # pylint: disable=abstract-method - def __init__(self, value): - super(RuleChoice, self).__init__(value) + def __init__(self, untrusted_value): + # preliminary validation + super(RuleChoice, self).__init__(untrusted_value) self.allowed_values = \ [v for k, v in self.__class__.__dict__.items() if not k.startswith('__') and isinstance(v, str) and not v.startswith('__')] - if value not in self.allowed_values: - raise ValueError(value) + if untrusted_value not in self.allowed_values: + raise ValueError(untrusted_value) class Action(RuleChoice): @@ -81,14 +92,14 @@ class Proto(RuleChoice): class DstHost(RuleOption): '''Represent host/network address: either IPv4, IPv6, or DNS name''' - def __init__(self, value, prefixlen=None): - # TODO: in python >= 3.3 ipaddress module could be used - if value.count('/') > 1: - raise ValueError('Too many /: ' + value) - elif not value.count('/'): + def __init__(self, untrusted_value, prefixlen=None): + if untrusted_value.count('/') > 1: + raise ValueError('Too many /: ' + untrusted_value) + elif not untrusted_value.count('/'): # add prefix length to bare IP addresses try: - socket.inet_pton(socket.AF_INET6, value) + socket.inet_pton(socket.AF_INET6, untrusted_value) + value = untrusted_value self.prefixlen = prefixlen or 128 if self.prefixlen < 0 or self.prefixlen > 128: raise ValueError( @@ -97,10 +108,11 @@ class DstHost(RuleOption): self.type = 'dst6' except socket.error: try: - socket.inet_pton(socket.AF_INET, value) - if value.count('.') != 3: + socket.inet_pton(socket.AF_INET, untrusted_value) + if untrusted_value.count('.') != 3: raise ValueError( 'Invalid number of dots in IPv4 address') + value = untrusted_value self.prefixlen = prefixlen or 32 if self.prefixlen < 0 or self.prefixlen > 32: raise ValueError( @@ -110,28 +122,33 @@ class DstHost(RuleOption): except socket.error: self.type = 'dsthost' self.prefixlen = 0 + safe_set = string.ascii_lowercase + string.digits + '-._' + assert all(c in safe_set for c in untrusted_value) + value = untrusted_value else: - host, prefixlen = value.split('/', 1) - prefixlen = int(prefixlen) + untrusted_host, untrusted_prefixlen = untrusted_value.split('/', 1) + prefixlen = int(untrusted_prefixlen) if prefixlen < 0: raise ValueError('netmask must be non-negative') self.prefixlen = prefixlen try: - socket.inet_pton(socket.AF_INET6, host) + socket.inet_pton(socket.AF_INET6, untrusted_host) + value = untrusted_value if prefixlen > 128: raise ValueError('netmask for IPv6 must be <= 128') self.type = 'dst6' except socket.error: try: - socket.inet_pton(socket.AF_INET, host) + socket.inet_pton(socket.AF_INET, untrusted_host) if prefixlen > 32: raise ValueError('netmask for IPv4 must be <= 32') self.type = 'dst4' - if host.count('.') != 3: + if untrusted_host.count('.') != 3: raise ValueError( 'Invalid number of dots in IPv4 address') + value = untrusted_value except socket.error: - raise ValueError('Invalid IP address: ' + host) + raise ValueError('Invalid IP address: ' + untrusted_host) super(DstHost, self).__init__(value) @@ -141,15 +158,15 @@ class DstHost(RuleOption): class DstPorts(RuleOption): - def __init__(self, value): - if isinstance(value, int): - value = str(value) - if value.count('-') == 1: - self.range = [int(x) for x in value.split('-', 1)] - elif not value.count('-'): - self.range = [int(value), int(value)] + def __init__(self, untrusted_value): + if isinstance(untrusted_value, int): + untrusted_value = str(untrusted_value) + if untrusted_value.count('-') == 1: + self.range = [int(x) for x in untrusted_value.split('-', 1)] + elif not untrusted_value.count('-'): + self.range = [int(untrusted_value), int(untrusted_value)] else: - raise ValueError(value) + raise ValueError(untrusted_value) if any(port < 0 or port > 65536 for port in self.range): raise ValueError('Ports out of range') if self.range[0] > self.range[1]: @@ -164,11 +181,11 @@ class DstPorts(RuleOption): class IcmpType(RuleOption): - def __init__(self, value): - super(IcmpType, self).__init__(value) - value = int(value) - if value < 0 or value > 255: + def __init__(self, untrusted_value): + untrusted_value = int(untrusted_value) + if untrusted_value < 0 or untrusted_value > 255: raise ValueError('ICMP type out of range') + super(IcmpType, self).__init__(untrusted_value) @property def rule(self): @@ -184,24 +201,42 @@ class SpecialTarget(RuleChoice): class Expire(RuleOption): - def __init__(self, value): - super(Expire, self).__init__(value) - self.datetime = datetime.datetime.utcfromtimestamp(int(value)) + def __init__(self, untrusted_value): + super(Expire, self).__init__(untrusted_value) + self.datetime = datetime.datetime.utcfromtimestamp(int(untrusted_value)) @property def rule(self): return None + @property + def api_rule(self): + return 'expire=' + str(self) + @property def expired(self): return self.datetime < datetime.datetime.utcnow() class Comment(RuleOption): + # noinspection PyMissingConstructor + def __init__(self, untrusted_value): + # pylint: disable=super-init-not-called + # subset of string.punctuation + safe_set = string.ascii_letters + string.digits + \ + ':;,./-_[] ' + assert all(x in safe_set for x in str(untrusted_value)) + value = str(untrusted_value) + self._value = value + @property def rule(self): return None + @property + def api_rule(self): + return 'comment=' + str(self) + class Rule(qubes.PropertyHolder): def __init__(self, xml=None, **kwargs): @@ -311,6 +346,20 @@ class Rule(qubes.PropertyHolder): values.append(value.rule) return ' '.join(values) + @property + def api_rule(self): + values = [] + # put comment at the end + for prop in sorted(self.property_list(), + key=(lambda p: p.__name__ == 'comment')): + value = getattr(self, prop.__name__) + if value is None: + continue + if value.api_rule is None: + continue + values.append(value.api_rule) + return ' '.join(values) + @classmethod def from_xml_v1(cls, node, action): netmask = node.get('netmask') @@ -358,8 +407,43 @@ class Rule(qubes.PropertyHolder): return cls(**kwargs) + @classmethod + def from_api_string(cls, untrusted_rule): + '''Parse a single line of firewall rule''' + # comment is allowed to have spaces + untrusted_options, _, untrusted_comment = untrusted_rule.partition( + 'comment=') + # appropriate handlers in __init__ of individual options will perform + # option-specific validation + kwargs = {} + if untrusted_comment: + kwargs['comment'] = Comment(untrusted_value=untrusted_comment) + + for untrusted_option in untrusted_options.strip().split(' '): + untrusted_key, untrusted_value = untrusted_option.split('=', 1) + if untrusted_key in kwargs: + raise ValueError('Option \'{}\' already set'.format( + untrusted_key)) + if untrusted_key in [str(prop) for prop in cls.property_list()]: + kwargs[untrusted_key] = cls.property_get_def( + untrusted_key).type(untrusted_value=untrusted_value) + elif untrusted_key in ('dst4', 'dst6', 'dstname'): + if 'dsthost' in kwargs: + raise ValueError('Option \'{}\' already set'.format( + 'dsthost')) + kwargs['dsthost'] = DstHost(untrusted_value=untrusted_value) + else: + raise ValueError('Unknown firewall option') + + return cls(**kwargs) + def __eq__(self, other): - return self.rule == other.rule + if isinstance(other, Rule): + return self.api_rule == other.api_rule + return self.api_rule == str(other) + + def __hash__(self): + return hash(self.api_rule) class Firewall(object): @@ -368,21 +452,23 @@ class Firewall(object): self.vm = vm #: firewall rules self.rules = [] - #: default action - self.policy = None if load: self.load() + @property + def policy(self): + ''' Default action - always 'drop' ''' + return Action('drop') + def __eq__(self, other): if isinstance(other, Firewall): - return self.policy == other.policy and self.rules == other.rules + return self.rules == other.rules return NotImplemented def load_defaults(self): '''Load default firewall settings''' - self.rules = [] - self.policy = Action('accept') + self.rules = [Rule(None, action='accept')] def clone(self, other): '''Clone firewall settings from other instance. @@ -390,7 +476,6 @@ class Firewall(object): :param other: other :py:class:`Firewall` instance ''' - self.policy = other.policy rules = [] for rule in other.rules: new_rule = Rule() @@ -421,10 +506,7 @@ class Firewall(object): '''Load old (Qubes < 4.0) firewall XML format''' policy_v1 = xml_root.get('policy') assert policy_v1 in ('allow', 'deny') - if policy_v1 == 'allow': - self.policy = Action('accept') - else: - self.policy = Action('drop') + default_policy_is_accept = (policy_v1 == 'allow') def _translate_action(key): if xml_root.get(key, policy_v1) == 'allow': @@ -439,7 +521,7 @@ class Firewall(object): action=_translate_action('icmp'), proto=Proto.icmp)) - if self.policy == Action.accept: + if default_policy_is_accept: rule_action = Action.drop else: rule_action = Action.accept @@ -447,11 +529,11 @@ class Firewall(object): for element in xml_root: rule = Rule.from_xml_v1(element, rule_action) self.rules.append(rule) + if default_policy_is_accept: + self.rules.append(Rule(None, action='accept')) def load_v2(self, xml_root): '''Load new (Qubes >= 4.0) firewall XML format''' - self.policy = Action(xml_root.findtext('policy')) - xml_rules = xml_root.find('rules') for xml_rule in xml_rules: rule = Rule(xml_rule) @@ -464,10 +546,6 @@ class Firewall(object): xml_root = lxml.etree.Element('firewall', version=str(2)) - xml_policy = lxml.etree.Element('policy') - xml_policy.text = str(self.policy) - xml_root.append(xml_policy) - xml_rules = lxml.etree.Element('rules') for rule in self.rules: if rule.expire: @@ -499,7 +577,6 @@ class Firewall(object): subprocess.call(["sudo", "systemctl", "start", "qubes-reload-firewall@%s.timer" % self.vm.name]) - def qdb_entries(self, addr_family=None): '''Return firewall settings serialized for QubesDB entries diff --git a/qubes/tests/__init__.py b/qubes/tests/__init__.py index b4c2ec51..3d19249e 100644 --- a/qubes/tests/__init__.py +++ b/qubes/tests/__init__.py @@ -151,7 +151,9 @@ class TestEmitter(qubes.events.Emitter): effects = super(TestEmitter, self).fire_event(event, **kwargs) ev_kwargs = frozenset( (key, - frozenset(value.items()) if isinstance(value, dict) else value) + frozenset(value.items()) if isinstance(value, dict) + else tuple(value) if isinstance(value, list) + else value) for key, value in kwargs.items() ) self.fired_events[(event, ev_kwargs)] += 1 @@ -161,7 +163,9 @@ class TestEmitter(qubes.events.Emitter): effects = super(TestEmitter, self).fire_event_pre(event, **kwargs) ev_kwargs = frozenset( (key, - frozenset(value.items()) if isinstance(value, dict) else value) + frozenset(value.items()) if isinstance(value, dict) + else tuple(value) if isinstance(value, list) + else value) for key, value in kwargs.items() ) self.fired_events[(event, ev_kwargs)] += 1 diff --git a/qubes/tests/api_admin.py b/qubes/tests/api_admin.py index 8a33eae9..8b6eadcf 100644 --- a/qubes/tests/api_admin.py +++ b/qubes/tests/api_admin.py @@ -30,6 +30,7 @@ import libvirt import qubes import qubes.devices +import qubes.firewall import qubes.api.admin import qubes.tests import qubes.storage @@ -1677,6 +1678,131 @@ class TC_00_VMs(AdminAPITestCase): self.assertNotIn('+.some-tag', self.vm.tags) self.assertFalse(self.app.save.called) + def test_570_firewall_get(self): + self.vm.firewall.save = unittest.mock.Mock() + value = self.call_mgmt_func(b'admin.vm.firewall.Get', + b'test-vm1', b'') + self.assertEqual(value, 'action=accept\n') + self.assertFalse(self.vm.firewall.save.called) + self.assertFalse(self.app.save.called) + + def test_571_firewall_get_non_default(self): + self.vm.firewall.save = unittest.mock.Mock() + self.vm.firewall.rules = [ + qubes.firewall.Rule(action='accept', proto='tcp', + dstports='1-1024'), + qubes.firewall.Rule(action='drop', proto='icmp', + comment='No ICMP'), + qubes.firewall.Rule(action='drop', proto='udp', + expire='1499450306'), + qubes.firewall.Rule(action='accept'), + ] + value = self.call_mgmt_func(b'admin.vm.firewall.Get', + b'test-vm1', b'') + self.assertEqual(value, + 'action=accept proto=tcp dstports=1-1024\n' + 'action=drop proto=icmp comment=No ICMP\n' + 'action=drop expire=1499450306 proto=udp\n' + 'action=accept\n') + self.assertFalse(self.vm.firewall.save.called) + self.assertFalse(self.app.save.called) + + def test_580_firewall_set_simple(self): + self.vm.firewall.save = unittest.mock.Mock() + value = self.call_mgmt_func(b'admin.vm.firewall.Set', + b'test-vm1', b'', b'action=accept\n') + self.assertEqual(self.vm.firewall.rules, + ['action=accept']) + self.assertTrue(self.vm.firewall.save.called) + self.assertFalse(self.app.save.called) + + def test_581_firewall_set_multi(self): + self.vm.firewall.save = unittest.mock.Mock() + rules = [ + qubes.firewall.Rule(action='accept', proto='tcp', + dstports='1-1024'), + qubes.firewall.Rule(action='drop', proto='icmp', + comment='No ICMP'), + qubes.firewall.Rule(action='drop', proto='udp', + expire='1499450306'), + qubes.firewall.Rule(action='accept'), + ] + rules_txt = ( + 'action=accept proto=tcp dstports=1-1024\n' + 'action=drop proto=icmp comment=No ICMP\n' + 'action=drop expire=1499450306 proto=udp\n' + 'action=accept\n') + value = self.call_mgmt_func(b'admin.vm.firewall.Set', + b'test-vm1', b'', rules_txt.encode()) + self.assertEqual(self.vm.firewall.rules, rules) + self.assertTrue(self.vm.firewall.save.called) + self.assertFalse(self.app.save.called) + + def test_582_firewall_set_invalid(self): + self.vm.firewall.save = unittest.mock.Mock() + rules_txt = ( + 'action=accept protoxyz=tcp dst4=127.0.0.1\n' + 'action=drop\n') + with self.assertRaises(ValueError): + self.call_mgmt_func(b'admin.vm.firewall.Set', + b'test-vm1', b'', rules_txt.encode()) + self.assertEqual(self.vm.firewall.rules, + [qubes.firewall.Rule(action='accept')]) + self.assertFalse(self.vm.firewall.save.called) + self.assertFalse(self.app.save.called) + + def test_583_firewall_set_invalid(self): + self.vm.firewall.save = unittest.mock.Mock() + rules_txt = ( + 'proto=tcp dstports=1-1024\n' + 'action=drop\n') + with self.assertRaises(AssertionError): + self.call_mgmt_func(b'admin.vm.firewall.Set', + b'test-vm1', b'', rules_txt.encode()) + self.assertEqual(self.vm.firewall.rules, + [qubes.firewall.Rule(action='accept')]) + self.assertFalse(self.vm.firewall.save.called) + self.assertFalse(self.app.save.called) + + def test_584_firewall_set_invalid(self): + self.vm.firewall.save = unittest.mock.Mock() + rules_txt = ( + 'action=accept proto=tcp dstports=1-1024 ' + 'action=drop\n') + with self.assertRaises(ValueError): + self.call_mgmt_func(b'admin.vm.firewall.Set', + b'test-vm1', b'', rules_txt.encode()) + self.assertEqual(self.vm.firewall.rules, + [qubes.firewall.Rule(action='accept')]) + self.assertFalse(self.vm.firewall.save.called) + self.assertFalse(self.app.save.called) + + def test_585_firewall_set_invalid(self): + self.vm.firewall.save = unittest.mock.Mock() + rules_txt = ( + 'action=accept dstports=1-1024 comment=ążźł\n' + 'action=drop\n') + with self.assertRaises(UnicodeDecodeError): + self.call_mgmt_func(b'admin.vm.firewall.Set', + b'test-vm1', b'', rules_txt.encode()) + self.assertEqual(self.vm.firewall.rules, + [qubes.firewall.Rule(action='accept')]) + self.assertFalse(self.vm.firewall.save.called) + self.assertFalse(self.app.save.called) + + def test_590_firewall_reload(self): + self.vm.firewall.save = unittest.mock.Mock() + self.app.domains['test-vm1'].fire_event = self.emitter.fire_event + self.app.domains['test-vm1'].fire_event_pre = \ + self.emitter.fire_event_pre + value = self.call_mgmt_func(b'admin.vm.firewall.Reload', + b'test-vm1', b'') + self.assertIsNone(value) + self.assertEventFired(self.emitter, 'firewall-changed') + self.assertFalse(self.vm.firewall.save.called) + self.assertFalse(self.app.save.called) + + def test_990_vm_unexpected_payload(self): methods_with_no_payload = [ b'admin.vm.List', @@ -1695,8 +1821,7 @@ class TC_00_VMs(AdminAPITestCase): b'admin.vm.tag.Remove', b'admin.vm.tag.Set', b'admin.vm.firewall.Get', - b'admin.vm.firewall.RemoveRule', - b'admin.vm.firewall.Flush', + b'admin.vm.firewall.Reload', b'admin.vm.device.pci.Attach', b'admin.vm.device.pci.Detach', b'admin.vm.device.pci.List', @@ -1748,8 +1873,9 @@ class TC_00_VMs(AdminAPITestCase): b'admin.vm.property.List', b'admin.vm.feature.List', b'admin.vm.tag.List', - b'admin.vm.firewall.List', - b'admin.vm.firewall.Flush', + b'admin.vm.firewall.Get', + b'admin.vm.firewall.Set', + b'admin.vm.firewall.Reload', b'admin.vm.microphone.Attach', b'admin.vm.microphone.Detach', b'admin.vm.microphone.Status', @@ -1950,9 +2076,8 @@ class TC_00_VMs(AdminAPITestCase): b'admin.vm.tag.Remove', b'admin.vm.tag.Set', b'admin.vm.firewall.Get', - b'admin.vm.firewall.RemoveRule', - b'admin.vm.firewall.InsertRule', - b'admin.vm.firewall.Flush', + b'admin.vm.firewall.Set', + b'admin.vm.firewall.Reload', b'admin.vm.device.pci.Attach', b'admin.vm.device.pci.Detach', b'admin.vm.device.pci.List', diff --git a/qubes/tests/firewall.py b/qubes/tests/firewall.py index e17a616d..f4594d45 100644 --- a/qubes/tests/firewall.py +++ b/qubes/tests/firewall.py @@ -80,6 +80,7 @@ class TC_01_Action(qubes.tests.QubesTestCase): def test_001_rule(self): instance = qubes.firewall.Action('accept') self.assertEqual(instance.rule, 'action=accept') + self.assertEqual(instance.api_rule, 'action=accept') # noinspection PyPep8Naming @@ -93,6 +94,7 @@ class TC_02_Proto(qubes.tests.QubesTestCase): def test_001_rule(self): instance = qubes.firewall.Proto('tcp') self.assertEqual(instance.rule, 'proto=tcp') + self.assertEqual(instance.api_rule, 'proto=tcp') # noinspection PyPep8Naming @@ -157,6 +159,7 @@ class TC_02_DstHost(qubes.tests.QubesTestCase): self.assertEqual(instance.prefixlen, 128) self.assertEqual(str(instance), '2001:abcd:efab::3/128') self.assertEqual(instance.rule, 'dst6=2001:abcd:efab::3/128') + self.assertEqual(instance.api_rule, 'dst6=2001:abcd:efab::3/128') def test_011_ipv6_prefixlen(self): with self.assertNotRaises(ValueError): @@ -165,6 +168,7 @@ class TC_02_DstHost(qubes.tests.QubesTestCase): self.assertEqual(instance.prefixlen, 64) self.assertEqual(str(instance), '2001:abcd:efab::/64') self.assertEqual(instance.rule, 'dst6=2001:abcd:efab::/64') + self.assertEqual(instance.api_rule, 'dst6=2001:abcd:efab::/64') def test_012_ipv6_parse_prefixlen(self): with self.assertNotRaises(ValueError): @@ -173,6 +177,7 @@ class TC_02_DstHost(qubes.tests.QubesTestCase): self.assertEqual(instance.prefixlen, 64) self.assertEqual(str(instance), '2001:abcd:efab::/64') self.assertEqual(instance.rule, 'dst6=2001:abcd:efab::/64') + self.assertEqual(instance.api_rule, 'dst6=2001:abcd:efab::/64') def test_013_ipv6_invalid_prefix(self): with self.assertRaises(ValueError): @@ -211,6 +216,7 @@ class TC_03_DstPorts(qubes.tests.QubesTestCase): self.assertEqual(str(instance), '80') self.assertEqual(instance.range, [80, 80]) self.assertEqual(instance.rule, 'dstports=80-80') + self.assertEqual(instance.api_rule, 'dstports=80-80') def test_001_single_int(self): with self.assertNotRaises(ValueError): @@ -218,6 +224,7 @@ class TC_03_DstPorts(qubes.tests.QubesTestCase): self.assertEqual(str(instance), '80') self.assertEqual(instance.range, [80, 80]) self.assertEqual(instance.rule, 'dstports=80-80') + self.assertEqual(instance.api_rule, 'dstports=80-80') def test_002_range(self): with self.assertNotRaises(ValueError): @@ -261,6 +268,7 @@ class TC_04_IcmpType(qubes.tests.QubesTestCase): instance = qubes.firewall.IcmpType('8') self.assertEqual(str(instance), '8') self.assertEqual(instance.rule, 'icmptype=8') + self.assertEqual(instance.api_rule, 'icmptype=8') def test_002_invalid(self): with self.assertRaises(ValueError): @@ -283,6 +291,7 @@ class TC_05_SpecialTarget(qubes.tests.QubesTestCase): def test_001_rule(self): instance = qubes.firewall.SpecialTarget('dns') self.assertEqual(instance.rule, 'specialtarget=dns') + self.assertEqual(instance.api_rule, 'specialtarget=dns') class TC_06_Expire(qubes.tests.QubesTestCase): @@ -290,6 +299,7 @@ class TC_06_Expire(qubes.tests.QubesTestCase): with self.assertNotRaises(ValueError): instance = qubes.firewall.Expire(1463292452) self.assertEqual(str(instance), '1463292452') + self.assertEqual(instance.api_rule, 'expire=1463292452') self.assertEqual(instance.datetime, datetime.datetime(2016, 5, 15, 6, 7, 32)) self.assertIsNone(instance.rule) @@ -322,6 +332,7 @@ class TC_07_Comment(qubes.tests.QubesTestCase): with self.assertNotRaises(ValueError): instance = qubes.firewall.Comment('Some comment') self.assertEqual(str(instance), 'Some comment') + self.assertEqual(instance.api_rule, 'comment=Some comment') self.assertIsNone(instance.rule) @@ -445,6 +456,33 @@ class TC_08_Rule(qubes.tests.QubesTestCase): self.assertIsNone(rule.proto) self.assertIsNone(rule.dstports) + def test_008_from_api_string(self): + rule_txt = 'action=drop proto=tcp dstports=80-80' + with self.assertNotRaises(ValueError): + rule = qubes.firewall.Rule.from_api_string( + rule_txt) + self.assertEqual(rule.dstports.range, [80, 80]) + self.assertEqual(rule.proto, 'tcp') + self.assertEqual(rule.action, 'drop') + self.assertIsNone(rule.dsthost) + self.assertIsNone(rule.expire) + self.assertIsNone(rule.comment) + self.assertEqual(rule.api_rule, rule_txt) + + def test_009_from_api_string(self): + rule_txt = 'action=accept expire=1463292452 proto=tcp ' \ + 'comment=Some comment, with spaces' + with self.assertNotRaises(ValueError): + rule = qubes.firewall.Rule.from_api_string( + rule_txt) + self.assertEqual(rule.comment, 'Some comment, with spaces') + self.assertEqual(rule.proto, 'tcp') + self.assertEqual(rule.action, 'accept') + self.assertEqual(rule.expire, '1463292452') + self.assertIsNone(rule.dstports) + self.assertIsNone(rule.dsthost) + self.assertEqual(rule.api_rule, rule_txt) + class TC_10_Firewall(qubes.tests.QubesTestCase): def setUp(self): @@ -463,17 +501,17 @@ class TC_10_Firewall(qubes.tests.QubesTestCase): def test_000_defaults(self): fw = qubes.firewall.Firewall(self.vm, False) fw.load_defaults() - self.assertEqual(fw.policy, 'accept') - self.assertEqual(fw.rules, []) + self.assertEqual(fw.policy, 'drop') + self.assertEqual(fw.rules, [qubes.firewall.Rule(None, action='accept')]) def test_001_save_load_empty(self): fw = qubes.firewall.Firewall(self.vm, True) - self.assertEqual(fw.policy, 'accept') - self.assertEqual(fw.rules, []) + self.assertEqual(fw.policy, 'drop') + self.assertEqual(fw.rules, [qubes.firewall.Rule(None, action='accept')]) fw.save() fw.load() - self.assertEqual(fw.policy, 'accept') - self.assertEqual(fw.rules, []) + self.assertEqual(fw.policy, 'drop') + self.assertEqual(fw.rules, [qubes.firewall.Rule(None, action='accept')]) def test_002_save_load_rules(self): fw = qubes.firewall.Firewall(self.vm, True) @@ -485,13 +523,13 @@ class TC_10_Firewall(qubes.tests.QubesTestCase): qubes.firewall.Rule(None, action='accept', specialtarget='dns'), ] fw.rules.extend(rules) - fw.policy = qubes.firewall.Action.drop fw.save() self.assertTrue(os.path.exists(os.path.join( self.vm.dir_path, self.vm.firewall_conf))) fw = qubes.firewall.Firewall(TestVM(), True) self.assertEqual(fw.policy, qubes.firewall.Action.drop) - self.assertEqual(fw.rules, rules) + self.assertEqual(fw.rules, + [qubes.firewall.Rule(None, action='accept')] + rules) def test_003_load_v1(self): xml_txt = """