Merge branch 'core3-firewall2'

This commit is contained in:
Marek Marczykowski-Górecki 2017-07-04 03:38:59 +02:00
commit 3748eb3e2b
No known key found for this signature in database
GPG Key ID: 063938BA42CFA724
6 changed files with 358 additions and 79 deletions

View File

@ -331,11 +331,12 @@ class property(object): # pylint: disable=redefined-builtin,invalid-name
# do not treat type='str' as sufficient validation # do not treat type='str' as sufficient validation
if self.type is not None and self.type is not str: if self.type is not None and self.type is not str:
# assume specific type will preform enough validation # assume specific type will preform enough validation
try:
untrusted_newvalue = untrusted_newvalue.decode('ascii',
errors='strict')
except UnicodeDecodeError:
raise qubes.exc.QubesValueError
if self.type is bool: if self.type is bool:
try:
untrusted_newvalue = untrusted_newvalue.decode('ascii')
except UnicodeDecodeError:
raise qubes.exc.QubesValueError
return self.bool(None, None, untrusted_newvalue) return self.bool(None, None, untrusted_newvalue)
else: else:
try: try:

View File

@ -30,6 +30,7 @@ import libvirt
import qubes.api import qubes.api
import qubes.devices import qubes.devices
import qubes.firewall
import qubes.storage import qubes.storage
import qubes.utils import qubes.utils
import qubes.vm import qubes.vm
@ -992,3 +993,38 @@ class QubesAdminAPI(qubes.api.AbstractQubesAPI):
dev.backend_domain, dev.ident) dev.backend_domain, dev.ident)
self.dest.devices[devclass].detach(assignment) self.dest.devices[devclass].detach(assignment)
self.app.save() self.app.save()
@qubes.api.method('admin.vm.firewall.Get', no_payload=True)
@asyncio.coroutine
def vm_firewall_get(self):
assert not self.arg
self.fire_event_for_permission()
return ''.join('{}\n'.format(rule.api_rule)
for rule in self.dest.firewall.rules)
@qubes.api.method('admin.vm.firewall.Set')
@asyncio.coroutine
def vm_firewall_set(self, untrusted_payload):
assert not self.arg
rules = []
for untrusted_line in untrusted_payload.decode('ascii',
errors='strict').splitlines():
rule = qubes.firewall.Rule.from_api_string(
untrusted_rule=untrusted_line)
rules.append(rule)
self.fire_event_for_permission(rules=rules)
self.dest.firewall.rules = rules
self.dest.firewall.save()
@qubes.api.method('admin.vm.firewall.Reload', no_payload=True)
@asyncio.coroutine
def vm_firewall_reload(self):
assert not self.arg
self.fire_event_for_permission()
self.dest.fire_event('firewall-changed')

View File

