diff --git a/qubes/storage/__init__.py b/qubes/storage/__init__.py index 4e23b091..8d2b7a1b 100644 --- a/qubes/storage/__init__.py +++ b/qubes/storage/__init__.py @@ -31,6 +31,7 @@ import string # pylint: disable=deprecated-module import time from datetime import datetime +import asyncio import lxml.etree import pkg_resources import qubes @@ -335,24 +336,37 @@ class Storage(object): result += volume.usage return result + @asyncio.coroutine def resize(self, volume, size): ''' Resizes volume a read-writable volume ''' if isinstance(volume, str): volume = self.vm.volumes[volume] - self.get_pool(volume).resize(volume, size) + ret = self.get_pool(volume).resize(volume, size) + if asyncio.iscoroutine(ret): + yield from ret if self.vm.is_running(): - self.vm.run_service('qubes.ResizeDisk', input=volume.name.encode(), - user='root', wait=True) + yield from self.vm.run_service_for_stdio('qubes.ResizeDisk', + input=volume.name.encode(), + user='root') + @asyncio.coroutine def create(self): ''' Creates volumes on disk ''' old_umask = os.umask(0o002) + coros = [] for volume in self.vm.volumes.values(): - self.get_pool(volume).create(volume) + # launch the operation, if it's asynchronous, then append to wait + # for them at the end + ret = self.get_pool(volume).create(volume) + if asyncio.iscoroutine(ret): + coros.append(ret) + if coros: + yield from asyncio.wait(coros) os.umask(old_umask) + @asyncio.coroutine def clone(self, src_vm): ''' Clone volumes from the specified vm ''' @@ -365,6 +379,14 @@ class Storage(object): assert not os.path.exists(dst_path), msg os.mkdir(dst_path) + # clone/import functions may be either synchronous or asynchronous + # in the later case, we need to wait for them to finish + clone_op = {} + + msg = "Cloning directory: {!s} to {!s}" + msg = msg.format(src_path, dst_path) + self.log.info(msg) + self.vm.volumes = {} with VmCreationManager(self.vm): for name, config in self.vm.volume_config.items(): @@ -375,22 +397,30 @@ class Storage(object): if dst_pool == src_pool: msg = "Cloning volume {!s} from vm {!s}" self.vm.log.info(msg.format(src_volume.name, src_vm.name)) - volume = dst_pool.clone(src_volume, dst) + clone_op_ret = dst_pool.clone(src_volume, dst) else: msg = "Importing volume {!s} from vm {!s}" self.vm.log.info(msg.format(src_volume.name, src_vm.name)) - volume = dst_pool.import_volume(dst_pool, dst, src_pool, - src_volume) + clone_op_ret = dst_pool.import_volume( + dst_pool, dst, src_pool, src_volume) + if asyncio.iscoroutine(clone_op_ret): + clone_op[name] = asyncio.ensure_future(clone_op_ret) + + yield from asyncio.wait(x for x in clone_op.values() + if asyncio.isfuture(x)) + + for name, clone_op_ret in clone_op.items(): + if asyncio.isfuture(clone_op_ret): + volume = clone_op_ret.result + else: + volume = clone_op_ret assert volume, "%s.clone() returned '%s'" % ( - dst_pool.__class__.__name__, volume) + self.get_pool(self.vm.volume_config[name]['pool']). + __class__.__name__, volume) self.vm.volumes[name] = volume - msg = "Cloning directory: {!s} to {!s}" - msg = msg.format(src_path, dst_path) - self.log.info(msg) - @property def outdated_volumes(self): ''' Returns a list of outdated volumes ''' @@ -431,28 +461,52 @@ class Storage(object): self.vm.fire_event('domain-verify-files') return True + @asyncio.coroutine def remove(self): ''' Remove all the volumes. Errors on removal are catched and logged. ''' + futures = [] for name, volume in self.vm.volumes.items(): self.log.info('Removing volume %s: %s' % (name, volume.vid)) try: - self.get_pool(volume).remove(volume) + ret = self.get_pool(volume).remove(volume) + if asyncio.iscoroutine(ret): + futures.append(ret) except (IOError, OSError) as e: self.vm.log.exception("Failed to remove volume %s", name, e) + if futures: + try: + yield from asyncio.wait(futures) + except (IOError, OSError) as e: + self.vm.log.exception("Failed to remove some volume", e) + + @asyncio.coroutine def start(self): ''' Execute the start method on each pool ''' + futures = [] for volume in self.vm.volumes.values(): pool = self.get_pool(volume) - volume = pool.start(volume) + ret = pool.start(volume) + if asyncio.iscoroutine(ret): + futures.append(ret) + if futures: + yield from asyncio.wait(futures) + + @asyncio.coroutine def stop(self): ''' Execute the start method on each pool ''' + futures = [] for volume in self.vm.volumes.values(): - self.get_pool(volume).stop(volume) + ret = self.get_pool(volume).stop(volume) + if asyncio.iscoroutine(ret): + futures.append(ret) + + if futures: + yield from asyncio.wait(futures) def get_pool(self, volume): ''' Helper function ''' @@ -463,11 +517,18 @@ class Storage(object): return self.vm.app.pools[volume] + @asyncio.coroutine def commit(self): ''' Makes changes to an 'origin' volume persistent ''' + futures = [] for volume in self.vm.volumes.values(): if volume.save_on_stop: - self.get_pool(volume).commit(volume) + ret = self.get_pool(volume).commit(volume) + if asyncio.iscoroutine(ret): + futures.append(ret) + + if futures: + yield asyncio.wait(futures) def unused_frontend(self): ''' Find an unused device name ''' @@ -529,11 +590,15 @@ class Pool(object): def create(self, volume): ''' Create the given volume on disk or copy from provided `source_volume`. + + This can be implemented as a coroutine. ''' raise self._not_implemented("create") def commit(self, volume): # pylint: disable=no-self-use - ''' Write the snapshot to disk ''' + ''' Write the snapshot to disk + + This can be implemented as a coroutine.''' msg = "Got volume_type {!s} when expected 'snap'" msg = msg.format(volume.volume_type) assert volume.volume_type == 'snap', msg @@ -544,7 +609,9 @@ class Pool(object): raise self._not_implemented("config") def clone(self, source, target): - ''' Clone volume ''' + ''' Clone volume. + + This can be implemented as a coroutine. ''' raise self._not_implemented("clone") def destroy(self): @@ -581,7 +648,9 @@ class Pool(object): raise self._not_implemented("recover") def remove(self, volume): - ''' Remove volume''' + ''' Remove volume. + + This can be implemented as a coroutine.''' raise self._not_implemented("remove") def rename(self, volume, old_name, new_name): @@ -597,6 +666,8 @@ class Pool(object): ''' Expands volume, throws :py:class:`qubes.storage.StoragePoolException` if given size is less than current_size + + This can be implemented as a coroutine. ''' raise self._not_implemented("resize") @@ -611,11 +682,15 @@ class Pool(object): raise self._not_implemented("setup") def start(self, volume): # pylint: disable=no-self-use - ''' Do what ever is needed on start ''' + ''' Do what ever is needed on start + + This can be implemented as a coroutine.''' raise self._not_implemented("start") def stop(self, volume): # pylint: disable=no-self-use - ''' Do what ever is needed on stop''' + ''' Do what ever is needed on stop + + This can be implemented as a coroutine.''' def verify(self, volume): ''' Verifies the volume. ''' diff --git a/qubes/tests/storage_file.py b/qubes/tests/storage_file.py index 1f0547ad..ca22b603 100644 --- a/qubes/tests/storage_file.py +++ b/qubes/tests/storage_file.py @@ -23,6 +23,8 @@ import os import shutil +import asyncio + import qubes.storage import qubes.tests.storage from qubes.config import defaults @@ -311,7 +313,8 @@ class TC_03_FilePool(qubes.tests.QubesTestCase): 'pool': 'test-pool' } }, label='red') - vm.create_on_disk() + loop = asyncio.get_event_loop() + loop.run_until_complete(vm.create_on_disk()) expected_vmdir = os.path.join(self.APPVMS_DIR, vm.name) @@ -341,7 +344,8 @@ class TC_03_FilePool(qubes.tests.QubesTestCase): 'pool': 'test-pool' } }, label='red') - vm.create_on_disk() + loop = asyncio.get_event_loop() + loop.run_until_complete(vm.create_on_disk()) expected_vmdir = os.path.join(self.TEMPLATES_DIR, vm.name) diff --git a/qubes/vm/qubesvm.py b/qubes/vm/qubesvm.py index 82cd75ff..b0c5a5f1 100644 --- a/qubes/vm/qubesvm.py +++ b/qubes/vm/qubesvm.py @@ -837,8 +837,7 @@ class QubesVM(qubes.vm.mix.net.NetVMMixin, qubes.vm.BaseVM): qmemman_client = yield from asyncio.get_event_loop().run_in_executor( None, self.request_memory, mem_required) - yield from asyncio.get_event_loop().run_in_executor(None, - self.storage.start) + yield from self.storage.start() self._update_libvirt_domain() try: @@ -906,8 +905,8 @@ class QubesVM(qubes.vm.mix.net.NetVMMixin, qubes.vm.BaseVM): self.libvirt_domain.shutdown() - yield from asyncio.get_event_loop().run_in_executor(None, - self.storage.stop) + # FIXME: move to libvirt domain destroy event handler + yield from self.storage.stop() while wait and not self.is_halted(): yield from asyncio.sleep(0.25) @@ -926,8 +925,8 @@ class QubesVM(qubes.vm.mix.net.NetVMMixin, qubes.vm.BaseVM): raise qubes.exc.QubesVMNotStartedError(self) self.libvirt_domain.destroy() - yield from asyncio.get_event_loop().run_in_executor(None, - self.storage.stop) + # FIXME: move to libvirt domain destroy event handler + yield from self.storage.stop() return self @@ -1229,6 +1228,7 @@ class QubesVM(qubes.vm.mix.net.NetVMMixin, qubes.vm.BaseVM): self.have_session.set() self.fire_event('domain-has-session') + @asyncio.coroutine def create_on_disk(self, pool=None, pools=None): '''Create files needed for VM. ''' @@ -1242,7 +1242,7 @@ class QubesVM(qubes.vm.mix.net.NetVMMixin, qubes.vm.BaseVM): pools) self.storage = qubes.storage.Storage(self) - self.storage.create() + yield from self.storage.create() self.log.info('Creating icon symlink: {} -> {}'.format( self.icon_path, self.label.icon_path)) @@ -1254,6 +1254,7 @@ class QubesVM(qubes.vm.mix.net.NetVMMixin, qubes.vm.BaseVM): # fire hooks self.fire_event('domain-create-on-disk') + @asyncio.coroutine def remove_from_disk(self): '''Remove domain remnants from disk.''' if not self.is_halted(): @@ -1263,14 +1264,16 @@ class QubesVM(qubes.vm.mix.net.NetVMMixin, qubes.vm.BaseVM): self.fire_event('domain-remove-from-disk') try: + # TODO: make it async? shutil.rmtree(self.dir_path) except OSError as e: if e.errno == errno.ENOENT: pass else: raise - self.storage.remove() + yield from self.storage.remove() + @asyncio.coroutine def clone_disk_files(self, src, pool=None, pools=None, ): '''Clone files from other vm. @@ -1291,7 +1294,7 @@ class QubesVM(qubes.vm.mix.net.NetVMMixin, qubes.vm.BaseVM): pools) self.storage = qubes.storage.Storage(self) - self.storage.clone(src) + yield from self.storage.clone(src) self.storage.verify() assert self.volumes != {}