#!/usr/bin/python2
# -*- encoding: utf8 -*-
#
# The Qubes OS Project, http://www.qubes-os.org
#
# Copyright (C) 2012  Marek Marczykowski <marmarek@invisiblethingslab.com>
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
#
#
import datetime

from qubes.qubes import QubesVmCollection
from optparse import OptionParser;
import subprocess
import sys
import re
import os
import socket

def parse_rule(args):
    if len(args) < 2:
        print >>sys.stderr, "ERROR: Rule must have at least address and protocol"
        return None

    address = args[0]
    netmask = 32
    proto = args[1]
    port = args[2] if len(args) > 2 else None
    port_end = None

    unmask = address.split("/", 1)
    if len(unmask) == 2:
        address = unmask[0]
        netmask = unmask[1]
        if netmask.isdigit():
            if re.match("^([0-9]{1,3}\.){3}[0-9]{1,3}$", address) is None:
                print >>sys.stderr, "ERROR: Only IP is allowed when specyfying netmask"
                return None
            if netmask != "":
                netmask = int(unmask[1])
                if netmask < 0 or netmask > 32:
                    print >>sys.stderr, "ERROR: Invalid netmask"
                    return None
        else:
            print >>sys.stderr, "ERROR: Invalid netmask"
            return None

    if address[-1:] == ".":
        address = address[:-1]

    allowed = re.compile("(?!-)[A-Z\d-]{1,63}(?<!-)$", re.IGNORECASE)
    if not all(allowed.match(x) for x in address.split(".")):
        print >>sys.stderr, "ERROR: Invalid hostname"
        return None

    proto_split = proto.split('/', 1)
    if len(proto_split) == 2:
        proto = proto_split[0]
        port = proto_split[1]

    if proto not in ['tcp', 'udp', 'any']:
        print >>sys.stderr, "ERROR: Protocol must be one of: 'tcp', 'udp', 'any'"
        return None

    if proto != "any" and port is None:
        print >>sys.stderr, "ERROR: Port required for protocol %s" % args[1]
        return None

    if port is not None:
        port_range = port.split('-', 1)
        if len(port_range) == 2:
            port = port_range[0]
            port_end = port_range[1]

        if port.isdigit():
            port = int(port)
        else:
            try:
                port = socket.getservbyname(port)
            except socket.error:
                print >>sys.stderr, "ERROR: Invalid port/service name '%s'" % port
                return None

        if port_end is not None and not port_end.isdigit():
            print >>sys.stderr, "ERROR: Invalid port '%s'" % port_end
            return None

        if port_end is not None:
            port_end = int(port_end)

    rule = {}
    rule['address'] = address
    rule['netmask'] = netmask
    rule['proto'] = proto
    rule['portBegin'] = port
    rule['portEnd'] = port_end
    return rule

def list_rules(rules, numeric=False):
    fields = [ "num", "address", "proto", "port(s)" ]

    rules_to_display = list()
    counter = 1
    for rule in rules:
        parsed_rule = {
            'num': "{0:>2}".format(counter),
            'address': rule['address'] + ('/' + str(rule['netmask']) if rule['netmask'] < 32 else ""),
            'proto': rule['proto'],
            'port(s)': '',
        }
        if rule['proto'] in ['tcp', 'udp']:
            parsed_rule['port(s)'] = str(rule['portBegin']) + \
                ('-' + str(rule['portEnd']) if rule['portEnd'] is not None else '')
            if not numeric and rule['portBegin'] is not None and rule['portEnd'] is None:
                try:
                    parsed_rule['port(s)'] = str(socket.getservbyport(rule['portBegin']))
                except socket.error:
                    pass

        if 'expire' in rule:
            parsed_rule['expire'] = str(datetime.datetime.fromtimestamp(rule[
                'expire']))

        rules_to_display.append(parsed_rule)
        counter += 1

    fields_width = {}
    for f in fields:
        fields_width[f] = len(f)
        for r in rules_to_display:
            if len(r[f]) > fields_width[f]:
                fields_width[f] = len(r[f])

    # Display the header
    s = ""
    for f in fields:
        fmt="{{0:-^{0}}}-+".format(fields_width[f] + 1)
        s += fmt.format('-')
    print s

    s = ""
    for f in fields:
        fmt=" {{0:^{0}}} |".format(fields_width[f])
        s += fmt.format(f)
    print s

    s = ""
    for f in fields:
        fmt="{{0:-^{0}}}-+".format(fields_width[f] + 1)
        s += fmt.format('-')
    print s

    # And the content
    for r in rules_to_display:
        s = ""
        for f in fields:
            fmt=" {{0:<{0}}} |".format(fields_width[f])
            s += fmt.format(r[f])
        if 'expire' in r:
            s += " <-- expires at %s" % r['expire']
        print s

def display_firewall(conf, numeric=False):
    print "Firewall policy: %s" % (
            "ALLOW all traffic except" if conf['allow'] else "DENY all traffic except")
    print "ICMP: %s" % ("ALLOW" if conf['allowIcmp'] else 'DENY')
    print "DNS: %s" % ("ALLOW" if conf['allowDns'] else 'DENY')
    print "Qubes yum proxy: %s" % ("ALLOW" if conf['allowYumProxy'] else 'DENY')
    list_rules(conf['rules'], numeric)

