Selaa lähdekoodia

storage: support asynchronous storage pool implementations

Allow specific pool implementation to provide asynchronous
implementation. vm.storage.* methods will detect if given implementation
is synchronous or asynchronous and will act accordingly.
Then it's up to pool implementation how asynchronous should be achieved.
Do not force it using threads (`run_in_executor()`). But pool
implementation is free to use threads, if consider it safe in a
particular case.

This commit does not touch any pool implementation - all of them are
still synchronous.

QubesOS/qubes-issues#2256
Marek Marczykowski-Górecki 7 vuotta sitten
vanhempi
commit
52c3753d61
3 muutettua tiedostoa jossa 114 lisäystä ja 32 poistoa
  1. 96 21
      qubes/storage/__init__.py
  2. 6 2
      qubes/tests/storage_file.py
  3. 12 9
      qubes/vm/qubesvm.py

+ 96 - 21
qubes/storage/__init__.py

@@ -31,6 +31,7 @@ import string  # pylint: disable=deprecated-module
 import time
 from datetime import datetime
 
+import asyncio
 import lxml.etree
 import pkg_resources
 import qubes
@@ -335,24 +336,37 @@ class Storage(object):
             result += volume.usage
         return result
 
+    @asyncio.coroutine
     def resize(self, volume, size):
         ''' Resizes volume a read-writable volume '''
         if isinstance(volume, str):
             volume = self.vm.volumes[volume]
-        self.get_pool(volume).resize(volume, size)
+        ret = self.get_pool(volume).resize(volume, size)
+        if asyncio.iscoroutine(ret):
+            yield from ret
         if self.vm.is_running():
-            self.vm.run_service('qubes.ResizeDisk', input=volume.name.encode(),
-                user='root', wait=True)
+            yield from self.vm.run_service_for_stdio('qubes.ResizeDisk',
+                input=volume.name.encode(),
+                user='root')
 
+    @asyncio.coroutine
     def create(self):
         ''' Creates volumes on disk '''
         old_umask = os.umask(0o002)
 
+        coros = []
         for volume in self.vm.volumes.values():
-            self.get_pool(volume).create(volume)
+            # launch the operation, if it's asynchronous, then append to wait
+            #  for them at the end
+            ret = self.get_pool(volume).create(volume)
+            if asyncio.iscoroutine(ret):
+                coros.append(ret)
+        if coros:
+            yield from asyncio.wait(coros)
 
         os.umask(old_umask)
 
+    @asyncio.coroutine
     def clone(self, src_vm):
         ''' Clone volumes from the specified vm '''
 
@@ -365,6 +379,14 @@ class Storage(object):
         assert not os.path.exists(dst_path), msg
         os.mkdir(dst_path)
 
+        # clone/import functions may be either synchronous or asynchronous
+        # in the later case, we need to wait for them to finish
+        clone_op = {}
+
+        msg = "Cloning directory: {!s} to {!s}"
+        msg = msg.format(src_path, dst_path)
+        self.log.info(msg)
+
         self.vm.volumes = {}
         with VmCreationManager(self.vm):
             for name, config in self.vm.volume_config.items():
@@ -375,22 +397,30 @@ class Storage(object):
                 if dst_pool == src_pool:
                     msg = "Cloning volume {!s} from vm {!s}"
                     self.vm.log.info(msg.format(src_volume.name, src_vm.name))
-                    volume = dst_pool.clone(src_volume, dst)
+                    clone_op_ret = dst_pool.clone(src_volume, dst)
                 else:
                     msg = "Importing volume {!s} from vm {!s}"
                     self.vm.log.info(msg.format(src_volume.name, src_vm.name))
-                    volume = dst_pool.import_volume(dst_pool, dst, src_pool,
-                                                    src_volume)
+                    clone_op_ret = dst_pool.import_volume(
+                            dst_pool, dst, src_pool, src_volume)
+                if asyncio.iscoroutine(clone_op_ret):
+                    clone_op[name] = asyncio.ensure_future(clone_op_ret)
+
+            yield from asyncio.wait(x for x in clone_op.values()
+                if asyncio.isfuture(x))
+
+            for name, clone_op_ret in clone_op.items():
+                if asyncio.isfuture(clone_op_ret):
+                    volume = clone_op_ret.result
+                else:
+                    volume = clone_op_ret
 
                 assert volume, "%s.clone() returned '%s'" % (
-                    dst_pool.__class__.__name__, volume)
+                    self.get_pool(self.vm.volume_config[name]['pool']).
+                        __class__.__name__, volume)
 
                 self.vm.volumes[name] = volume
 