@ -22,6 +22,7 @@
# #
import datetime import datetime
import string
import subprocess import subprocess
import itertools import itertools
@ -34,13 +35,22 @@ import qubes.vm.qubesvm
class RuleOption(object): class RuleOption(object):
def __init__(self, value): def __init__(self, untrusted_value):
self._value = str(value) # subset of string.punctuation
safe_set = string.ascii_letters + string.digits + \
':;,./-_[]'
assert all(x in safe_set for x in str(untrusted_value))
value = str(untrusted_value)
self._value = value
@property @property
def rule(self): def rule(self):
raise NotImplementedError raise NotImplementedError
@property
def api_rule(self):
return self.rule
def __str__(self): def __str__(self):
return self._value return self._value
@ -50,14 +60,15 @@ class RuleOption(object):
# noinspection PyAbstractClass # noinspection PyAbstractClass
class RuleChoice(RuleOption): class RuleChoice(RuleOption):
# pylint: disable=abstract-method # pylint: disable=abstract-method
def __init__(self, value): def __init__(self, untrusted_value):
super(RuleChoice, self).__init__(value) # preliminary validation
super(RuleChoice, self).__init__(untrusted_value)
self.allowed_values = \ self.allowed_values = \
[v for k, v in self.__class__.__dict__.items() [v for k, v in self.__class__.__dict__.items()
if not k.startswith('__') and isinstance(v, str) and if not k.startswith('__') and isinstance(v, str) and
not v.startswith('__')] not v.startswith('__')]
if value not in self.allowed_values: if untrusted_value not in self.allowed_values:
raise ValueError(value) raise ValueError(untrusted_value)
class Action(RuleChoice): class Action(RuleChoice):
@ -81,14 +92,14 @@ class Proto(RuleChoice):
class DstHost(RuleOption): class DstHost(RuleOption):
'''Represent host/network address: either IPv4, IPv6, or DNS name''' '''Represent host/network address: either IPv4, IPv6, or DNS name'''
def __init__(self, value, prefixlen=None): def __init__(self, untrusted_value, prefixlen=None):
# TODO: in python >= 3.3 ipaddress module could be used if untrusted_value.count('/') > 1:
if value.count('/') > 1: raise ValueError('Too many /: ' + untrusted_value)
raise ValueError('Too many /: ' + value) elif not untrusted_value.count('/'):
elif not value.count('/'):
# add prefix length to bare IP addresses # add prefix length to bare IP addresses
try: try:
socket.inet_pton(socket.AF_INET6, value) socket.inet_pton(socket.AF_INET6, untrusted_value)
value = untrusted_value
self.prefixlen = prefixlen or 128 self.prefixlen = prefixlen or 128
if self.prefixlen < 0 or self.prefixlen > 128: if self.prefixlen < 0 or self.prefixlen > 128:
raise ValueError( raise ValueError(
@ -97,10 +108,11 @@ class DstHost(RuleOption):
self.type = 'dst6' self.type = 'dst6'
except socket.error: except socket.error:
try: try:
socket.inet_pton(socket.AF_INET, value) socket.inet_pton(socket.AF_INET, untrusted_value)
if value.count('.') != 3: if untrusted_value.count('.') != 3:
raise ValueError( raise ValueError(
'Invalid number of dots in IPv4 address') 'Invalid number of dots in IPv4 address')
value = untrusted_value
self.prefixlen = prefixlen or 32 self.prefixlen = prefixlen or 32
if self.prefixlen < 0 or self.prefixlen > 32: if self.prefixlen < 0 or self.prefixlen > 32:
raise ValueError( raise ValueError(
@ -110,28 +122,33 @@ class DstHost(RuleOption):
except socket.error: except socket.error:
self.type = 'dsthost' self.type = 'dsthost'
self.prefixlen = 0 self.prefixlen = 0
safe_set = string.ascii_lowercase + string.digits + '-._'
assert all(c in safe_set for c in untrusted_value)
value = untrusted_value
else: else:
host, prefixlen = value.split('/', 1) untrusted_host, untrusted_prefixlen = untrusted_value.split('/', 1)
prefixlen = int(prefixlen) prefixlen = int(untrusted_prefixlen)
if prefixlen < 0: if prefixlen < 0:
raise ValueError('netmask must be non-negative') raise ValueError('netmask must be non-negative')
self.prefixlen = prefixlen self.prefixlen = prefixlen
try: try:
socket.inet_pton(socket.AF_INET6, host) socket.inet_pton(socket.AF_INET6, untrusted_host)
value = untrusted_value
if prefixlen > 128: if prefixlen > 128:
raise ValueError('netmask for IPv6 must be <= 128') raise ValueError('netmask for IPv6 must be <= 128')
self.type = 'dst6' self.type = 'dst6'
except socket.error: except socket.error:
try: try:
socket.inet_pton(socket.AF_INET, host) socket.inet_pton(socket.AF_INET, untrusted_host)
if prefixlen > 32: if prefixlen > 32:
raise ValueError('netmask for IPv4 must be <= 32') raise ValueError('netmask for IPv4 must be <= 32')
self.type = 'dst4' self.type = 'dst4'
if host.count('.') != 3: if untrusted_host.count('.') != 3:
raise ValueError( raise ValueError(
'Invalid number of dots in IPv4 address') 'Invalid number of dots in IPv4 address')
value = untrusted_value
except socket.error: except socket.error:
raise ValueError('Invalid IP address: ' + host) raise ValueError('Invalid IP address: ' + untrusted_host)
super(DstHost, self).__init__(value) super(DstHost, self).__init__(value)
@ -141,15 +158,15 @@ class DstHost(RuleOption):
class DstPorts(RuleOption): class DstPorts(RuleOption):
def __init__(self, value): def __init__(self, untrusted_value):
if isinstance(value, int): if isinstance(untrusted_value, int):
value = str(value) untrusted_value = str(untrusted_value)
if value.count('-') == 1: if untrusted_value.count('-') == 1:
self.range = [int(x) for x in value.split('-', 1)] self.range = [int(x) for x in untrusted_value.split('-', 1)]
elif not value.count('-'): elif not untrusted_value.count('-'):
self.range = [int(value), int(value)] self.range = [int(untrusted_value), int(untrusted_value)]
else: else:
raise ValueError(value) raise ValueError(untrusted_value)
if any(port < 0 or port > 65536 for port in self.range): if any(port < 0 or port > 65536 for port in self.range):
raise ValueError('Ports out of range') raise ValueError('Ports out of range')
if self.range[0] > self.range[1]: if self.range[0] > self.range[1]:
@ -164,11 +181,11 @@ class DstPorts(RuleOption):
class IcmpType(RuleOption): class IcmpType(RuleOption):
def __init__(self, value): def __init__(self, untrusted_value):
super(IcmpType, self).__init__(value) untrusted_value = int(untrusted_value)
value = int(value) if untrusted_value < 0 or untrusted_value > 255:
if value < 0 or value > 255:
raise ValueError('ICMP type out of range') raise ValueError('ICMP type out of range')
super(IcmpType, self).__init__(untrusted_value)
@property @property
def rule(self): def rule(self):
@ -184,24 +201,42 @@ class SpecialTarget(RuleChoice):
class Expire(RuleOption): class Expire(RuleOption):
def __init__(self, value): def __init__(self, untrusted_value):
super(Expire, self).__init__(value) super(Expire, self).__init__(untrusted_value)
self.datetime = datetime.datetime.utcfromtimestamp(int(value)) self.datetime = datetime.datetime.utcfromtimestamp(int(untrusted_value))
@property @property
def rule(self): def rule(self):
return None return None
@property
def api_rule(self):
return 'expire=' + str(self)
@property @property
def expired(self): def expired(self):
return self.datetime < datetime.datetime.utcnow() return self.datetime < datetime.datetime.utcnow()
class Comment(RuleOption): class Comment(RuleOption):
# noinspection PyMissingConstructor
def __init__(self, untrusted_value):
# pylint: disable=super-init-not-called
# subset of string.punctuation
safe_set = string.ascii_letters + string.digits + \
':;,./-_[] '
assert all(x in safe_set for x in str(untrusted_value))
value = str(untrusted_value)
self._value = value
@property @property
def rule(self): def rule(self):
return None return None
@property
def api_rule(self):
return 'comment=' + str(self)
class Rule(qubes.PropertyHolder): class Rule(qubes.PropertyHolder):
def __init__(self, xml=None, **kwargs): def __init__(self, xml=None, **kwargs):
@ -311,6 +346,20 @@ class Rule(qubes.PropertyHolder):
values.append(value.rule) values.append(value.rule)
return ' '.join(values) return ' '.join(values)
@property
def api_rule(self):
values = []
# put comment at the end
for prop in sorted(self.property_list(),
key=(lambda p: p.__name__ == 'comment')):
value = getattr(self, prop.__name__)
if value is None:
continue
if value.api_rule is None:
continue
values.append(value.api_rule)
return ' '.join(values)
@classmethod @classmethod
def from_xml_v1(cls, node, action): def from_xml_v1(cls, node, action):
netmask = node.get('netmask') netmask = node.get('netmask')
@ -358,8 +407,43 @@ class Rule(qubes.PropertyHolder):
return cls(**kwargs) return cls(**kwargs)
@classmethod
def from_api_string(cls, untrusted_rule):
'''Parse a single line of firewall rule'''
# comment is allowed to have spaces
untrusted_options, _, untrusted_comment = untrusted_rule.partition(
'comment=')
# appropriate handlers in __init__ of individual options will perform
# option-specific validation
kwargs = {}
if untrusted_comment:
kwargs['comment'] = Comment(untrusted_value=untrusted_comment)
for untrusted_option in untrusted_options.strip().split(' '):
untrusted_key, untrusted_value = untrusted_option.split('=', 1)
if untrusted_key in kwargs:
raise ValueError('Option \'{}\' already set'.format(
untrusted_key))
if untrusted_key in [str(prop) for prop in cls.property_list()]:
kwargs[untrusted_key] = cls.property_get_def(
untrusted_key).type(untrusted_value=untrusted_value)
elif untrusted_key in ('dst4', 'dst6', 'dstname'):
if 'dsthost' in kwargs:
raise ValueError('Option \'{}\' already set'.format(
'dsthost'))
kwargs['dsthost'] = DstHost(untrusted_value=untrusted_value)
else:
raise ValueError('Unknown firewall option')
return cls(**kwargs)
def __eq__(self, other): def __eq__(self, other):
return self.rule == other.rule if isinstance(other, Rule):
return self.api_rule == other.api_rule
return self.api_rule == str(other)
def __hash__(self):
return hash(self.api_rule)
class Firewall(object): class Firewall(object):
@ -368,21 +452,23 @@ class Firewall(object):
self.vm = vm self.vm = vm
#: firewall rules #: firewall rules
self.rules = [] self.rules = []
#: default action
self.policy = None
if load: if load:
self.load() self.load()
@property
def policy(self):
''' Default action - always 'drop' '''
return Action('drop')
def __eq__(self, other): def __eq__(self, other):
if isinstance(other, Firewall): if isinstance(other, Firewall):
return self.policy == other.policy and self.rules == other.rules return self.rules == other.rules
return NotImplemented return NotImplemented
def load_defaults(self): def load_defaults(self):
'''Load default firewall settings''' '''Load default firewall settings'''
self.rules = [] self.rules = [Rule(None, action='accept')]
self.policy = Action('accept')
def clone(self, other): def clone(self, other):
'''Clone firewall settings from other instance. '''Clone firewall settings from other instance.
@ -390,7 +476,6 @@ class Firewall(object):
:param other: other :py:class:`Firewall` instance :param other: other :py:class:`Firewall` instance
''' '''
self.policy = other.policy
rules = [] rules = []
for rule in other.rules: for rule in other.rules:
new_rule = Rule() new_rule = Rule()
@ -421,10 +506,7 @@ class Firewall(object):
'''Load old (Qubes < 4.0) firewall XML format''' '''Load old (Qubes < 4.0) firewall XML format'''
policy_v1 = xml_root.get('policy') policy_v1 = xml_root.get('policy')
assert policy_v1 in ('allow', 'deny') assert policy_v1 in ('allow', 'deny')
if policy_v1 == 'allow': default_policy_is_accept = (policy_v1 == 'allow')
self.policy = Action('accept')
else:
self.policy = Action('drop')
def _translate_action(key): def _translate_action(key):
if xml_root.get(key, policy_v1) == 'allow': if xml_root.get(key, policy_v1) == 'allow':
@ -439,7 +521,7 @@ class Firewall(object):
action=_translate_action('icmp'), action=_translate_action('icmp'),
proto=Proto.icmp)) proto=Proto.icmp))
if self.policy == Action.accept: if default_policy_is_accept:
rule_action = Action.drop rule_action = Action.drop
else: else:
rule_action = Action.accept rule_action = Action.accept
@ -447,11 +529,11 @@ class Firewall(object):
for element in xml_root: for element in xml_root:
rule = Rule.from_xml_v1(element, rule_action) rule = Rule.from_xml_v1(element, rule_action)
self.rules.append(rule) self.rules.append(rule)
if default_policy_is_accept:
self.rules.append(Rule(None, action='accept'))
def load_v2(self, xml_root): def load_v2(self, xml_root):
'''Load new (Qubes >= 4.0) firewall XML format''' '''Load new (Qubes >= 4.0) firewall XML format'''
self.policy = Action(xml_root.findtext('policy'))
xml_rules = xml_root.find('rules') xml_rules = xml_root.find('rules')
for xml_rule in xml_rules: for xml_rule in xml_rules:
rule = Rule(xml_rule) rule = Rule(xml_rule)
@ -464,10 +546,6 @@ class Firewall(object):
xml_root = lxml.etree.Element('firewall', version=str(2)) xml_root = lxml.etree.Element('firewall', version=str(2))
xml_policy = lxml.etree.Element('policy')
xml_policy.text = str(self.policy)
xml_root.append(xml_policy)
xml_rules = lxml.etree.Element('rules') xml_rules = lxml.etree.Element('rules')
for rule in self.rules: for rule in self.rules:
if rule.expire: if rule.expire:
@ -499,7 +577,6 @@ class Firewall(object):
subprocess.call(["sudo", "systemctl", "start", subprocess.call(["sudo", "systemctl", "start",
"qubes-reload-firewall@%s.timer" % self.vm.name]) "qubes-reload-firewall@%s.timer" % self.vm.name])
def qdb_entries(self, addr_family=None): def qdb_entries(self, addr_family=None):
'''Return firewall settings serialized for QubesDB entries '''Return firewall settings serialized for QubesDB entries