def add_rule(conf, args):
    rule = parse_rule(args)
    if rule is None:
        return False

    conf['rules'].append(rule)
    return True

def del_rule(conf, args):
    if len(args) == 1 and args[0].isdigit():
        rulenum = int(args[0])
        if rulenum < 1 or rulenum > len(conf['rules']):
            print >>sys.stderr, "ERROR: Rule number out of range"
            return False
        conf['rules'].pop(rulenum-1)
    else:
        rule = parse_rule(args)
        #print "PARSED: %s" % str(rule)
        #print "ALL: %s" % str(conf['rules'])
        if rule is None:
            return False
        try:
            conf['rules'].remove(rule)
        except ValueError:
            print >>sys.stderr, "ERROR: Rule not found"
            return False

    return True

def allow_deny_value(s):
    value = None
    if s == "allow":
        value = True
    elif s == "deny":
        value = False
    else:
        print >>sys.stderr, 'ERROR: Only "allow" or "deny" allowed'
        exit(1)
    return value

def main():
    usage =  "usage: %prog [-n] <vm-name> [action] [rule spec]\n"
    usage += "       rule specification can be one of:\n"
    usage += "         address|hostname[/netmask] tcp|udp port[-port]\n"
    usage += "         address|hostname[/netmask] tcp|udp service_name\n"
    usage += "         address|hostname[/netmask] any\n"
    parser = OptionParser (usage)
    parser.add_option ("-l", "--list", dest="do_list", action="store_true", default=True,
            help="List firewall settings (default action)")
    parser.add_option ("-a", "--add", dest="do_add", action="store_true", default=False,
            help="Add rule")
    parser.add_option ("-d", "--del", dest="do_del", action="store_true", default=False,
            help="Remove rule (given by number or by rule spec)")
    parser.add_option ("-P", "--policy", dest="set_policy", action="store", default=None,
            help="Set firewall policy (allow/deny)")
    parser.add_option ("-i", "--icmp", dest="set_icmp", action="store", default=None,
            help="Set ICMP access (allow/deny)")
    parser.add_option ("-D", "--dns", dest="set_dns", action="store", default=None,
            help="Set DNS access (allow/deny)")
    parser.add_option ("-Y", "--yum-proxy", dest="set_yum_proxy", action="store", default=None,
            help="Set access to Qubes yum proxy (allow/deny)")
    parser.add_option ("-r", "--reload", dest="reload", action="store_true",
                       default=False, help="Reload firewall (implied by any "
                                           "change action")

    parser.add_option ("-n", "--numeric", dest="numeric", action="store_true", default=False,
            help="Display port numbers instead of services (makes sense only with --list)")
    parser.add_option ("--force-root", action="store_true", dest="force_root", default=False,
                       help="Force to run, even with root privileges")

    (options, args) = parser.parse_args ()
    if (len (args) < 1):
        parser.error ("You must specify VM name!")
    vmname = args[0]
    args = args[1:]

    if hasattr(os, "geteuid") and os.geteuid() == 0:
        if not options.force_root:
            print >> sys.stderr, "*** Running this tool as root is strongly discouraged, this will lead you in permissions problems."
            print >> sys.stderr, "Retry as unprivileged user."
            print >> sys.stderr, "... or use --force-root to continue anyway."
            exit(1)

    if options.do_add or options.do_del or options.set_policy or \
            options.set_icmp or options.set_dns or options.set_yum_proxy:
        options.do_list = False
    qvm_collection = QubesVmCollection()
    if options.do_list:
        qvm_collection.lock_db_for_reading()
        qvm_collection.load()
        qvm_collection.unlock_db()
    else:
        qvm_collection.lock_db_for_writing()
        qvm_collection.load()

    vm = qvm_collection.get_vm_by_name(vmname)
    if vm is None:
        print >> sys.stderr, "A VM with the name '{0}' does not exist in the system.".format(vmname)
        exit(1)

    changed = False
    conf = vm.get_firewall_conf()

    if options.set_policy:
        conf['allow'] = allow_deny_value(options.set_policy)
        changed = True
    if options.set_icmp:
        conf['allowIcmp'] = allow_deny_value(options.set_icmp)
        changed = True
    if options.set_dns:
        conf['allowDns'] = allow_deny_value(options.set_dns)
        changed = True
    if options.set_yum_proxy:
        conf['allowYumProxy'] = allow_deny_value(options.set_yum_proxy)
        changed = True

    if options.do_add:
        changed = add_rule(conf, args)
    elif options.do_del:
        changed = del_rule(conf, args)
    elif options.do_list and not options.reload:
        if not vm.has_firewall():
            print "INFO: This VM has no firewall rules set, below defaults are listed"
        display_firewall(conf, options.numeric)

    if changed:
        vm.write_firewall_conf(conf)
        qvm_collection.save()
    if changed or options.reload:
        if vm.is_running():
            if vm.netvm is not None and vm.netvm.is_proxyvm():
                vm.netvm.write_iptables_qubesdb_entry()

    if not options.do_list:
        qvm_collection.unlock_db()


main()