-        msg = "Cloning directory: {!s} to {!s}"
-        msg = msg.format(src_path, dst_path)
-        self.log.info(msg)
-
     @property
     def outdated_volumes(self):
         ''' Returns a list of outdated volumes '''
@@ -431,28 +461,52 @@ class Storage(object):
         self.vm.fire_event('domain-verify-files')
         return True
 
+    @asyncio.coroutine
     def remove(self):
         ''' Remove all the volumes.
 
             Errors on removal are catched and logged.
         '''
+        futures = []
         for name, volume in self.vm.volumes.items():
             self.log.info('Removing volume %s: %s' % (name, volume.vid))
             try:
-                self.get_pool(volume).remove(volume)
+                ret = self.get_pool(volume).remove(volume)
+                if asyncio.iscoroutine(ret):
+                    futures.append(ret)
             except (IOError, OSError) as e:
                 self.vm.log.exception("Failed to remove volume %s", name, e)
 
+        if futures:
+            try:
+                yield from asyncio.wait(futures)
+            except (IOError, OSError) as e:
+                self.vm.log.exception("Failed to remove some volume", e)
+
+    @asyncio.coroutine
     def start(self):
         ''' Execute the start method on each pool '''
+        futures = []
         for volume in self.vm.volumes.values():
             pool = self.get_pool(volume)
-            volume = pool.start(volume)
+            ret = pool.start(volume)
+            if asyncio.iscoroutine(ret):
+                futures.append(ret)
+
+        if futures:
+            yield from asyncio.wait(futures)
 
+    @asyncio.coroutine
     def stop(self):
         ''' Execute the start method on each pool '''
+        futures = []
         for volume in self.vm.volumes.values():
-            self.get_pool(volume).stop(volume)
+            ret = self.get_pool(volume).stop(volume)
+            if asyncio.iscoroutine(ret):
+                futures.append(ret)
+
+        if futures:
+            yield from asyncio.wait(futures)
 
     def get_pool(self, volume):
         ''' Helper function '''
@@ -463,11 +517,18 @@ class Storage(object):
 
         return self.vm.app.pools[volume]
 
+    @asyncio.coroutine
     def commit(self):
         ''' Makes changes to an 'origin' volume persistent '''
+        futures = []
         for volume in self.vm.volumes.values():
             if volume.save_on_stop:
-                self.get_pool(volume).commit(volume)
+                ret = self.get_pool(volume).commit(volume)
+                if asyncio.iscoroutine(ret):
+                    futures.append(ret)
+
+        if futures:
+            yield asyncio.wait(futures)
 
     def unused_frontend(self):
         ''' Find an unused device name '''
@@ -529,11 +590,15 @@ class Pool(object):
     def create(self, volume):
         ''' Create the given volume on disk or copy from provided
             `source_volume`.
+
+            This can be implemented as a coroutine.
         '''
         raise self._not_implemented("create")
 
     def commit(self, volume):  # pylint: disable=no-self-use
-        ''' Write the snapshot to disk '''
+        ''' Write the snapshot to disk
+
+        This can be implemented as a coroutine.'''
         msg = "Got volume_type {!s} when expected 'snap'"
         msg = msg.format(volume.volume_type)
         assert volume.volume_type == 'snap', msg
@@ -544,7 +609,9 @@ class Pool(object):
         raise self._not_implemented("config")
 
     def clone(self, source, target):
-        ''' Clone volume '''
+        ''' Clone volume.
+
+        This can be implemented as a coroutine. '''
         raise self._not_implemented("clone")
 
     def destroy(self):
@@ -581,7 +648,9 @@ class Pool(object):
         raise self._not_implemented("recover")
 
     def remove(self, volume):
-        ''' Remove volume'''
+        ''' Remove volume.
+
+        This can be implemented as a coroutine.'''
         raise self._not_implemented("remove")
 
     def rename(self, volume, old_name, new_name):
@@ -597,6 +666,8 @@ class Pool(object):
         ''' Expands volume, throws
             :py:class:`qubes.storage.StoragePoolException` if
             given size is less than current_size
+
+            This can be implemented as a coroutine.
         '''
         raise self._not_implemented("resize")
 
@@ -611,11 +682,15 @@ class Pool(object):
         raise self._not_implemented("setup")
 
     def start(self, volume):  # pylint: disable=no-self-use
-        ''' Do what ever is needed on start '''
+        ''' Do what ever is needed on start
+
+        This can be implemented as a coroutine.'''
         raise self._not_implemented("start")
 
     def stop(self, volume):  # pylint: disable=no-self-use
-        ''' Do what ever is needed on stop'''
+        ''' Do what ever is needed on stop
+
+        This can be implemented as a coroutine.'''
 
     def verify(self, volume):
         ''' Verifies the volume. '''

