api/admin: firewall-related methods

In the end firewall is implemented as .Get and .Set rules, with policy
statically set to 'drop'. This way allow atomic firewall updates.

Since we already have appropriate firewall format handling in
qubes.firewall module - reuse it from there, but adjust the code to be
prepared for potentially malicious input. And also mark such variables
with untrusted_ prefix.

There is also third method: .Reload - which cause firewall reload
without making any change.

QubesOS/qubes-issues#2622
Fixes QubesOS/qubes-issues#2869
This commit is contained in:
Marek Marczykowski-Górecki 2017-06-26 12:58:14 +02:00
parent 842efb577d
commit 0200fdadcb
No known key found for this signature in database
GPG Key ID: 063938BA42CFA724
5 changed files with 328 additions and 46 deletions

View File

@ -30,6 +30,7 @@ import libvirt
import qubes.api
import qubes.devices
import qubes.firewall
import qubes.storage
import qubes.utils
import qubes.vm
@ -991,3 +992,37 @@ class QubesAdminAPI(qubes.api.AbstractQubesAPI):
dev.backend_domain, dev.ident)
self.dest.devices[devclass].detach(assignment)
self.app.save()
@qubes.api.method('admin.vm.firewall.Get', no_payload=True)
@asyncio.coroutine
def vm_firewall_get(self):
assert not self.arg
self.fire_event_for_permission()
return ''.join('{}\n'.format(rule.api_rule)
for rule in self.dest.firewall.rules)
@qubes.api.method('admin.vm.firewall.Set')
@asyncio.coroutine
def vm_firewall_set(self, untrusted_payload):
assert not self.arg
rules = []
for untrusted_line in untrusted_payload.decode('ascii',
errors='strict').splitlines():
rule = qubes.firewall.Rule.from_api_string(untrusted_line)
rules.append(rule)
self.fire_event_for_permission(rules=rules)
self.dest.firewall.rules = rules
self.dest.firewall.save()
@qubes.api.method('admin.vm.firewall.Reload', no_payload=True)
@asyncio.coroutine
def vm_firewall_reload(self):
assert not self.arg
self.fire_event_for_permission()
self.dest.fire_event('firewall-changed')

View File

