#!/usr/bin/python2.6
#
# The Qubes OS Project, http://www.qubes-os.org
#
# Copyright (C) 2010  Rafal Wojtczuk  <rafal@invisiblethingslab.com>
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
#
#

import xen.lowlevel.xs
import os
import sys
import subprocess
import daemon
import time
from qubes.qubes import QubesVmCollection
from qubes.qubes import QubesException
from qubes.qubes import QubesDaemonPidfile

filename_seq = 50
pen_cmd = '/usr/lib/qubes/qubes_pencmd'

def get_next_filename_seq():
    global filename_seq
    filename_seq = filename_seq + 1
    return str(filename_seq)

def logproc(msg):
        f = file('/var/log/qubes/qfileexchgd', 'a')
        f.write(msg+'\n')
        f.close()

def get_req_node(domain_id):
    return '/local/domain/'+domain_id+'/device/qpen'

def get_name_node(domain_id):
    return '/local/domain/'+domain_id+'/name'

def only_in_first_list(l1, l2):
    ret=[]
    for i in l1:
        if not i in l2:
            ret.append(i)
    return ret


class WatchType:
    def __init__(self, fn, param):
        self.fn = fn
        self.param = param

class DomainState:
    def __init__(self, domain, dict):
        self.rcv_state = 'idle'
        self.send_state = 'idle'
        self.domain_id = domain
        self.domdict = dict
        self.send_seq = None
        self.rcv_seq = None
        self.waiting_sender = None

    def handle_request(self, request):
        req_ok = False
        if request is None:
            return
        tmp = request.split()
        rq = tmp[0]
        if len(tmp) > 1:
            arg = tmp[1]
        else:
            arg = None
        if rq == 'new' and self.send_state == 'idle':
            self.send_seq = get_next_filename_seq()
            retcode = subprocess.call([pen_cmd, 'new', self.domain_id, self.send_seq])
            logproc( 'Give domain ' + self.domain_id + ' a clean pendrive, retcode= ' + str(retcode))
            if retcode == 0:
                self.send_state = 'has_clean_pendrive'
            req_ok = True
        if rq == 'send' and self.send_state == 'has_clean_pendrive' and arg is not None:
            logproc( 'send from ' + self.domain_id + ' to ' + arg)
            if self.handle_transfer(arg):
                self.send_state = 'idle'
            req_ok = True;
        if rq == 'umount' and self.rcv_state == 'has_loaded_pendrive':
            retcode = subprocess.call([pen_cmd, 'umount', self.domain_id, self.rcv_seq])
            if retcode == 0:
                self.rcv_state = 'idle'
                self.rcv_seq = None
            logproc( 'set state of ' + self.domain_id + ' loaded->idle retcode=' + str(retcode))
            req_ok = True
        if rq == 'umount' and self.rcv_state == 'waits_to_umount':
            req_ok = True
            retcode = subprocess.call([pen_cmd, 'umount', self.domain_id, self.rcv_seq])
            if retcode != 0:
                return
            assert(self.waiting_sender != None)
            self.rcv_state = 'idle'
            self.rcv_seq = None
            tmp = self.waiting_sender
            self.waiting_sender = None
            if tmp.send_state == 'has_clean_pendrive':
                if tmp.handle_transfer(self.name):
                    tmp.send_state = 'idle'

        if not req_ok:
            logproc( 'request ' + request + ' not served due to nonmatching state')

    def ask_to_umount(self, vmname):
        q = 'VM ' + vmname + ' has already an incoming pendrive, and thus '
        q+= 'cannot accept another one. If you intend to unmount its current '
        q+= 'pendrive and retry this transfer, press Yes. '
        q+= 'Otherwise press No to fail this transfer.'
        retcode = subprocess.call(['/usr/bin/kdialog', '--yesno', q, '--title', 'Some additional action required'])
        if retcode == 0:
            return True
        else:
            return False

    def handle_transfer(self, vmname):
        qvm_collection = QubesVmCollection()
        qvm_collection.lock_db_for_reading()
        qvm_collection.load()
        qvm_collection.unlock_db()

        vm = qvm_collection.get_vm_by_name(vmname)
        if vm is None:
            logproc( 'Domain ' + vmname + ' does not exist ?')
            return False
        if not vm.is_running():
            logproc( 'Domain ' + vmname + ' is not running ?')
            return False
        target=self.domdict[str(vm.get_xid())]
        if target.rcv_state != 'idle':
            if self.ask_to_umount(vmname):
                target.rcv_state='waits_to_umount'
                target.waiting_sender=self
            logproc( 'target domain ' + target.domain_id + ' is not idle, now ' + target.rcv_state)
            return False
        retcode = subprocess.call(['/usr/bin/kdialog', '--yesno', 'Do you authorize pendrive transfer from ' + self.name + ' to ' + vmname + '?' , '--title', 'Security confirmation'])
        logproc('handle_transfer: kdialog retcode=' + str(retcode))
        if retcode != 0:
            return False
        target.rcv_state='has_loaded_pendrive'
        retcode = subprocess.call([pen_cmd, 'send', self.domain_id, target.domain_id, self.send_seq])
        target.rcv_seq = self.send_seq
        self.send_seq = None
        logproc( 'set state of ' + target.domain_id + ' to has_loaded_pendrive, retcode=' + str(retcode))
        return True


class XS_Watcher:
    def __init__(self):
        self.handle = xen.lowlevel.xs.xs()
        self.handle.watch('/local/domain', WatchType(XS_Watcher.dom_list_change, None))
        self.domdict = {}

    def dom_list_change(self, param):
        curr = self.handle.ls('', '/local/domain')
        if curr == None:
            return
        for i in only_in_first_list(curr, self.domdict.keys()):
            newdom = DomainState(i, self.domdict)
            newdom.watch_token = WatchType(XS_Watcher.request, newdom)
            newdom.watch_name = WatchType(XS_Watcher.namechange, newdom)
            self.domdict[i] = newdom
            self.handle.watch(get_req_node(i), newdom.watch_token)
            self.handle.watch(get_name_node(i), newdom.watch_name)
            newdom.name = ''
            logproc( 'added domain ' + i)
        for i in only_in_first_list(self.domdict.keys(), curr):
            self.handle.unwatch(get_req_node(i), self.domdict[i].watch_token)
            self.handle.unwatch(get_name_node(i), self.domdict[i].watch_name)
            self.domdict.pop(i)
            logproc( 'removed domain ' + i)

    def request(self, domain_param):
        ret = self.handle.read('', get_req_node(domain_param.domain_id))
        domain_param.handle_request(ret)

    def namechange(self, domain_param):
        ret = self.handle.read('', get_name_node(domain_param.domain_id))
	if ret!= '' and ret!=None:
		domain_param.name = ret
                logproc( 'Name for domain xid ' + domain_param.domain_id + ' is ' + ret )

    def watch_loop(self):
        sys.stderr = file('/var/log/qubes/qfileexchgd.errors', 'a')
        while True:
            result = self.handle.read_watch()
            token = result[1]
            token.fn(self, token.param)

def main():

    lock = QubesDaemonPidfile ("qfileexchgd")
    if lock.pidfile_exists():
        if lock.pidfile_is_stale():
            lock.remove_pidfile()
            print "Removed stale pidfile (has the previous daemon instance crashed?)."
        else:
            exit (0)


    context = daemon.DaemonContext(
        working_directory = "/var/run/qubes",
        pidfile = lock)
    with context:
        XS_Watcher().watch_loop()

main()