View File

@ -151,7 +151,9 @@ class TestEmitter(qubes.events.Emitter):
effects = super(TestEmitter, self).fire_event(event, **kwargs) effects = super(TestEmitter, self).fire_event(event, **kwargs)
ev_kwargs = frozenset( ev_kwargs = frozenset(
(key, (key,
frozenset(value.items()) if isinstance(value, dict) else value) frozenset(value.items()) if isinstance(value, dict)
else tuple(value) if isinstance(value, list)
else value)
for key, value in kwargs.items() for key, value in kwargs.items()
) )
self.fired_events[(event, ev_kwargs)] += 1 self.fired_events[(event, ev_kwargs)] += 1
@ -161,7 +163,9 @@ class TestEmitter(qubes.events.Emitter):
effects = super(TestEmitter, self).fire_event_pre(event, **kwargs) effects = super(TestEmitter, self).fire_event_pre(event, **kwargs)
ev_kwargs = frozenset( ev_kwargs = frozenset(
(key, (key,
frozenset(value.items()) if isinstance(value, dict) else value) frozenset(value.items()) if isinstance(value, dict)
else tuple(value) if isinstance(value, list)
else value)
for key, value in kwargs.items() for key, value in kwargs.items()
) )
self.fired_events[(event, ev_kwargs)] += 1 self.fired_events[(event, ev_kwargs)] += 1