@ -22,6 +22,7 @@
#
import datetime
import string
import subprocess
import itertools
@ -34,13 +35,22 @@ import qubes.vm.qubesvm
class RuleOption(object):
def __init__(self, value):
self._value = str(value)
def __init__(self, untrusted_value):
# subset of string.punctuation
safe_set = string.ascii_letters + string.digits + \
':;,./-_[]'
assert all(x in safe_set for x in str(untrusted_value))
value = str(untrusted_value)
self._value = value
@property
def rule(self):
raise NotImplementedError
@property
def api_rule(self):
return self.rule
def __str__(self):
return self._value
@ -50,14 +60,15 @@ class RuleOption(object):
# noinspection PyAbstractClass
class RuleChoice(RuleOption):
# pylint: disable=abstract-method
def __init__(self, value):
super(RuleChoice, self).__init__(value)
def __init__(self, untrusted_value):
# preliminary validation
super(RuleChoice, self).__init__(untrusted_value)
self.allowed_values = \
[v for k, v in self.__class__.__dict__.items()
if not k.startswith('__') and isinstance(v, str) and
not v.startswith('__')]
if value not in self.allowed_values:
raise ValueError(value)
if untrusted_value not in self.allowed_values:
raise ValueError(untrusted_value)
class Action(RuleChoice):
@ -81,14 +92,14 @@ class Proto(RuleChoice):
class DstHost(RuleOption):
'''Represent host/network address: either IPv4, IPv6, or DNS name'''
def __init__(self, value, prefixlen=None):
# TODO: in python >= 3.3 ipaddress module could be used
if value.count('/') > 1:
raise ValueError('Too many /: ' + value)
elif not value.count('/'):
def __init__(self, untrusted_value, prefixlen=None):
if untrusted_value.count('/') > 1:
raise ValueError('Too many /: ' + untrusted_value)
elif not untrusted_value.count('/'):
# add prefix length to bare IP addresses
try:
socket.inet_pton(socket.AF_INET6, value)
socket.inet_pton(socket.AF_INET6, untrusted_value)
value = untrusted_value
self.prefixlen = prefixlen or 128
if self.prefixlen < 0 or self.prefixlen > 128:
raise ValueError(
@ -97,10 +108,11 @@ class DstHost(RuleOption):
self.type = 'dst6'
except socket.error:
try:
socket.inet_pton(socket.AF_INET, value)
if value.count('.') != 3:
socket.inet_pton(socket.AF_INET, untrusted_value)
if untrusted_value.count('.') != 3:
raise ValueError(
'Invalid number of dots in IPv4 address')
value = untrusted_value
self.prefixlen = prefixlen or 32
if self.prefixlen < 0 or self.prefixlen > 32:
raise ValueError(
@ -110,28 +122,33 @@ class DstHost(RuleOption):
except socket.error:
self.type = 'dsthost'
self.prefixlen = 0
safe_set = string.ascii_lowercase + string.digits + '-._'
assert all(c in safe_set for c in untrusted_value)
value = untrusted_value
else:
host, prefixlen = value.split('/', 1)
prefixlen = int(prefixlen)
untrusted_host, untrusted_prefixlen = untrusted_value.split('/', 1)
prefixlen = int(untrusted_prefixlen)
if prefixlen < 0:
raise ValueError('netmask must be non-negative')
self.prefixlen = prefixlen
try:
socket.inet_pton(socket.AF_INET6, host)
socket.inet_pton(socket.AF_INET6, untrusted_host)
value = untrusted_value
if prefixlen > 128:
raise ValueError('netmask for IPv6 must be <= 128')
self.type = 'dst6'
except socket.error:
try:
socket.inet_pton(socket.AF_INET, host)
socket.inet_pton(socket.AF_INET, untrusted_host)
if prefixlen > 32:
raise ValueError('netmask for IPv4 must be <= 32')
self.type = 'dst4'
if host.count('.') != 3:
if untrusted_host.count('.') != 3:
raise ValueError(
'Invalid number of dots in IPv4 address')
value = untrusted_value
except socket.error:
raise ValueError('Invalid IP address: ' + host)
raise ValueError('Invalid IP address: ' + untrusted_host)
super(DstHost, self).__init__(value)
@ -141,15 +158,15 @@ class DstHost(RuleOption):
class DstPorts(RuleOption):
def __init__(self, value):
if isinstance(value, int):
value = str(value)
if value.count('-') == 1:
self.range = [int(x) for x in value.split('-', 1)]
elif not value.count('-'):
self.range = [int(value), int(value)]
def __init__(self, untrusted_value):
if isinstance(untrusted_value, int):
untrusted_value = str(untrusted_value)
if untrusted_value.count('-') == 1:
self.range = [int(x) for x in untrusted_value.split('-', 1)]
elif not untrusted_value.count('-'):
self.range = [int(untrusted_value), int(untrusted_value)]
else:
raise ValueError(value)
raise ValueError(untrusted_value)
if any(port < 0 or port > 65536 for port in self.range):
raise ValueError('Ports out of range')
if self.range[0] > self.range[1]:
@ -164,11 +181,11 @@ class DstPorts(RuleOption):
class IcmpType(RuleOption):
def __init__(self, value):
super(IcmpType, self).__init__(value)
value = int(value)
if value < 0 or value > 255:
def __init__(self, untrusted_value):
untrusted_value = int(untrusted_value)
if untrusted_value < 0 or untrusted_value > 255:
raise ValueError('ICMP type out of range')
super(IcmpType, self).__init__(untrusted_value)
@property
def rule(self):
@ -184,24 +201,42 @@ class SpecialTarget(RuleChoice):
class Expire(RuleOption):
def __init__(self, value):
super(Expire, self).__init__(value)
self.datetime = datetime.datetime.utcfromtimestamp(int(value))
def __init__(self, untrusted_value):
super(Expire, self).__init__(untrusted_value)
self.datetime = datetime.datetime.utcfromtimestamp(int(untrusted_value))
@property
def rule(self):
return None
@property
def api_rule(self):
return 'expire=' + str(self)
@property
def expired(self):
return self.datetime < datetime.datetime.utcnow()
class Comment(RuleOption):
# noinspection PyMissingConstructor
def __init__(self, untrusted_value):
# pylint: disable=super-init-not-called
# subset of string.punctuation
safe_set = string.ascii_letters + string.digits + \
':;,./-_[] '
assert all(x in safe_set for x in str(untrusted_value))
value = str(untrusted_value)
self._value = value
@property
def rule(self):
return None
@property
def api_rule(self):
return 'comment=' + str(self)
class Rule(qubes.PropertyHolder):
def __init__(self, xml=None, **kwargs):
@ -311,6 +346,20 @@ class Rule(qubes.PropertyHolder):
values.append(value.rule)
return ' '.join(values)
@property
def api_rule(self):
values = []
# put comment at the end
for prop in sorted(self.property_list(),
key=(lambda p: p.__name__ == 'comment')):
value = getattr(self, prop.__name__)
if value is None:
continue
if value.api_rule is None:
continue
values.append(value.api_rule)
return ' '.join(values)
@classmethod
def from_xml_v1(cls, node, action):
netmask = node.get('netmask')
@ -358,8 +407,39 @@ class Rule(qubes.PropertyHolder):
return cls(**kwargs)
@classmethod
def from_api_string(cls, untrusted_rule):
'''Parse a single line of firewall rule'''
# comment is allowed to have spaces
untrusted_options, _, untrusted_comment = untrusted_rule.partition(
'comment=')
# appropriate handlers in __init__ of individual options will perform
# option-specific validation
kwargs = {}
if untrusted_comment:
kwargs['comment'] = untrusted_comment
for untrusted_option in untrusted_options.strip().split(' '):
untrusted_key, untrusted_value = untrusted_option.split('=', 1)
if untrusted_key in kwargs:
raise ValueError('Option \'{}\' already set'.format(
untrusted_key))
if untrusted_key in [str(prop) for prop in cls.property_list()]:
kwargs[untrusted_key] = untrusted_value
elif untrusted_key in ('dst4', 'dst6', 'dstname'):
kwargs['dsthost'] = untrusted_value
else:
raise ValueError('Unknown firewall option')
return cls(**kwargs)
def __eq__(self, other):
return self.rule == other.rule
if isinstance(other, Rule):
return self.api_rule == other.api_rule
return self.api_rule == str(other)
def __hash__(self):
return hash(self.api_rule)
class Firewall(object):
@ -496,7 +576,6 @@ class Firewall(object):
subprocess.call(["sudo", "systemctl", "start",
"qubes-reload-firewall@%s.timer" % self.vm.name])
def qdb_entries(self, addr_family=None):
'''Return firewall settings serialized for QubesDB entries

View File

@ -151,7 +151,9 @@ class TestEmitter(qubes.events.Emitter):
effects = super(TestEmitter, self).fire_event(event, **kwargs)
ev_kwargs = frozenset(
(key,
frozenset(value.items()) if isinstance(value, dict) else value)
frozenset(value.items()) if isinstance(value, dict)
else tuple(value) if isinstance(value, list)
else value)
for key, value in kwargs.items()
)
self.fired_events[(event, ev_kwargs)] += 1
@ -161,7 +163,9 @@ class TestEmitter(qubes.events.Emitter):
effects = super(TestEmitter, self).fire_event_pre(event, **kwargs)
ev_kwargs = frozenset(
(key,
frozenset(value.items()) if isinstance(value, dict) else value)
frozenset(value.items()) if isinstance(value, dict)
else tuple(value) if isinstance(value, list)
else value)
for key, value in kwargs.items()
)
self.fired_events[(event, ev_kwargs)] += 1

View File

@ -21,6 +21,7 @@
''' Tests for management calls endpoints '''
import asyncio
import operator
import os
import shutil
import unittest.mock
@ -29,6 +30,7 @@ import libvirt
import qubes
import qubes.devices
import qubes.firewall
import qubes.api.admin
import qubes.tests
@ -1734,6 +1736,131 @@ class TC_00_VMs(AdminAPITestCase):
self.assertNotIn('+.some-tag', self.vm.tags)
self.assertFalse(self.app.save.called)
def test_570_firewall_get(self):
self.vm.firewall.save = unittest.mock.Mock()
value = self.call_mgmt_func(b'admin.vm.firewall.Get',
b'test-vm1', b'')
self.assertEqual(value, 'action=accept\n')
self.assertFalse(self.vm.firewall.save.called)
self.assertFalse(self.app.save.called)
def test_571_firewall_get_non_default(self):
self.vm.firewall.save = unittest.mock.Mock()
self.vm.firewall.rules = [
qubes.firewall.Rule(action='accept', proto='tcp',
dstports='1-1024'),
qubes.firewall.Rule(action='drop', proto='icmp',
comment='No ICMP'),
qubes.firewall.Rule(action='drop', proto='udp',
expire='1499450306'),
qubes.firewall.Rule(action='accept'),
]
value = self.call_mgmt_func(b'admin.vm.firewall.Get',
b'test-vm1', b'')
self.assertEqual(value,
'action=accept proto=tcp dstports=1-1024\n'
'action=drop proto=icmp comment=No ICMP\n'
'action=drop expire=1499450306 proto=udp\n'
'action=accept\n')
self.assertFalse(self.vm.firewall.save.called)
self.assertFalse(self.app.save.called)
def test_580_firewall_set_simple(self):
self.vm.firewall.save = unittest.mock.Mock()
value = self.call_mgmt_func(b'admin.vm.firewall.Set',
b'test-vm1', b'', b'action=accept\n')
self.assertEqual(self.vm.firewall.rules,
['action=accept'])
self.assertTrue(self.vm.firewall.save.called)
self.assertFalse(self.app.save.called)
def test_581_firewall_set_multi(self):
self.vm.firewall.save = unittest.mock.Mock()
rules = [
qubes.firewall.Rule(action='accept', proto='tcp',
dstports='1-1024'),
qubes.firewall.Rule(action='drop', proto='icmp',
comment='No ICMP'),
qubes.firewall.Rule(action='drop', proto='udp',
expire='1499450306'),
qubes.firewall.Rule(action='accept'),
]
rules_txt = (
'action=accept proto=tcp dstports=1-1024\n'
'action=drop proto=icmp comment=No ICMP\n'
'action=drop expire=1499450306 proto=udp\n'
'action=accept\n')
value = self.call_mgmt_func(b'admin.vm.firewall.Set',
b'test-vm1', b'', rules_txt.encode())
self.assertEqual(self.vm.firewall.rules, rules)
self.assertTrue(self.vm.firewall.save.called)
self.assertFalse(self.app.save.called)
def test_582_firewall_set_invalid(self):
self.vm.firewall.save = unittest.mock.Mock()
rules_txt = (
'action=accept protoxyz=tcp dst4=127.0.0.1\n'
'action=drop\n')
with self.assertRaises(ValueError):
self.call_mgmt_func(b'admin.vm.firewall.Set',
b'test-vm1', b'', rules_txt.encode())
self.assertEqual(self.vm.firewall.rules,
[qubes.firewall.Rule(action='accept')])
self.assertFalse(self.vm.firewall.save.called)
self.assertFalse(self.app.save.called)
def test_583_firewall_set_invalid(self):
self.vm.firewall.save = unittest.mock.Mock()
rules_txt = (
'proto=tcp dstports=1-1024\n'
'action=drop\n')
with self.assertRaises(AssertionError):
self.call_mgmt_func(b'admin.vm.firewall.Set',
b'test-vm1', b'', rules_txt.encode())
self.assertEqual(self.vm.firewall.rules,
[qubes.firewall.Rule(action='accept')])
self.assertFalse(self.vm.firewall.save.called)
self.assertFalse(self.app.save.called)
def test_584_firewall_set_invalid(self):
self.vm.firewall.save = unittest.mock.Mock()
rules_txt = (
'action=accept proto=tcp dstports=1-1024 '
'action=drop\n')
with self.assertRaises(ValueError):
self.call_mgmt_func(b'admin.vm.firewall.Set',
b'test-vm1', b'', rules_txt.encode())
self.assertEqual(self.vm.firewall.rules,
[qubes.firewall.Rule(action='accept')])
self.assertFalse(self.vm.firewall.save.called)
self.assertFalse(self.app.save.called)
def test_585_firewall_set_invalid(self):
self.vm.firewall.save = unittest.mock.Mock()
rules_txt = (
'action=accept dstports=1-1024 comment=ążźł\n'
'action=drop\n')
with self.assertRaises(UnicodeDecodeError):
self.call_mgmt_func(b'admin.vm.firewall.Set',
b'test-vm1', b'', rules_txt.encode())
self.assertEqual(self.vm.firewall.rules,
[qubes.firewall.Rule(action='accept')])
self.assertFalse(self.vm.firewall.save.called)
self.assertFalse(self.app.save.called)
def test_590_firewall_reload(self):
self.vm.firewall.save = unittest.mock.Mock()
self.app.domains['test-vm1'].fire_event = self.emitter.fire_event
self.app.domains['test-vm1'].fire_event_pre = \
self.emitter.fire_event_pre
value = self.call_mgmt_func(b'admin.vm.firewall.Reload',
b'test-vm1', b'')
self.assertIsNone(value)
self.assertEventFired(self.emitter, 'firewall-changed')
self.assertFalse(self.vm.firewall.save.called)
self.assertFalse(self.app.save.called)
def test_990_vm_unexpected_payload(self):
methods_with_no_payload = [
b'admin.vm.List',
@ -1752,8 +1879,7 @@ class TC_00_VMs(AdminAPITestCase):
b'admin.vm.tag.Remove',
b'admin.vm.tag.Set',
b'admin.vm.firewall.Get',
b'admin.vm.firewall.RemoveRule',
b'admin.vm.firewall.Flush',
b'admin.vm.firewall.Reload',
b'admin.vm.device.pci.Attach',
b'admin.vm.device.pci.Detach',
b'admin.vm.device.pci.List',
@ -1805,8 +1931,9 @@ class TC_00_VMs(AdminAPITestCase):
b'admin.vm.property.List',
b'admin.vm.feature.List',
b'admin.vm.tag.List',
b'admin.vm.firewall.List',
b'admin.vm.firewall.Flush',
b'admin.vm.firewall.Get',
b'admin.vm.firewall.Set',
b'admin.vm.firewall.Reload',
b'admin.vm.microphone.Attach',
b'admin.vm.microphone.Detach',
b'admin.vm.microphone.Status',
@ -2007,9 +2134,8 @@ class TC_00_VMs(AdminAPITestCase):
b'admin.vm.tag.Remove',
b'admin.vm.tag.Set',
b'admin.vm.firewall.Get',
b'admin.vm.firewall.RemoveRule',
b'admin.vm.firewall.InsertRule',
b'admin.vm.firewall.Flush',
b'admin.vm.firewall.Set',
b'admin.vm.firewall.Reload',
b'admin.vm.device.pci.Attach',
b'admin.vm.device.pci.Detach',
b'admin.vm.device.pci.List',

View File

@ -80,6 +80,7 @@ class TC_01_Action(qubes.tests.QubesTestCase):
def test_001_rule(self):
instance = qubes.firewall.Action('accept')
self.assertEqual(instance.rule, 'action=accept')
self.assertEqual(instance.api_rule, 'action=accept')
# noinspection PyPep8Naming
@ -93,6 +94,7 @@ class TC_02_Proto(qubes.tests.QubesTestCase):
def test_001_rule(self):
instance = qubes.firewall.Proto('tcp')
self.assertEqual(instance.rule, 'proto=tcp')
self.assertEqual(instance.api_rule, 'proto=tcp')
# noinspection PyPep8Naming
@ -157,6 +159,7 @@ class TC_02_DstHost(qubes.tests.QubesTestCase):
self.assertEqual(instance.prefixlen, 128)
self.assertEqual(str(instance), '2001:abcd:efab::3/128')
self.assertEqual(instance.rule, 'dst6=2001:abcd:efab::3/128')
self.assertEqual(instance.api_rule, 'dst6=2001:abcd:efab::3/128')
def test_011_ipv6_prefixlen(self):
with self.assertNotRaises(ValueError):
@ -165,6 +168,7 @@ class TC_02_DstHost(qubes.tests.QubesTestCase):
self.assertEqual(instance.prefixlen, 64)
self.assertEqual(str(instance), '2001:abcd:efab::/64')
self.assertEqual(instance.rule, 'dst6=2001:abcd:efab::/64')
self.assertEqual(instance.api_rule, 'dst6=2001:abcd:efab::/64')
def test_012_ipv6_parse_prefixlen(self):
with self.assertNotRaises(ValueError):
@ -173,6 +177,7 @@ class TC_02_DstHost(qubes.tests.QubesTestCase):
self.assertEqual(instance.prefixlen, 64)
self.assertEqual(str(instance), '2001:abcd:efab::/64')
self.assertEqual(instance.rule, 'dst6=2001:abcd:efab::/64')
self.assertEqual(instance.api_rule, 'dst6=2001:abcd:efab::/64')
def test_013_ipv6_invalid_prefix(self):
with self.assertRaises(ValueError):
@ -211,6 +216,7 @@ class TC_03_DstPorts(qubes.tests.QubesTestCase):
self.assertEqual(str(instance), '80')
self.assertEqual(instance.range, [80, 80])
self.assertEqual(instance.rule, 'dstports=80-80')
self.assertEqual(instance.api_rule, 'dstports=80-80')
def test_001_single_int(self):
with self.assertNotRaises(ValueError):
@ -218,6 +224,7 @@ class TC_03_DstPorts(qubes.tests.QubesTestCase):
self.assertEqual(str(instance), '80')
self.assertEqual(instance.range, [80, 80])
self.assertEqual(instance.rule, 'dstports=80-80')
self.assertEqual(instance.api_rule, 'dstports=80-80')
def test_002_range(self):
with self.assertNotRaises(ValueError):
@ -261,6 +268,7 @@ class TC_04_IcmpType(qubes.tests.QubesTestCase):
instance = qubes.firewall.IcmpType('8')
self.assertEqual(str(instance), '8')
self.assertEqual(instance.rule, 'icmptype=8')
self.assertEqual(instance.api_rule, 'icmptype=8')
def test_002_invalid(self):
with self.assertRaises(ValueError):
@ -283,6 +291,7 @@ class TC_05_SpecialTarget(qubes.tests.QubesTestCase):
def test_001_rule(self):
instance = qubes.firewall.SpecialTarget('dns')
self.assertEqual(instance.rule, 'specialtarget=dns')
self.assertEqual(instance.api_rule, 'specialtarget=dns')
class TC_06_Expire(qubes.tests.QubesTestCase):
@ -290,6 +299,7 @@ class TC_06_Expire(qubes.tests.QubesTestCase):
with self.assertNotRaises(ValueError):
instance = qubes.firewall.Expire(1463292452)
self.assertEqual(str(instance), '1463292452')
self.assertEqual(instance.api_rule, 'expire=1463292452')
self.assertEqual(instance.datetime,
datetime.datetime(2016, 5, 15, 6, 7, 32))
self.assertIsNone(instance.rule)
@ -322,6 +332,7 @@ class TC_07_Comment(qubes.tests.QubesTestCase):
with self.assertNotRaises(ValueError):
instance = qubes.firewall.Comment('Some comment')
self.assertEqual(str(instance), 'Some comment')
self.assertEqual(instance.api_rule, 'comment=Some comment')
self.assertIsNone(instance.rule)
@ -445,6 +456,33 @@ class TC_08_Rule(qubes.tests.QubesTestCase):
self.assertIsNone(rule.proto)
self.assertIsNone(rule.dstports)
def test_008_from_api_string(self):
rule_txt = 'action=drop proto=tcp dstports=80-80'
with self.assertNotRaises(ValueError):
rule = qubes.firewall.Rule.from_api_string(
rule_txt)
self.assertEqual(rule.dstports.range, [80, 80])
self.assertEqual(rule.proto, 'tcp')
self.assertEqual(rule.action, 'drop')
self.assertIsNone(rule.dsthost)
self.assertIsNone(rule.expire)
self.assertIsNone(rule.comment)
self.assertEqual(rule.api_rule, rule_txt)
def test_009_from_api_string(self):
rule_txt = 'action=accept expire=1463292452 proto=tcp ' \
'comment=Some comment, with spaces'
with self.assertNotRaises(ValueError):
rule = qubes.firewall.Rule.from_api_string(
rule_txt)
self.assertEqual(rule.comment, 'Some comment, with spaces')
self.assertEqual(rule.proto, 'tcp')
self.assertEqual(rule.action, 'accept')
self.assertEqual(rule.expire, '1463292452')
self.assertIsNone(rule.dstports)
self.assertIsNone(rule.dsthost)
self.assertEqual(rule.api_rule, rule_txt)
class TC_10_Firewall(qubes.tests.QubesTestCase):
def setUp(self):