#!/usr/bin/python2.6
#
# The Qubes OS Project, http://www.qubes-os.org
#
# Copyright (C) 2010  Joanna Rutkowska <joanna@invisiblethingslab.com>
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
#
#

from qubes.qubes import QubesVmCollection
from qubes.qubes import QubesException
from qubes.qubes import qubes_store_filename
from qubes.qubes import qubes_base_dir
from qubes.qubes import qubes_templates_dir
from optparse import OptionParser

import os
import time
import subprocess
import sys

def size_to_human (size):
    if size < 1024:
        return str (size);
    elif size < 1024*1024:
        return str(round(size/1024.0,1)) + ' KiB'
    elif size < 1024*1024*1024:
        return str(round(size/(1024.0*1024),1)) + ' MiB'
    else:
        return str(round(size/(1024.0*1024*1024),1)) + ' GiB'

fields = {
    "qid": {"func": "vm.qid"},

    "name": {"func": "('=>' if backup_collection.get_default_template_vm() is not None\
             and vm.qid == backup_collection.get_default_template_vm().qid else '')\
             + ('[' if vm.is_templete() else '')\
             + ('{' if vm.is_netvm() else '')\
             + vm.name \
             + (']' if vm.is_templete() else '')\
             + ('}' if vm.is_netvm() else '')"},

    "type": {"func": "'Tpl' if vm.is_templete() else \
             (' Net' if vm.is_netvm() else 'App')"},

    "updbl" : {"func": "'Yes' if vm.is_updateable() else ''"},

    "template": {"func": "'n/a' if vm.is_templete() or vm.is_netvm() else\
                 backup_collection[vm.template_vm.qid].name"},

    "netvm": {"func": "'n/a' if vm.is_netvm() else\
              ('*' if vm.uses_default_netvm else '') +\
              backup_collection[vm.netvm_vm.qid].name\
                     if vm.netvm_vm is not None else '-'"},

    "label" : {"func" : "vm.label.name"},
}

def is_vm_included_in_backup (backup_dir, vm):
    if vm.qid == 0:
        # Dom0 is not included, obviously
        return False

    backup_vm_dir_path = vm.dir_path.replace (qubes_base_dir, backup_dir)

    if os.path.exists (backup_vm_dir_path):
        return True
    else:
        return False

def restore_vm_file (backup_dir, file_path):

    backup_file_path = file_path.replace (qubes_base_dir, backup_dir)
    #print "cp -rp {0} {1}".format (backup_file_path, file_path)

    # We prefer to use Linux's cp, because it nicely handles sparse files
    retcode = subprocess.call (["cp", "-p", backup_file_path, file_path])
    if retcode != 0:
        print "*** Error while copying file {0} to {1}".format(backup_file_path, file_path)
        exit (1)

def restore_vm_dir (backup_dir, src_dir, dst_dir):

    backup_src_dir = src_dir.replace (qubes_base_dir, backup_dir)

    # We prefer to use Linux's cp, because it nicely handles sparse files
    retcode = subprocess.call (["cp", "-rp", backup_src_dir, dst_dir])
    if retcode != 0:
        print "*** Error while copying file {0} to {1}".format(backup_src_dir, dest_dir)
        exit (1)


