Browse Source

storage/file: import data into temporary volume

Similar to LVM changes, this fixes/improves multiple things:
 - no old data visible in the volume
 - failed import do not leave broken volume
 - parially imported data not visible to running VM

QubesOS/qubes-issues#3169
Marek Marczykowski-Górecki 6 years ago
parent
commit
510fad9163
3 changed files with 61 additions and 3 deletions
  1. 18 2
      qubes/storage/file.py
  2. 1 1
      qubes/tests/api_admin.py
  3. 42 0
      qubes/tests/storage_file.py

+ 18 - 2
qubes/storage/file.py

@@ -269,9 +269,20 @@ class FileVolume(qubes.storage.Volume):
             copy_file(src_volume.export(), self.path)
         return self
 
-
     def import_data(self):
-        return self.path
+        if not self.save_on_stop:
+            raise qubes.storage.StoragePoolException(
+                "Can not import into save_on_stop=False volume {!s}".format(
+                    self))
+        create_sparse_file(self.path_import, self.size)
+        return self.path_import
+
+    def import_data_end(self, success):
+        if success:
+            os.rename(self.path_import, self.path)
+        else:
+            os.unlink(self.path_import)
+        return self
 
     def reset(self):
         ''' Remove and recreate a volatile volume '''
@@ -321,6 +332,11 @@ class FileVolume(qubes.storage.Volume):
         img_name = self.vid + '-cow.img'
         return os.path.join(self.dir_path, img_name)
 
+    @property
+    def path_import(self):
+        img_name = self.vid + '-import.img'
+        return os.path.join(self.dir_path, img_name)
+
     def verify(self):
         ''' Verifies the volume. '''
         if not os.path.exists(self.path) and \

+ 1 - 1
qubes/tests/api_admin.py

@@ -1624,7 +1624,7 @@ class TC_00_VMs(AdminAPITestCase):
         value = self.call_mgmt_func(b'admin.vm.volume.Import', b'test-vm1',
             b'private')
         self.assertEqual(value, '{} {}'.format(
-            2*2**30, '/tmp/qubes-test-dir/appvms/test-vm1/private.img'))
+            2*2**30, '/tmp/qubes-test-dir/appvms/test-vm1/private-import.img'))
         self.assertFalse(self.app.save.called)
 
     def test_511_vm_volume_import_running(self):

+ 42 - 0
qubes/tests/storage_file.py

@@ -312,6 +312,48 @@ class TC_01_FileVolumes(qubes.tests.QubesTestCase):
             volume.revisions_to_keep = 2
         self.assertEqual(volume.revisions_to_keep, 1)
 
+    def test_020_import_data(self):
+        config = {
+            'name': 'root',
+            'pool': self.POOL_NAME,
+            'save_on_stop': True,
+            'rw': True,
+            'size': 1024 * 1024,
+        }
+        vm = qubes.tests.storage.TestVM(self)
+        volume = self.app.get_pool(self.POOL_NAME).init_volume(vm, config)
+        volume.create()
+        import_path = volume.import_data()
+        self.assertNotEqual(volume.path, import_path)
+        with open(import_path, 'w+') as import_file:
+            import_file.write('test')
+        volume.import_data_end(True)
+        self.assertFalse(os.path.exists(import_path), import_path)
+        with open(volume.path) as volume_file:
+            volume_data = volume_file.read().strip('\0')
+        self.assertEqual(volume_data, 'test')
+
+    def test_021_import_data_fail(self):
+        config = {
+            'name': 'root',
+            'pool': self.POOL_NAME,
+            'save_on_stop': True,
+            'rw': True,
+            'size': 1024 * 1024,
+        }
+        vm = qubes.tests.storage.TestVM(self)
+        volume = self.app.get_pool(self.POOL_NAME).init_volume(vm, config)
+        volume.create()
+        import_path = volume.import_data()
+        self.assertNotEqual(volume.path, import_path)
+        with open(import_path, 'w+') as import_file:
+            import_file.write('test')
+        volume.import_data_end(False)
+        self.assertFalse(os.path.exists(import_path), import_path)
+        with open(volume.path) as volume_file:
+            volume_data = volume_file.read().strip('\0')
+        self.assertNotEqual(volume_data, 'test')
+
     def assertVolumePath(self, vm, dev_name, expected, rw=True):
         # :pylint: disable=invalid-name
         volumes = vm.volumes