View File

@ -30,6 +30,7 @@ import libvirt
import qubes import qubes
import qubes.devices import qubes.devices
import qubes.firewall
import qubes.api.admin import qubes.api.admin
import qubes.tests import qubes.tests
import qubes.storage import qubes.storage
@ -1677,6 +1678,131 @@ class TC_00_VMs(AdminAPITestCase):
self.assertNotIn('+.some-tag', self.vm.tags) self.assertNotIn('+.some-tag', self.vm.tags)
self.assertFalse(self.app.save.called) self.assertFalse(self.app.save.called)
def test_570_firewall_get(self):
self.vm.firewall.save = unittest.mock.Mock()
value = self.call_mgmt_func(b'admin.vm.firewall.Get',
b'test-vm1', b'')
self.assertEqual(value, 'action=accept\n')
self.assertFalse(self.vm.firewall.save.called)
self.assertFalse(self.app.save.called)
def test_571_firewall_get_non_default(self):
self.vm.firewall.save = unittest.mock.Mock()
self.vm.firewall.rules = [
qubes.firewall.Rule(action='accept', proto='tcp',
dstports='1-1024'),
qubes.firewall.Rule(action='drop', proto='icmp',
comment='No ICMP'),
qubes.firewall.Rule(action='drop', proto='udp',
expire='1499450306'),
qubes.firewall.Rule(action='accept'),
]
value = self.call_mgmt_func(b'admin.vm.firewall.Get',
b'test-vm1', b'')
self.assertEqual(value,
'action=accept proto=tcp dstports=1-1024\n'
'action=drop proto=icmp comment=No ICMP\n'
'action=drop expire=1499450306 proto=udp\n'
'action=accept\n')
self.assertFalse(self.vm.firewall.save.called)
self.assertFalse(self.app.save.called)
def test_580_firewall_set_simple(self):
self.vm.firewall.save = unittest.mock.Mock()
value = self.call_mgmt_func(b'admin.vm.firewall.Set',
b'test-vm1', b'', b'action=accept\n')
self.assertEqual(self.vm.firewall.rules,
['action=accept'])
self.assertTrue(self.vm.firewall.save.called)
self.assertFalse(self.app.save.called)
def test_581_firewall_set_multi(self):
self.vm.firewall.save = unittest.mock.Mock()
rules = [
qubes.firewall.Rule(action='accept', proto='tcp',
dstports='1-1024'),
qubes.firewall.Rule(action='drop', proto='icmp',
comment='No ICMP'),
qubes.firewall.Rule(action='drop', proto='udp',
expire='1499450306'),
qubes.firewall.Rule(action='accept'),
]
rules_txt = (
'action=accept proto=tcp dstports=1-1024\n'
'action=drop proto=icmp comment=No ICMP\n'
'action=drop expire=1499450306 proto=udp\n'
'action=accept\n')
value = self.call_mgmt_func(b'admin.vm.firewall.Set',
b'test-vm1', b'', rules_txt.encode())
self.assertEqual(self.vm.firewall.rules, rules)
self.assertTrue(self.vm.firewall.save.called)
self.assertFalse(self.app.save.called)
def test_582_firewall_set_invalid(self):
self.vm.firewall.save = unittest.mock.Mock()
rules_txt = (
'action=accept protoxyz=tcp dst4=127.0.0.1\n'
'action=drop\n')
with self.assertRaises(ValueError):
self.call_mgmt_func(b'admin.vm.firewall.Set',
b'test-vm1', b'', rules_txt.encode())
self.assertEqual(self.vm.firewall.rules,
[qubes.firewall.Rule(action='accept')])
self.assertFalse(self.vm.firewall.save.called)
self.assertFalse(self.app.save.called)
def test_583_firewall_set_invalid(self):
self.vm.firewall.save = unittest.mock.Mock()
rules_txt = (
'proto=tcp dstports=1-1024\n'
'action=drop\n')
with self.assertRaises(AssertionError):
self.call_mgmt_func(b'admin.vm.firewall.Set',
b'test-vm1', b'', rules_txt.encode())
self.assertEqual(self.vm.firewall.rules,
[qubes.firewall.Rule(action='accept')])
self.assertFalse(self.vm.firewall.save.called)
self.assertFalse(self.app.save.called)
def test_584_firewall_set_invalid(self):
self.vm.firewall.save = unittest.mock.Mock()
rules_txt = (
'action=accept proto=tcp dstports=1-1024 '
'action=drop\n')
with self.assertRaises(ValueError):
self.call_mgmt_func(b'admin.vm.firewall.Set',
b'test-vm1', b'', rules_txt.encode())
self.assertEqual(self.vm.firewall.rules,
[qubes.firewall.Rule(action='accept')])
self.assertFalse(self.vm.firewall.save.called)
self.assertFalse(self.app.save.called)
def test_585_firewall_set_invalid(self):
self.vm.firewall.save = unittest.mock.Mock()
rules_txt = (
'action=accept dstports=1-1024 comment=ążźł\n'
'action=drop\n')
with self.assertRaises(UnicodeDecodeError):
self.call_mgmt_func(b'admin.vm.firewall.Set',
b'test-vm1', b'', rules_txt.encode())
self.assertEqual(self.vm.firewall.rules,
[qubes.firewall.Rule(action='accept')])
self.assertFalse(self.vm.firewall.save.called)
self.assertFalse(self.app.save.called)
def test_590_firewall_reload(self):
self.vm.firewall.save = unittest.mock.Mock()
self.app.domains['test-vm1'].fire_event = self.emitter.fire_event
self.app.domains['test-vm1'].fire_event_pre = \
self.emitter.fire_event_pre
value = self.call_mgmt_func(b'admin.vm.firewall.Reload',
b'test-vm1', b'')
self.assertIsNone(value)
self.assertEventFired(self.emitter, 'firewall-changed')
self.assertFalse(self.vm.firewall.save.called)
self.assertFalse(self.app.save.called)
def test_990_vm_unexpected_payload(self): def test_990_vm_unexpected_payload(self):
methods_with_no_payload = [ methods_with_no_payload = [
b'admin.vm.List', b'admin.vm.List',
@ -1695,8 +1821,7 @@ class TC_00_VMs(AdminAPITestCase):
b'admin.vm.tag.Remove', b'admin.vm.tag.Remove',
b'admin.vm.tag.Set', b'admin.vm.tag.Set',
b'admin.vm.firewall.Get', b'admin.vm.firewall.Get',
b'admin.vm.firewall.RemoveRule', b'admin.vm.firewall.Reload',
b'admin.vm.firewall.Flush',
b'admin.vm.device.pci.Attach', b'admin.vm.device.pci.Attach',
b'admin.vm.device.pci.Detach', b'admin.vm.device.pci.Detach',
b'admin.vm.device.pci.List', b'admin.vm.device.pci.List',
@ -1748,8 +1873,9 @@ class TC_00_VMs(AdminAPITestCase):
b'admin.vm.property.List', b'admin.vm.property.List',
b'admin.vm.feature.List', b'admin.vm.feature.List',
b'admin.vm.tag.List', b'admin.vm.tag.List',
b'admin.vm.firewall.List', b'admin.vm.firewall.Get',
b'admin.vm.firewall.Flush', b'admin.vm.firewall.Set',
b'admin.vm.firewall.Reload',
b'admin.vm.microphone.Attach', b'admin.vm.microphone.Attach',
b'admin.vm.microphone.Detach', b'admin.vm.microphone.Detach',
b'admin.vm.microphone.Status', b'admin.vm.microphone.Status',
@ -1950,9 +2076,8 @@ class TC_00_VMs(AdminAPITestCase):
b'admin.vm.tag.Remove', b'admin.vm.tag.Remove',
b'admin.vm.tag.Set', b'admin.vm.tag.Set',
b'admin.vm.firewall.Get', b'admin.vm.firewall.Get',
b'admin.vm.firewall.RemoveRule', b'admin.vm.firewall.Set',
b'admin.vm.firewall.InsertRule', b'admin.vm.firewall.Reload',
b'admin.vm.firewall.Flush',
b'admin.vm.device.pci.Attach', b'admin.vm.device.pci.Attach',
b'admin.vm.device.pci.Detach', b'admin.vm.device.pci.Detach',
b'admin.vm.device.pci.List', b'admin.vm.device.pci.List',

View File

@ -80,6 +80,7 @@ class TC_01_Action(qubes.tests.QubesTestCase):
def test_001_rule(self): def test_001_rule(self):
instance = qubes.firewall.Action('accept') instance = qubes.firewall.Action('accept')
self.assertEqual(instance.rule, 'action=accept') self.assertEqual(instance.rule, 'action=accept')
self.assertEqual(instance.api_rule, 'action=accept')
# noinspection PyPep8Naming # noinspection PyPep8Naming
@ -93,6 +94,7 @@ class TC_02_Proto(qubes.tests.QubesTestCase):
def test_001_rule(self): def test_001_rule(self):
instance = qubes.firewall.Proto('tcp') instance = qubes.firewall.Proto('tcp')
self.assertEqual(instance.rule, 'proto=tcp') self.assertEqual(instance.rule, 'proto=tcp')
self.assertEqual(instance.api_rule, 'proto=tcp')
# noinspection PyPep8Naming # noinspection PyPep8Naming
@ -157,6 +159,7 @@ class TC_02_DstHost(qubes.tests.QubesTestCase):
self.assertEqual(instance.prefixlen, 128) self.assertEqual(instance.prefixlen, 128)
self.assertEqual(str(instance), '2001:abcd:efab::3/128') self.assertEqual(str(instance), '2001:abcd:efab::3/128')
self.assertEqual(instance.rule, 'dst6=2001:abcd:efab::3/128') self.assertEqual(instance.rule, 'dst6=2001:abcd:efab::3/128')
self.assertEqual(instance.api_rule, 'dst6=2001:abcd:efab::3/128')
def test_011_ipv6_prefixlen(self): def test_011_ipv6_prefixlen(self):
with self.assertNotRaises(ValueError): with self.assertNotRaises(ValueError):
@ -165,6 +168,7 @@ class TC_02_DstHost(qubes.tests.QubesTestCase):
self.assertEqual(instance.prefixlen, 64) self.assertEqual(instance.prefixlen, 64)
self.assertEqual(str(instance), '2001:abcd:efab::/64') self.assertEqual(str(instance), '2001:abcd:efab::/64')
self.assertEqual(instance.rule, 'dst6=2001:abcd:efab::/64') self.assertEqual(instance.rule, 'dst6=2001:abcd:efab::/64')
self.assertEqual(instance.api_rule, 'dst6=2001:abcd:efab::/64')
def test_012_ipv6_parse_prefixlen(self): def test_012_ipv6_parse_prefixlen(self):
with self.assertNotRaises(ValueError): with self.assertNotRaises(ValueError):
@ -173,6 +177,7 @@ class TC_02_DstHost(qubes.tests.QubesTestCase):
self.assertEqual(instance.prefixlen, 64) self.assertEqual(instance.prefixlen, 64)
self.assertEqual(str(instance), '2001:abcd:efab::/64') self.assertEqual(str(instance), '2001:abcd:efab::/64')
self.assertEqual(instance.rule, 'dst6=2001:abcd:efab::/64') self.assertEqual(instance.rule, 'dst6=2001:abcd:efab::/64')
self.assertEqual(instance.api_rule, 'dst6=2001:abcd:efab::/64')
def test_013_ipv6_invalid_prefix(self): def test_013_ipv6_invalid_prefix(self):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
@ -211,6 +216,7 @@ class TC_03_DstPorts(qubes.tests.QubesTestCase):
self.assertEqual(str(instance), '80') self.assertEqual(str(instance), '80')
self.assertEqual(instance.range, [80, 80]) self.assertEqual(instance.range, [80, 80])
self.assertEqual(instance.rule, 'dstports=80-80') self.assertEqual(instance.rule, 'dstports=80-80')
self.assertEqual(instance.api_rule, 'dstports=80-80')
def test_001_single_int(self): def test_001_single_int(self):
with self.assertNotRaises(ValueError): with self.assertNotRaises(ValueError):
@ -218,6 +224,7 @@ class TC_03_DstPorts(qubes.tests.QubesTestCase):
self.assertEqual(str(instance), '80') self.assertEqual(str(instance), '80')
self.assertEqual(instance.range, [80, 80]) self.assertEqual(instance.range, [80, 80])
self.assertEqual(instance.rule, 'dstports=80-80') self.assertEqual(instance.rule, 'dstports=80-80')
self.assertEqual(instance.api_rule, 'dstports=80-80')
def test_002_range(self): def test_002_range(self):
with self.assertNotRaises(ValueError): with self.assertNotRaises(ValueError):
@ -261,6 +268,7 @@ class TC_04_IcmpType(qubes.tests.QubesTestCase):
instance = qubes.firewall.IcmpType('8') instance = qubes.firewall.IcmpType('8')
self.assertEqual(str(instance), '8') self.assertEqual(str(instance), '8')
self.assertEqual(instance.rule, 'icmptype=8') self.assertEqual(instance.rule, 'icmptype=8')
self.assertEqual(instance.api_rule, 'icmptype=8')
def test_002_invalid(self): def test_002_invalid(self):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
@ -283,6 +291,7 @@ class TC_05_SpecialTarget(qubes.tests.QubesTestCase):
def test_001_rule(self): def test_001_rule(self):
instance = qubes.firewall.SpecialTarget('dns') instance = qubes.firewall.SpecialTarget('dns')
self.assertEqual(instance.rule, 'specialtarget=dns') self.assertEqual(instance.rule, 'specialtarget=dns')
self.assertEqual(instance.api_rule, 'specialtarget=dns')
class TC_06_Expire(qubes.tests.QubesTestCase): class TC_06_Expire(qubes.tests.QubesTestCase):
@ -290,6 +299,7 @@ class TC_06_Expire(qubes.tests.QubesTestCase):
with self.assertNotRaises(ValueError): with self.assertNotRaises(ValueError):
instance = qubes.firewall.Expire(1463292452) instance = qubes.firewall.Expire(1463292452)
self.assertEqual(str(instance), '1463292452') self.assertEqual(str(instance), '1463292452')
self.assertEqual(instance.api_rule, 'expire=1463292452')
self.assertEqual(instance.datetime, self.assertEqual(instance.datetime,
datetime.datetime(2016, 5, 15, 6, 7, 32)) datetime.datetime(2016, 5, 15, 6, 7, 32))
self.assertIsNone(instance.rule) self.assertIsNone(instance.rule)
@ -322,6 +332,7 @@ class TC_07_Comment(qubes.tests.QubesTestCase):
with self.assertNotRaises(ValueError): with self.assertNotRaises(ValueError):
instance = qubes.firewall.Comment('Some comment') instance = qubes.firewall.Comment('Some comment')
self.assertEqual(str(instance), 'Some comment') self.assertEqual(str(instance), 'Some comment')
self.assertEqual(instance.api_rule, 'comment=Some comment')
self.assertIsNone(instance.rule) self.assertIsNone(instance.rule)
@ -445,6 +456,33 @@ class TC_08_Rule(qubes.tests.QubesTestCase):
self.assertIsNone(rule.proto) self.assertIsNone(rule.proto)
self.assertIsNone(rule.dstports) self.assertIsNone(rule.dstports)
def test_008_from_api_string(self):
rule_txt = 'action=drop proto=tcp dstports=80-80'
with self.assertNotRaises(ValueError):
rule = qubes.firewall.Rule.from_api_string(
rule_txt)
self.assertEqual(rule.dstports.range, [80, 80])
self.assertEqual(rule.proto, 'tcp')
self.assertEqual(rule.action, 'drop')
self.assertIsNone(rule.dsthost)
self.assertIsNone(rule.expire)
self.assertIsNone(rule.comment)
self.assertEqual(rule.api_rule, rule_txt)
def test_009_from_api_string(self):
rule_txt = 'action=accept expire=1463292452 proto=tcp ' \
'comment=Some comment, with spaces'
with self.assertNotRaises(ValueError):
rule = qubes.firewall.Rule.from_api_string(
rule_txt)
self.assertEqual(rule.comment, 'Some comment, with spaces')
self.assertEqual(rule.proto, 'tcp')
self.assertEqual(rule.action, 'accept')
self.assertEqual(rule.expire, '1463292452')
self.assertIsNone(rule.dstports)
self.assertIsNone(rule.dsthost)
self.assertEqual(rule.api_rule, rule_txt)
class TC_10_Firewall(qubes.tests.QubesTestCase): class TC_10_Firewall(qubes.tests.QubesTestCase):
def setUp(self): def setUp(self):
@ -463,17 +501,17 @@ class TC_10_Firewall(qubes.tests.QubesTestCase):
def test_000_defaults(self): def test_000_defaults(self):
fw = qubes.firewall.Firewall(self.vm, False) fw = qubes.firewall.Firewall(self.vm, False)
fw.load_defaults() fw.load_defaults()
self.assertEqual(fw.policy, 'accept') self.assertEqual(fw.policy, 'drop')
self.assertEqual(fw.rules, []) self.assertEqual(fw.rules, [qubes.firewall.Rule(None, action='accept')])
def test_001_save_load_empty(self): def test_001_save_load_empty(self):
fw = qubes.firewall.Firewall(self.vm, True) fw = qubes.firewall.Firewall(self.vm, True)
self.assertEqual(fw.policy, 'accept') self.assertEqual(fw.policy, 'drop')
self.assertEqual(fw.rules, []) self.assertEqual(fw.rules, [qubes.firewall.Rule(None, action='accept')])
fw.save() fw.save()
fw.load() fw.load()
self.assertEqual(fw.policy, 'accept') self.assertEqual(fw.policy, 'drop')
self.assertEqual(fw.rules, []) self.assertEqual(fw.rules, [qubes.firewall.Rule(None, action='accept')])
def test_002_save_load_rules(self): def test_002_save_load_rules(self):
fw = qubes.firewall.Firewall(self.vm, True) fw = qubes.firewall.Firewall(self.vm, True)
@ -485,13 +523,13 @@ class TC_10_Firewall(qubes.tests.QubesTestCase):
qubes.firewall.Rule(None, action='accept', specialtarget='dns'), qubes.firewall.Rule(None, action='accept', specialtarget='dns'),
] ]
fw.rules.extend(rules) fw.rules.extend(rules)
fw.policy = qubes.firewall.Action.drop
fw.save() fw.save()
self.assertTrue(os.path.exists(os.path.join( self.assertTrue(os.path.exists(os.path.join(
self.vm.dir_path, self.vm.firewall_conf))) self.vm.dir_path, self.vm.firewall_conf)))
fw = qubes.firewall.Firewall(TestVM(), True) fw = qubes.firewall.Firewall(TestVM(), True)
self.assertEqual(fw.policy, qubes.firewall.Action.drop) self.assertEqual(fw.policy, qubes.firewall.Action.drop)
self.assertEqual(fw.rules, rules) self.assertEqual(fw.rules,
[qubes.firewall.Rule(None, action='accept')] + rules)
def test_003_load_v1(self): def test_003_load_v1(self):
xml_txt = """<QubesFirewallRules dns="allow" icmp="allow" xml_txt = """<QubesFirewallRules dns="allow" icmp="allow"
@ -524,8 +562,7 @@ class TC_10_Firewall(qubes.tests.QubesTestCase):
dstports=67, expire=1373300257), dstports=67, expire=1373300257),
qubes.firewall.Rule(None, action='accept', specialtarget='dns'), qubes.firewall.Rule(None, action='accept', specialtarget='dns'),
] ]
fw.rules.extend(rules) fw.rules = rules
fw.policy = qubes.firewall.Action.drop
fw.save() fw.save()
rules.pop(2) rules.pop(2)
fw = qubes.firewall.Firewall(self.vm, True) fw = qubes.firewall.Firewall(self.vm, True)
@ -539,8 +576,7 @@ class TC_10_Firewall(qubes.tests.QubesTestCase):
qubes.firewall.Rule(None, action='accept', proto='udp'), qubes.firewall.Rule(None, action='accept', proto='udp'),
qubes.firewall.Rule(None, action='accept', specialtarget='dns'), qubes.firewall.Rule(None, action='accept', specialtarget='dns'),
] ]
fw.rules.extend(rules) fw.rules = rules
fw.policy = qubes.firewall.Action.drop
expected_qdb_entries = { expected_qdb_entries = {
'policy': 'drop', 'policy': 'drop',
'0000': 'action=drop proto=icmp', '0000': 'action=drop proto=icmp',