+ 6 - 2
qubes/tests/storage_file.py

@@ -23,6 +23,8 @@
 import os
 import shutil
 
+import asyncio
+
 import qubes.storage
 import qubes.tests.storage
 from qubes.config import defaults
@@ -311,7 +313,8 @@ class TC_03_FilePool(qubes.tests.QubesTestCase):
                                          'pool': 'test-pool'
                                      }
                                  }, label='red')
-        vm.create_on_disk()
+        loop = asyncio.get_event_loop()
+        loop.run_until_complete(vm.create_on_disk())
 
         expected_vmdir = os.path.join(self.APPVMS_DIR, vm.name)
 
@@ -341,7 +344,8 @@ class TC_03_FilePool(qubes.tests.QubesTestCase):
                                          'pool': 'test-pool'
                                      }
                                  }, label='red')
-        vm.create_on_disk()
+        loop = asyncio.get_event_loop()
+        loop.run_until_complete(vm.create_on_disk())
 
         expected_vmdir = os.path.join(self.TEMPLATES_DIR, vm.name)
 

+ 12 - 9
qubes/vm/qubesvm.py

@@ -837,8 +837,7 @@ class QubesVM(qubes.vm.mix.net.NetVMMixin, qubes.vm.BaseVM):
         qmemman_client = yield from asyncio.get_event_loop().run_in_executor(
             None, self.request_memory, mem_required)
 
-        yield from asyncio.get_event_loop().run_in_executor(None,
-            self.storage.start)
+        yield from self.storage.start()
         self._update_libvirt_domain()
 
         try:
@@ -906,8 +905,8 @@ class QubesVM(qubes.vm.mix.net.NetVMMixin, qubes.vm.BaseVM):
 
         self.libvirt_domain.shutdown()
 
-        yield from asyncio.get_event_loop().run_in_executor(None,
-            self.storage.stop)
+        # FIXME: move to libvirt domain destroy event handler
+        yield from self.storage.stop()
 
         while wait and not self.is_halted():
             yield from asyncio.sleep(0.25)
@@ -926,8 +925,8 @@ class QubesVM(qubes.vm.mix.net.NetVMMixin, qubes.vm.BaseVM):
             raise qubes.exc.QubesVMNotStartedError(self)
 
         self.libvirt_domain.destroy()
-        yield from asyncio.get_event_loop().run_in_executor(None,
-            self.storage.stop)
+        # FIXME: move to libvirt domain destroy event handler
+        yield from self.storage.stop()
 
         return self
 
@@ -1229,6 +1228,7 @@ class QubesVM(qubes.vm.mix.net.NetVMMixin, qubes.vm.BaseVM):
         self.have_session.set()
         self.fire_event('domain-has-session')
 
+    @asyncio.coroutine
     def create_on_disk(self, pool=None, pools=None):
         '''Create files needed for VM.
         '''
@@ -1242,7 +1242,7 @@ class QubesVM(qubes.vm.mix.net.NetVMMixin, qubes.vm.BaseVM):
                                                       pools)
             self.storage = qubes.storage.Storage(self)
 
-        self.storage.create()
+        yield from self.storage.create()
 
         self.log.info('Creating icon symlink: {} -> {}'.format(
             self.icon_path, self.label.icon_path))
@@ -1254,6 +1254,7 @@ class QubesVM(qubes.vm.mix.net.NetVMMixin, qubes.vm.BaseVM):
         # fire hooks
         self.fire_event('domain-create-on-disk')
 
+    @asyncio.coroutine
     def remove_from_disk(self):
         '''Remove domain remnants from disk.'''
         if not self.is_halted():
@@ -1263,14 +1264,16 @@ class QubesVM(qubes.vm.mix.net.NetVMMixin, qubes.vm.BaseVM):
 
         self.fire_event('domain-remove-from-disk')
         try:
+            # TODO: make it async?
             shutil.rmtree(self.dir_path)
         except OSError as e:
             if e.errno == errno.ENOENT:
                 pass
             else:
                 raise
-        self.storage.remove()
+        yield from self.storage.remove()
 
+    @asyncio.coroutine
     def clone_disk_files(self, src, pool=None, pools=None, ):
         '''Clone files from other vm.
 
@@ -1291,7 +1294,7 @@ class QubesVM(qubes.vm.mix.net.NetVMMixin, qubes.vm.BaseVM):
                                                       pools)
 
         self.storage = qubes.storage.Storage(self)
-        self.storage.clone(src)
+        yield from self.storage.clone(src)
         self.storage.verify()
         assert self.volumes != {}