#!/usr/bin/python2
# coding=utf-8
#
# The Qubes OS Project, http://www.qubes-os.org
#
# Copyright (C) 2010      Rafal Wojtczuk  <rafal@invisiblethingslab.com>
# Copyright (C) 2013-2015 Marek Marczykowski-Górecki
#                                         <marmarek@invisiblethingslab.com>
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
#
#
import os
import subprocess
import sys
import shutil
import time

from qubes.qubes import QubesVmCollection, QubesException
from qubes.qubes import QubesDispVmLabels
from qubes.notify import tray_notify, tray_notify_error, tray_notify_init


current_savefile = '/var/run/qubes/current-savefile'
current_savefile_vmdir = '/var/lib/qubes/dvmdata/vmdir'


class QfileDaemonDvm:
    def __init__(self, name):
        self.name = name

    @staticmethod
    def get_disp_templ():
        vmdir = os.readlink(current_savefile_vmdir)
        return vmdir.split('/')[-1]
        
    def do_get_dvm(self):
        tray_notify("Starting new DispVM...", "red")

        qvm_collection = QubesVmCollection()
        qvm_collection.lock_db_for_writing()
        try:

            tar_process = subprocess.Popen(
                ['bsdtar', '-C', current_savefile_vmdir,
                 '-xSUf', os.path.join(current_savefile_vmdir, 'saved-cows.tar')])

            qvm_collection.load()
            print >>sys.stderr, "time=%s, collection loaded" % (str(time.time()))

            vm = qvm_collection.get_vm_by_name(self.name)
            if vm is None:
                sys.stderr.write('Domain ' + self.name + ' does not exist ?')
                return None
            label = vm.label
            if len(sys.argv) > 4 and len(sys.argv[4]) > 0:
                assert sys.argv[4] in QubesDispVmLabels.keys(), "Invalid label"
                label = QubesDispVmLabels[sys.argv[4]]
            disp_templ = self.get_disp_templ()
            vm_disptempl = qvm_collection.get_vm_by_name(disp_templ)
            if vm_disptempl is None:
                sys.stderr.write('Domain ' + disp_templ + ' does not exist ?')
                return None
            dispvm = qvm_collection.add_new_vm('QubesDisposableVm',
                                               disp_template=vm_disptempl,
                                               label=label)
            print >>sys.stderr, "time=%s, VM created" % (str(time.time()))
            # By default inherit firewall rules from calling VM
            disp_firewall_conf = '/var/run/qubes/%s-firewall.xml' % dispvm.name
            dispvm.firewall_conf = disp_firewall_conf
            if os.path.exists(vm.firewall_conf):
                shutil.copy(vm.firewall_conf, disp_firewall_conf)
            elif vm.qid == 0 and os.path.exists(vm_disptempl.firewall_conf):
                # for DispVM called from dom0, copy use rules from DispVM template
                shutil.copy(vm_disptempl.firewall_conf, disp_firewall_conf)
            if len(sys.argv) > 5 and len(sys.argv[5]) > 0:
                assert os.path.exists(sys.argv[5]), "Invalid firewall.conf location"
                dispvm.firewall_conf = sys.argv[5]
            if vm.qid != 0:
                dispvm.uses_default_netvm = False
                # netvm can be changed before restore,
                # but cannot be enabled/disabled
                if (dispvm.netvm is None) == (vm.dispvm_netvm is None):
                    dispvm.netvm = vm.dispvm_netvm
            # Wait for tar to finish
            if tar_process.wait() != 0:
                sys.stderr.write('Failed to unpack saved-cows.tar')
                return None
            print >>sys.stderr, "time=%s, VM starting" % (str(time.time()))
            try:
                dispvm.start()
            except (MemoryError, QubesException) as e:
                tray_notify_error(str(e))
                raise
            if vm.qid != 0:
                # if need to enable/disable netvm, do it while DispVM is alive
                if (dispvm.netvm is None) != (vm.dispvm_netvm is None):
                    dispvm.netvm = vm.dispvm_netvm
            print >>sys.stderr, "time=%s, VM started" % (str(time.time()))
            qvm_collection.save()
        finally:
            qvm_collection.unlock_db()
        # Reload firewall rules
        print >>sys.stderr, "time=%s, reloading firewall" % (str(time.time()))
        for vm in qvm_collection.values():
            if vm.is_proxyvm() and vm.is_running():
                vm.write_iptables_qubesdb_entry()

        return dispvm

    @staticmethod
    def dvm_setup_ok():
        dvmdata_dir = '/var/lib/qubes/dvmdata/'
        if not os.path.isfile(current_savefile):
            return False
        if not os.path.isfile(dvmdata_dir+'default-savefile') or \
                not os.path.isfile(dvmdata_dir+'savefile-root'):
            return False
        dvm_mtime = os.stat(current_savefile).st_mtime
        root_mtime = os.stat(dvmdata_dir+'savefile-root').st_mtime
        if dvm_mtime < root_mtime:
            template_name = os.path.basename(
                os.path.dirname(os.readlink(dvmdata_dir+'savefile-root')))
            if subprocess.call(["xl", "domid", template_name],
                               stdout=open(os.devnull, "w")) == 0:
                tray_notify("For optimum performance, you should not "
                            "start DispVM when its template is running.", "red")
            return False       
        return True

    def get_dvm(self):
        if not self.dvm_setup_ok():
            if os.system("/usr/lib/qubes/"
                         "qubes-update-dispvm-savefile-with-progress.sh"
                         " >/dev/null </dev/null") != 0:
                tray_notify_error("DVM savefile creation failed")
                return None 
        return self.do_get_dvm()

    @staticmethod
    def remove_disposable_from_qdb(name):
        qvm_collection = QubesVmCollection()
        qvm_collection.lock_db_for_writing()
        qvm_collection.load()
        vm = qvm_collection.get_vm_by_name(name)
        if vm is None:
            qvm_collection.unlock_db()
            return False
        qvm_collection.pop(vm.qid)
        qvm_collection.save()
        qvm_collection.unlock_db()


def main():
    exec_index = sys.argv[1]
    src_vmname = sys.argv[2]
    user = sys.argv[3]
    # accessed directly by get_dvm()
    #  sys.argv[4] - override label
    #  sys.argv[5] - override firewall

    print >>sys.stderr, "time=%s, qfile-daemon-dvm init" % (str(time.time()))
    tray_notify_init()
    print >>sys.stderr, "time=%s, creating DispVM" % (str(time.time()))
    qfile = QfileDaemonDvm(src_vmname)
    dispvm = qfile.get_dvm()
    if dispvm is not None:
        print >>sys.stderr, "time=%s, starting VM process" % (str(time.time()))
        subprocess.call(['/usr/lib/qubes/qrexec-client', '-d', dispvm.name,
                         user+':exec /usr/lib/qubes/qubes-rpc-multiplexer ' +
                         exec_index + " " + src_vmname])
        try:
            dispvm.force_shutdown()
        except QubesException:
            # VM already destroyed
            pass
        qfile.remove_disposable_from_qdb(dispvm.name)

main()