def main():
    usage = "usage: %prog [options] <backup-dir>"
    parser = OptionParser (usage)

    parser.add_option ("--skip-broken", action="store_true", dest="skip_broken", default=False,
                       help="Do not restore VMs that have missing templates or netvms")

    parser.add_option ("--ignore-missing", action="store_true", dest="ignore_missing", default=False,
                       help="Ignore missing templates or netvms, restore VMs anyway")
 
    (options, args) = parser.parse_args ()

    if (len (args) != 1):
        print "You must specify the backup directory (e.g. /mnt/backup/qubes-2010-12-01-235959)"
        exit (0)

    backup_dir = args[0]

    if not os.path.exists (backup_dir):
        print "The backup directory doesn't exist!"
        exit(1)

    backup_collection = QubesVmCollection(store_filename = backup_dir + "/qubes.xml")
    backup_collection.lock_db_for_reading()
    backup_collection.load()

    host_collection = QubesVmCollection()
    host_collection.lock_db_for_writing()
    host_collection.load()

    backup_vms_list = [vm for vm in backup_collection.values()]
    host_vms_list = [vm for vm in host_collection.values()]
    vms_to_restore = []

    fields_to_display = ["name", "type", "template", "updbl", "netvm", "label" ]

    # First calculate the maximum width of each field we want to display
    total_width = 0;
    for f in fields_to_display:
        fields[f]["max_width"] = len(f)
        for vm in backup_vms_list:
            l = len(str(eval(fields[f]["func"])))
            if l > fields[f]["max_width"]:
                fields[f]["max_width"] = l
        total_width += fields[f]["max_width"]

    print
    print "The following VMs are included in the backup:"
    print

    # Display the header
    s = ""
    for f in fields_to_display:
        fmt="{{0:-^{0}}}-+".format(fields[f]["max_width"] + 1)
        s += fmt.format('-')
    print s
    s = ""
    for f in fields_to_display:
        fmt="{{0:>{0}}} |".format(fields[f]["max_width"] + 1)
        s += fmt.format(f) 
    print s
    s = ""
    for f in fields_to_display:
        fmt="{{0:-^{0}}}-+".format(fields[f]["max_width"] + 1)
        s += fmt.format('-')
    print s

    there_are_conflicting_vms = False
    there_are_missing_templates = False
    there_are_missing_netvms = False
    # ... and the actual data
    for vm in backup_vms_list:
        if is_vm_included_in_backup (backup_dir, vm):
            s = ""
            good_to_go = True

            for f in fields_to_display:
                fmt="{{0:>{0}}} |".format(fields[f]["max_width"] + 1)
                s += fmt.format(eval(fields[f]["func"])) 

            if host_collection.get_vm_by_name (vm.name) is not None:
                s +=  " <-- A VM with the same name already exists on the host!"
                there_are_conflicting_vms = True
                good_to_go = False

            if vm.is_appvm():
                templatevm_name = vm.template_vm.name
                template_vm_on_host = host_collection.get_vm_by_name (templatevm_name)

                # No template on the host?
                if not ((template_vm_on_host is not None) and template_vm_on_host.is_templete):

                    # Maybe the (custom) template is in the backup?
                    template_vm_on_backup = backup_collection.get_vm_by_name (templatevm_name)
                    if not ((template_vm_on_backup is not None) and template_vm_on_backup.is_templete):
                        s += " <-- No matching template on the host or in the backup found!"
                        there_are_missing_templates = True
                        good_to_go = False if not (options.ignore_missing) else True
                
            if not vm.is_netvm() and vm.netvm_vm is not None:
                netvm_name = vm.netvm_vm.name
                netvm_vm_on_host = host_collection.get_vm_by_name (netvm_name)

                # No netv, on the host?
                if not ((netvm_vm_on_host is not None) and netvm_vm_on_host.is_netvm):

                    # Maybe the (custom) netvm is in the backup?
                    netvm_vm_on_backup = backup_collection.get_vm_by_name (netvm_name)
                    if not ((netvm_vm_on_backup is not None) and netvm_vm_on_backup.is_netvm):
                        s += " <-- No matching netvm on the host found!"
                        there_are_missing_netvms = True
                        good_to_go = False if not (options.ignore_missing) else True

            print s
            if good_to_go:
                vms_to_restore.append (vm)

    print

    if there_are_conflicting_vms:
        print "*** There VMs with conflicting names on the host! ***"
        print "Remove VMs with conflicting names from the host before proceeding."
        print "You can use 'qvm-remove <vmname>' command for this."
        print "Be careful! Think twice before typing qvm-remove!!!"
        exit (1)

    print "The above VMs will be copied and added to your system."
    print "Exisiting VMs will not be removed."
 
    if there_are_missing_templates:
        print "*** One or more template VM is missing on the host! ***"
        if not (options.skip_broken or options.ignore_missing):
            print "Install it first, before proceeding with backup restore."
            print "Or pass: --skip-broken or --ignore-missing switch."
            exit (1)
        elif options.skip_broken:
            print "... VMs that depend on it will not be restored (--skip-broken used)"
        elif options.ignore_missing:
            print "... VMs that depend on it be restored anyway (--ignore-missing used)"
        else:
            print "INTERNAL ERROR?!"
            exit (1)

    if there_are_missing_netvms:
        print "*** One or more network VM is missing on the host! ***"
        if not (options.skip_broken or options.ignore_missing):
            print "Install it first, before proceeding with backup restore."
            print "Or pass: --skip_broken or --ignore_missing switch."
            exit (1)
        elif options.skip_broken:
            print "... VMs that depend on it will not be restored (--skip-broken used)"
        elif options.ignore_missing:
            print "... VMs that depend on it be restored anyway (--ignore-missing used)"
        else:
            print "INTERNAL ERROR?!"
            exit (1)

    prompt = raw_input ("Do you want to proceed? [y/N] ")
    if not (prompt == "y" or prompt == "Y"):
        exit (0)

    for vm in vms_to_restore:
        print "-> Restoring: {0} ...".format(vm.name)

        retcode = subprocess.call (["mkdir", "-p", vm.dir_path])
        if retcode != 0:
            print ("*** Cannot create directory: {0}?!".format(dest_dir))
            print ("Skiping...")
            continue

        if vm.is_appvm():

            restore_vm_file (backup_dir, vm.private_img)
            restore_vm_file (backup_dir, vm.icon_path)
            restore_vm_file (backup_dir, vm.conf_file)

            if vm.is_updateable():
                restore_vm_file (backup_dir, vm.rootcow_img)

        elif vm.is_templete():
            restore_vm_dir (backup_dir, vm.dir_path, qubes_templates_dir);
        else:
            print "ERROR: VM '{0}', type='{1}': unsupported VM type!".format(vm.name, vm.type)

    # Add templates...
    for vm in [ vm for vm in vms_to_restore if vm.is_templete()]:
        print "-> Adding Template VM {0}...".format(vm.name)
        updateable = vm.updateable
        vm = host_collection.add_new_templatevm(vm.name, 
                                               conf_file=vm.conf_file,
                                               dir_path=vm.dir_path,
                                               installed_by_rpm=False)

        vm.updateable = updateable
        try:
            vm.verify_files()
        except QubesException as err:
            print "ERROR: {0}".format(err)
            print "*** Skiping VM: {0}".vm.name
            host_collection.pop(vm.qid)
            continue

        try:
            vm.add_to_xen_storage()

        except (IOError, OSError) as err:
            print "ERROR: {0}".format(err)
            print "*** Skiping VM: {0}".vm.name
            host_collection.pop(vm.qid)
            continue

    # ... then appvms...
    for vm in [ vm for vm in vms_to_restore if vm.is_appvm()]:

        print "-> Adding AppVM {0}...".format(vm.name)
        template_vm = host_collection.get_vm_by_name(vm.template_vm.name)

        if not vm.uses_default_netvm:
            uses_default_netvm = False
            netvm_vm = host_collection.get_vm_by_name (vm.netvm_vm.name) if vm.netvm_vm is not None else None
        else:
            uses_default_netvm = True

        updateable = vm.updateable

        vm = host_collection.add_new_appvm(vm.name, template_vm,
                                          conf_file=vm.conf_file,
                                          dir_path=vm.dir_path,
                                          label=vm.label)

        vm.updateable = updateable
        if not uses_default_netvm:
            vm.uses_default_netvm = False
            vm.netvm_vm = netvm_vm


        vm.create_appmenus(verbose=True)
        try:
            vm.verify_files()
        except QubesException as err:
            print "ERROR: {0}".format(err)
            print "*** Skiping VM: {0}".vm.name
            host_collection.pop(vm.qid)

        try:
            vm.add_to_xen_storage()
        except (IOError, OSError) as err:
            print "ERROR: {0}".format(err)
            print "*** Skiping VM: {0}".vm.name
            host_collection.pop(vm.qid)


    backup_collection.unlock_db()
    host_collection.save()
    host_collection.unlock_db()
    print "-> Done."
main()