diff --git a/qubes/api/admin.py b/qubes/api/admin.py index ce198f19..10445888 100644 --- a/qubes/api/admin.py +++ b/qubes/api/admin.py @@ -30,6 +30,7 @@ import libvirt import qubes.api import qubes.devices +import qubes.firewall import qubes.storage import qubes.utils import qubes.vm @@ -991,3 +992,37 @@ class QubesAdminAPI(qubes.api.AbstractQubesAPI): dev.backend_domain, dev.ident) self.dest.devices[devclass].detach(assignment) self.app.save() + + @qubes.api.method('admin.vm.firewall.Get', no_payload=True) + @asyncio.coroutine + def vm_firewall_get(self): + assert not self.arg + + self.fire_event_for_permission() + + return ''.join('{}\n'.format(rule.api_rule) + for rule in self.dest.firewall.rules) + + @qubes.api.method('admin.vm.firewall.Set') + @asyncio.coroutine + def vm_firewall_set(self, untrusted_payload): + assert not self.arg + rules = [] + for untrusted_line in untrusted_payload.decode('ascii', + errors='strict').splitlines(): + rule = qubes.firewall.Rule.from_api_string(untrusted_line) + rules.append(rule) + + self.fire_event_for_permission(rules=rules) + + self.dest.firewall.rules = rules + self.dest.firewall.save() + + @qubes.api.method('admin.vm.firewall.Reload', no_payload=True) + @asyncio.coroutine + def vm_firewall_reload(self): + assert not self.arg + + self.fire_event_for_permission() + + self.dest.fire_event('firewall-changed') diff --git a/qubes/firewall.py b/qubes/firewall.py index 1748047d..f2febf46 100644 --- a/qubes/firewall.py +++ b/qubes/firewall.py @@ -22,6 +22,7 @@ # import datetime +import string import subprocess import itertools @@ -34,13 +35,22 @@ import qubes.vm.qubesvm class RuleOption(object): - def __init__(self, value): - self._value = str(value) + def __init__(self, untrusted_value): + # subset of string.punctuation + safe_set = string.ascii_letters + string.digits + \ + ':;,./-_[]' + assert all(x in safe_set for x in str(untrusted_value)) + value = str(untrusted_value) + self._value = value @property def rule(self): raise NotImplementedError + @property + def api_rule(self): + return self.rule + def __str__(self): return self._value @@ -50,14 +60,15 @@ class RuleOption(object): # noinspection PyAbstractClass class RuleChoice(RuleOption): # pylint: disable=abstract-method - def __init__(self, value): - super(RuleChoice, self).__init__(value) + def __init__(self, untrusted_value): + # preliminary validation + super(RuleChoice, self).__init__(untrusted_value) self.allowed_values = \ [v for k, v in self.__class__.__dict__.items() if not k.startswith('__') and isinstance(v, str) and not v.startswith('__')] - if value not in self.allowed_values: - raise ValueError(value) + if untrusted_value not in self.allowed_values: + raise ValueError(untrusted_value) class Action(RuleChoice): @@ -81,14 +92,14 @@ class Proto(RuleChoice): class DstHost(RuleOption): '''Represent host/network address: either IPv4, IPv6, or DNS name''' - def __init__(self, value, prefixlen=None): - # TODO: in python >= 3.3 ipaddress module could be used - if value.count('/') > 1: - raise ValueError('Too many /: ' + value) - elif not value.count('/'): + def __init__(self, untrusted_value, prefixlen=None): + if untrusted_value.count('/') > 1: + raise ValueError('Too many /: ' + untrusted_value) + elif not untrusted_value.count('/'): # add prefix length to bare IP addresses try: - socket.inet_pton(socket.AF_INET6, value) + socket.inet_pton(socket.AF_INET6, untrusted_value) + value = untrusted_value self.prefixlen = prefixlen or 128 if self.prefixlen < 0 or self.prefixlen > 128: raise ValueError( @@ -97,10 +108,11 @@ class DstHost(RuleOption): self.type = 'dst6' except socket.error: try: - socket.inet_pton(socket.AF_INET, value) - if value.count('.') != 3: + socket.inet_pton(socket.AF_INET, untrusted_value) + if untrusted_value.count('.') != 3: raise ValueError( 'Invalid number of dots in IPv4 address') + value = untrusted_value self.prefixlen = prefixlen or 32 if self.prefixlen < 0 or self.prefixlen > 32: raise ValueError( @@ -110,28 +122,33 @@ class DstHost(RuleOption): except socket.error: self.type = 'dsthost' self.prefixlen = 0 + safe_set = string.ascii_lowercase + string.digits + '-._' + assert all(c in safe_set for c in untrusted_value) + value = untrusted_value else: - host, prefixlen = value.split('/', 1) - prefixlen = int(prefixlen) + untrusted_host, untrusted_prefixlen = untrusted_value.split('/', 1) + prefixlen = int(untrusted_prefixlen) if prefixlen < 0: raise ValueError('netmask must be non-negative') self.prefixlen = prefixlen try: - socket.inet_pton(socket.AF_INET6, host) + socket.inet_pton(socket.AF_INET6, untrusted_host) + value = untrusted_value if prefixlen > 128: raise ValueError('netmask for IPv6 must be <= 128') self.type = 'dst6' except socket.error: try: - socket.inet_pton(socket.AF_INET, host) + socket.inet_pton(socket.AF_INET, untrusted_host) if prefixlen > 32: raise ValueError('netmask for IPv4 must be <= 32') self.type = 'dst4' - if host.count('.') != 3: + if untrusted_host.count('.') != 3: raise ValueError( 'Invalid number of dots in IPv4 address') + value = untrusted_value except socket.error: - raise ValueError('Invalid IP address: ' + host) + raise ValueError('Invalid IP address: ' + untrusted_host) super(DstHost, self).__init__(value) @@ -141,15 +158,15 @@ class DstHost(RuleOption): class DstPorts(RuleOption): - def __init__(self, value): - if isinstance(value, int): - value = str(value) - if value.count('-') == 1: - self.range = [int(x) for x in value.split('-', 1)] - elif not value.count('-'): - self.range = [int(value), int(value)] + def __init__(self, untrusted_value): + if isinstance(untrusted_value, int): + untrusted_value = str(untrusted_value) + if untrusted_value.count('-') == 1: + self.range = [int(x) for x in untrusted_value.split('-', 1)] + elif not untrusted_value.count('-'): + self.range = [int(untrusted_value), int(untrusted_value)] else: - raise ValueError(value) + raise ValueError(untrusted_value) if any(port < 0 or port > 65536 for port in self.range): raise ValueError('Ports out of range') if self.range[0] > self.range[1]: @@ -164,11 +181,11 @@ class DstPorts(RuleOption): class IcmpType(RuleOption): - def __init__(self, value): - super(IcmpType, self).__init__(value) - value = int(value) - if value < 0 or value > 255: + def __init__(self, untrusted_value): + untrusted_value = int(untrusted_value) + if untrusted_value < 0 or untrusted_value > 255: raise ValueError('ICMP type out of range') + super(IcmpType, self).__init__(untrusted_value) @property def rule(self): @@ -184,24 +201,42 @@ class SpecialTarget(RuleChoice): class Expire(RuleOption): - def __init__(self, value): - super(Expire, self).__init__(value) - self.datetime = datetime.datetime.utcfromtimestamp(int(value)) + def __init__(self, untrusted_value): + super(Expire, self).__init__(untrusted_value) + self.datetime = datetime.datetime.utcfromtimestamp(int(untrusted_value)) @property def rule(self): return None + @property + def api_rule(self): + return 'expire=' + str(self) + @property def expired(self): return self.datetime < datetime.datetime.utcnow() class Comment(RuleOption): + # noinspection PyMissingConstructor + def __init__(self, untrusted_value): + # pylint: disable=super-init-not-called + # subset of string.punctuation + safe_set = string.ascii_letters + string.digits + \ + ':;,./-_[] ' + assert all(x in safe_set for x in str(untrusted_value)) + value = str(untrusted_value) + self._value = value + @property def rule(self): return None + @property + def api_rule(self): + return 'comment=' + str(self) + class Rule(qubes.PropertyHolder): def __init__(self, xml=None, **kwargs): @@ -311,6 +346,20 @@ class Rule(qubes.PropertyHolder): values.append(value.rule) return ' '.join(values) + @property + def api_rule(self): + values = [] + # put comment at the end + for prop in sorted(self.property_list(), + key=(lambda p: p.__name__ == 'comment')): + value = getattr(self, prop.__name__) + if value is None: + continue + if value.api_rule is None: + continue + values.append(value.api_rule) + return ' '.join(values) + @classmethod def from_xml_v1(cls, node, action): netmask = node.get('netmask') @@ -358,8 +407,39 @@ class Rule(qubes.PropertyHolder): return cls(**kwargs) + @classmethod + def from_api_string(cls, untrusted_rule): + '''Parse a single line of firewall rule''' + # comment is allowed to have spaces + untrusted_options, _, untrusted_comment = untrusted_rule.partition( + 'comment=') + # appropriate handlers in __init__ of individual options will perform + # option-specific validation + kwargs = {} + if untrusted_comment: + kwargs['comment'] = untrusted_comment + + for untrusted_option in untrusted_options.strip().split(' '): + untrusted_key, untrusted_value = untrusted_option.split('=', 1) + if untrusted_key in kwargs: + raise ValueError('Option \'{}\' already set'.format( + untrusted_key)) + if untrusted_key in [str(prop) for prop in cls.property_list()]: + kwargs[untrusted_key] = untrusted_value + elif untrusted_key in ('dst4', 'dst6', 'dstname'): + kwargs['dsthost'] = untrusted_value + else: + raise ValueError('Unknown firewall option') + + return cls(**kwargs) + def __eq__(self, other): - return self.rule == other.rule + if isinstance(other, Rule): + return self.api_rule == other.api_rule + return self.api_rule == str(other) + + def __hash__(self): + return hash(self.api_rule) class Firewall(object): @@ -496,7 +576,6 @@ class Firewall(object): subprocess.call(["sudo", "systemctl", "start", "qubes-reload-firewall@%s.timer" % self.vm.name]) - def qdb_entries(self, addr_family=None): '''Return firewall settings serialized for QubesDB entries diff --git a/qubes/tests/__init__.py b/qubes/tests/__init__.py index b4c2ec51..3d19249e 100644 --- a/qubes/tests/__init__.py +++ b/qubes/tests/__init__.py @@ -151,7 +151,9 @@ class TestEmitter(qubes.events.Emitter): effects = super(TestEmitter, self).fire_event(event, **kwargs) ev_kwargs = frozenset( (key, - frozenset(value.items()) if isinstance(value, dict) else value) + frozenset(value.items()) if isinstance(value, dict) + else tuple(value) if isinstance(value, list) + else value) for key, value in kwargs.items() ) self.fired_events[(event, ev_kwargs)] += 1 @@ -161,7 +163,9 @@ class TestEmitter(qubes.events.Emitter): effects = super(TestEmitter, self).fire_event_pre(event, **kwargs) ev_kwargs = frozenset( (key, - frozenset(value.items()) if isinstance(value, dict) else value) + frozenset(value.items()) if isinstance(value, dict) + else tuple(value) if isinstance(value, list) + else value) for key, value in kwargs.items() ) self.fired_events[(event, ev_kwargs)] += 1 diff --git a/qubes/tests/api_admin.py b/qubes/tests/api_admin.py index fde83162..f2453050 100644 --- a/qubes/tests/api_admin.py +++ b/qubes/tests/api_admin.py @@ -21,6 +21,7 @@ ''' Tests for management calls endpoints ''' import asyncio +import operator import os import shutil import unittest.mock @@ -29,6 +30,7 @@ import libvirt import qubes import qubes.devices +import qubes.firewall import qubes.api.admin import qubes.tests @@ -1734,6 +1736,131 @@ class TC_00_VMs(AdminAPITestCase): self.assertNotIn('+.some-tag', self.vm.tags) self.assertFalse(self.app.save.called) + def test_570_firewall_get(self): + self.vm.firewall.save = unittest.mock.Mock() + value = self.call_mgmt_func(b'admin.vm.firewall.Get', + b'test-vm1', b'') + self.assertEqual(value, 'action=accept\n') + self.assertFalse(self.vm.firewall.save.called) + self.assertFalse(self.app.save.called) + + def test_571_firewall_get_non_default(self): + self.vm.firewall.save = unittest.mock.Mock() + self.vm.firewall.rules = [ + qubes.firewall.Rule(action='accept', proto='tcp', + dstports='1-1024'), + qubes.firewall.Rule(action='drop', proto='icmp', + comment='No ICMP'), + qubes.firewall.Rule(action='drop', proto='udp', + expire='1499450306'), + qubes.firewall.Rule(action='accept'), + ] + value = self.call_mgmt_func(b'admin.vm.firewall.Get', + b'test-vm1', b'') + self.assertEqual(value, + 'action=accept proto=tcp dstports=1-1024\n' + 'action=drop proto=icmp comment=No ICMP\n' + 'action=drop expire=1499450306 proto=udp\n' + 'action=accept\n') + self.assertFalse(self.vm.firewall.save.called) + self.assertFalse(self.app.save.called) + + def test_580_firewall_set_simple(self): + self.vm.firewall.save = unittest.mock.Mock() + value = self.call_mgmt_func(b'admin.vm.firewall.Set', + b'test-vm1', b'', b'action=accept\n') + self.assertEqual(self.vm.firewall.rules, + ['action=accept']) + self.assertTrue(self.vm.firewall.save.called) + self.assertFalse(self.app.save.called) + + def test_581_firewall_set_multi(self): + self.vm.firewall.save = unittest.mock.Mock() + rules = [ + qubes.firewall.Rule(action='accept', proto='tcp', + dstports='1-1024'), + qubes.firewall.Rule(action='drop', proto='icmp', + comment='No ICMP'), + qubes.firewall.Rule(action='drop', proto='udp', + expire='1499450306'), + qubes.firewall.Rule(action='accept'), + ] + rules_txt = ( + 'action=accept proto=tcp dstports=1-1024\n' + 'action=drop proto=icmp comment=No ICMP\n' + 'action=drop expire=1499450306 proto=udp\n' + 'action=accept\n') + value = self.call_mgmt_func(b'admin.vm.firewall.Set', + b'test-vm1', b'', rules_txt.encode()) + self.assertEqual(self.vm.firewall.rules, rules) + self.assertTrue(self.vm.firewall.save.called) + self.assertFalse(self.app.save.called) + + def test_582_firewall_set_invalid(self): + self.vm.firewall.save = unittest.mock.Mock() + rules_txt = ( + 'action=accept protoxyz=tcp dst4=127.0.0.1\n' + 'action=drop\n') + with self.assertRaises(ValueError): + self.call_mgmt_func(b'admin.vm.firewall.Set', + b'test-vm1', b'', rules_txt.encode()) + self.assertEqual(self.vm.firewall.rules, + [qubes.firewall.Rule(action='accept')]) + self.assertFalse(self.vm.firewall.save.called) + self.assertFalse(self.app.save.called) + + def test_583_firewall_set_invalid(self): + self.vm.firewall.save = unittest.mock.Mock() + rules_txt = ( + 'proto=tcp dstports=1-1024\n' + 'action=drop\n') + with self.assertRaises(AssertionError): + self.call_mgmt_func(b'admin.vm.firewall.Set', + b'test-vm1', b'', rules_txt.encode()) + self.assertEqual(self.vm.firewall.rules, + [qubes.firewall.Rule(action='accept')]) + self.assertFalse(self.vm.firewall.save.called) + self.assertFalse(self.app.save.called) + + def test_584_firewall_set_invalid(self): + self.vm.firewall.save = unittest.mock.Mock() + rules_txt = ( + 'action=accept proto=tcp dstports=1-1024 ' + 'action=drop\n') + with self.assertRaises(ValueError): + self.call_mgmt_func(b'admin.vm.firewall.Set', + b'test-vm1', b'', rules_txt.encode()) + self.assertEqual(self.vm.firewall.rules, + [qubes.firewall.Rule(action='accept')]) + self.assertFalse(self.vm.firewall.save.called) + self.assertFalse(self.app.save.called) + + def test_585_firewall_set_invalid(self): + self.vm.firewall.save = unittest.mock.Mock() + rules_txt = ( + 'action=accept dstports=1-1024 comment=ążźł\n' + 'action=drop\n') + with self.assertRaises(UnicodeDecodeError): + self.call_mgmt_func(b'admin.vm.firewall.Set', + b'test-vm1', b'', rules_txt.encode()) + self.assertEqual(self.vm.firewall.rules, + [qubes.firewall.Rule(action='accept')]) + self.assertFalse(self.vm.firewall.save.called) + self.assertFalse(self.app.save.called) + + def test_590_firewall_reload(self): + self.vm.firewall.save = unittest.mock.Mock() + self.app.domains['test-vm1'].fire_event = self.emitter.fire_event + self.app.domains['test-vm1'].fire_event_pre = \ + self.emitter.fire_event_pre + value = self.call_mgmt_func(b'admin.vm.firewall.Reload', + b'test-vm1', b'') + self.assertIsNone(value) + self.assertEventFired(self.emitter, 'firewall-changed') + self.assertFalse(self.vm.firewall.save.called) + self.assertFalse(self.app.save.called) + + def test_990_vm_unexpected_payload(self): methods_with_no_payload = [ b'admin.vm.List', @@ -1752,8 +1879,7 @@ class TC_00_VMs(AdminAPITestCase): b'admin.vm.tag.Remove', b'admin.vm.tag.Set', b'admin.vm.firewall.Get', - b'admin.vm.firewall.RemoveRule', - b'admin.vm.firewall.Flush', + b'admin.vm.firewall.Reload', b'admin.vm.device.pci.Attach', b'admin.vm.device.pci.Detach', b'admin.vm.device.pci.List', @@ -1805,8 +1931,9 @@ class TC_00_VMs(AdminAPITestCase): b'admin.vm.property.List', b'admin.vm.feature.List', b'admin.vm.tag.List', - b'admin.vm.firewall.List', - b'admin.vm.firewall.Flush', + b'admin.vm.firewall.Get', + b'admin.vm.firewall.Set', + b'admin.vm.firewall.Reload', b'admin.vm.microphone.Attach', b'admin.vm.microphone.Detach', b'admin.vm.microphone.Status', @@ -2007,9 +2134,8 @@ class TC_00_VMs(AdminAPITestCase): b'admin.vm.tag.Remove', b'admin.vm.tag.Set', b'admin.vm.firewall.Get', - b'admin.vm.firewall.RemoveRule', - b'admin.vm.firewall.InsertRule', - b'admin.vm.firewall.Flush', + b'admin.vm.firewall.Set', + b'admin.vm.firewall.Reload', b'admin.vm.device.pci.Attach', b'admin.vm.device.pci.Detach', b'admin.vm.device.pci.List', diff --git a/qubes/tests/firewall.py b/qubes/tests/firewall.py index 5c1a330e..f4594d45 100644 --- a/qubes/tests/firewall.py +++ b/qubes/tests/firewall.py @@ -80,6 +80,7 @@ class TC_01_Action(qubes.tests.QubesTestCase): def test_001_rule(self): instance = qubes.firewall.Action('accept') self.assertEqual(instance.rule, 'action=accept') + self.assertEqual(instance.api_rule, 'action=accept') # noinspection PyPep8Naming @@ -93,6 +94,7 @@ class TC_02_Proto(qubes.tests.QubesTestCase): def test_001_rule(self): instance = qubes.firewall.Proto('tcp') self.assertEqual(instance.rule, 'proto=tcp') + self.assertEqual(instance.api_rule, 'proto=tcp') # noinspection PyPep8Naming @@ -157,6 +159,7 @@ class TC_02_DstHost(qubes.tests.QubesTestCase): self.assertEqual(instance.prefixlen, 128) self.assertEqual(str(instance), '2001:abcd:efab::3/128') self.assertEqual(instance.rule, 'dst6=2001:abcd:efab::3/128') + self.assertEqual(instance.api_rule, 'dst6=2001:abcd:efab::3/128') def test_011_ipv6_prefixlen(self): with self.assertNotRaises(ValueError): @@ -165,6 +168,7 @@ class TC_02_DstHost(qubes.tests.QubesTestCase): self.assertEqual(instance.prefixlen, 64) self.assertEqual(str(instance), '2001:abcd:efab::/64') self.assertEqual(instance.rule, 'dst6=2001:abcd:efab::/64') + self.assertEqual(instance.api_rule, 'dst6=2001:abcd:efab::/64') def test_012_ipv6_parse_prefixlen(self): with self.assertNotRaises(ValueError): @@ -173,6 +177,7 @@ class TC_02_DstHost(qubes.tests.QubesTestCase): self.assertEqual(instance.prefixlen, 64) self.assertEqual(str(instance), '2001:abcd:efab::/64') self.assertEqual(instance.rule, 'dst6=2001:abcd:efab::/64') + self.assertEqual(instance.api_rule, 'dst6=2001:abcd:efab::/64') def test_013_ipv6_invalid_prefix(self): with self.assertRaises(ValueError): @@ -211,6 +216,7 @@ class TC_03_DstPorts(qubes.tests.QubesTestCase): self.assertEqual(str(instance), '80') self.assertEqual(instance.range, [80, 80]) self.assertEqual(instance.rule, 'dstports=80-80') + self.assertEqual(instance.api_rule, 'dstports=80-80') def test_001_single_int(self): with self.assertNotRaises(ValueError): @@ -218,6 +224,7 @@ class TC_03_DstPorts(qubes.tests.QubesTestCase): self.assertEqual(str(instance), '80') self.assertEqual(instance.range, [80, 80]) self.assertEqual(instance.rule, 'dstports=80-80') + self.assertEqual(instance.api_rule, 'dstports=80-80') def test_002_range(self): with self.assertNotRaises(ValueError): @@ -261,6 +268,7 @@ class TC_04_IcmpType(qubes.tests.QubesTestCase): instance = qubes.firewall.IcmpType('8') self.assertEqual(str(instance), '8') self.assertEqual(instance.rule, 'icmptype=8') + self.assertEqual(instance.api_rule, 'icmptype=8') def test_002_invalid(self): with self.assertRaises(ValueError): @@ -283,6 +291,7 @@ class TC_05_SpecialTarget(qubes.tests.QubesTestCase): def test_001_rule(self): instance = qubes.firewall.SpecialTarget('dns') self.assertEqual(instance.rule, 'specialtarget=dns') + self.assertEqual(instance.api_rule, 'specialtarget=dns') class TC_06_Expire(qubes.tests.QubesTestCase): @@ -290,6 +299,7 @@ class TC_06_Expire(qubes.tests.QubesTestCase): with self.assertNotRaises(ValueError): instance = qubes.firewall.Expire(1463292452) self.assertEqual(str(instance), '1463292452') + self.assertEqual(instance.api_rule, 'expire=1463292452') self.assertEqual(instance.datetime, datetime.datetime(2016, 5, 15, 6, 7, 32)) self.assertIsNone(instance.rule) @@ -322,6 +332,7 @@ class TC_07_Comment(qubes.tests.QubesTestCase): with self.assertNotRaises(ValueError): instance = qubes.firewall.Comment('Some comment') self.assertEqual(str(instance), 'Some comment') + self.assertEqual(instance.api_rule, 'comment=Some comment') self.assertIsNone(instance.rule) @@ -445,6 +456,33 @@ class TC_08_Rule(qubes.tests.QubesTestCase): self.assertIsNone(rule.proto) self.assertIsNone(rule.dstports) + def test_008_from_api_string(self): + rule_txt = 'action=drop proto=tcp dstports=80-80' + with self.assertNotRaises(ValueError): + rule = qubes.firewall.Rule.from_api_string( + rule_txt) + self.assertEqual(rule.dstports.range, [80, 80]) + self.assertEqual(rule.proto, 'tcp') + self.assertEqual(rule.action, 'drop') + self.assertIsNone(rule.dsthost) + self.assertIsNone(rule.expire) + self.assertIsNone(rule.comment) + self.assertEqual(rule.api_rule, rule_txt) + + def test_009_from_api_string(self): + rule_txt = 'action=accept expire=1463292452 proto=tcp ' \ + 'comment=Some comment, with spaces' + with self.assertNotRaises(ValueError): + rule = qubes.firewall.Rule.from_api_string( + rule_txt) + self.assertEqual(rule.comment, 'Some comment, with spaces') + self.assertEqual(rule.proto, 'tcp') + self.assertEqual(rule.action, 'accept') + self.assertEqual(rule.expire, '1463292452') + self.assertIsNone(rule.dstports) + self.assertIsNone(rule.dsthost) + self.assertEqual(rule.api_rule, rule_txt) + class TC_10_Firewall(qubes.tests.QubesTestCase): def setUp(self):