Browse Source

utils: take tweaked helper functions from storage/reflink

replace_file(), rename_file(), and remove_file() now have optional
'logger' and 'log_level' (defaulting to DEBUG) arguments.

replace_file() now has a required 'permissions' and an optional
'close_on_success' (defaulting to True) argument. Also, it doesn't
create any directories; and in case of an exception, the tempfile is
removed even when closing it raises another exception.

remove_file() now returns a value: True if the file was removed, or
False if it already didn't exist.

(fsync_path() is unchanged.)

!!! After cherry-picking for release4.0, consider a fixup !!!
!!! adding 'import qubes.utils' to storage/reflink there  !!!
Rusty Bird 3 years ago
parent
commit
c988a2218b
2 changed files with 71 additions and 44 deletions
  1. 13 44
      qubes/storage/reflink.py
  2. 58 0
      qubes/utils.py

+ 13 - 44
qubes/storage/reflink.py

@@ -32,7 +32,7 @@ import logging
 import os
 import subprocess
 import tempfile
-from contextlib import contextmanager, suppress
+from contextlib import suppress
 
 import qubes.storage
 import qubes.utils
@@ -224,7 +224,7 @@ class ReflinkVolume(qubes.storage.Volume):
     def _commit(self, path_from):
         self._add_revision()
         self._prune_revisions()
-        _fsync_path(path_from)
+        qubes.utils.fsync_path(path_from)
         _rename_file(path_from, self._path_clean)
 
     def _add_revision(self):
@@ -354,32 +354,16 @@ class ReflinkVolume(qubes.storage.Volume):
         return 0
 
 
-@contextmanager
 def _replace_file(dst):
-    ''' Yield a tempfile whose name starts with dst, creating the last
-        directory component if necessary. If the block does not raise
-        an exception, safely rename the tempfile to dst.
-    '''
-    tmp_dir, prefix = os.path.split(dst + '~')
-    _make_dir(tmp_dir)
-    tmp = tempfile.NamedTemporaryFile(dir=tmp_dir, prefix=prefix, delete=False)
-    try:
-        yield tmp
-        tmp.flush()
-        os.fsync(tmp.fileno())
-        tmp.close()
-        _rename_file(tmp.name, dst)
-    except:
-        tmp.close()
-        _remove_file(tmp.name)
-        raise
-
-def _fsync_path(path):
-    fd = os.open(path, os.O_RDONLY)  # works for a file or a directory
-    try:
-        os.fsync(fd)
-    finally:
-        os.close(fd)
+    _make_dir(os.path.dirname(dst))
+    return qubes.utils.replace_file(
+        dst, permissions=0o600, log_level=logging.INFO)
+
+_rename_file = functools.partial(
+    qubes.utils.rename_file, log_level=logging.INFO)
+
+_remove_file = functools.partial(
+    qubes.utils.remove_file, log_level=logging.INFO)
 
 def _make_dir(path):
     ''' mkdir path, ignoring FileExistsError; return whether we
@@ -387,35 +371,20 @@ def _make_dir(path):
     '''
     with suppress(FileExistsError):
         os.mkdir(path)
-        _fsync_path(os.path.dirname(path))
+        qubes.utils.fsync_path(os.path.dirname(path))
         LOGGER.info('Created directory: %r', path)
         return True
     return False
 
-def _remove_file(path):
-    with suppress(FileNotFoundError):
-        os.remove(path)
-        _fsync_path(os.path.dirname(path))
-        LOGGER.info('Removed file: %r', path)
-
 def _remove_empty_dir(path):
     try:
         os.rmdir(path)
-        _fsync_path(os.path.dirname(path))
+        qubes.utils.fsync_path(os.path.dirname(path))
         LOGGER.info('Removed empty directory: %r', path)
     except OSError as ex:
         if ex.errno not in (errno.ENOENT, errno.ENOTEMPTY):
             raise
 
-def _rename_file(src, dst):
-    os.rename(src, dst)
-    dst_dir = os.path.dirname(dst)
-    src_dir = os.path.dirname(src)
-    _fsync_path(dst_dir)
-    if src_dir != dst_dir:
-        _fsync_path(src_dir)
-    LOGGER.info('Renamed file: %r -> %r', src, dst)
-
 def _resize_file(path, size):
     ''' Resize an existing file. '''
     with open(path, 'rb+') as file:

+ 58 - 0
qubes/utils.py

@@ -22,12 +22,16 @@
 
 import asyncio
 import hashlib
+import logging
 import random
 import string
 import os
+import os.path
 import re
 import socket
 import subprocess
+import tempfile
+from contextlib import contextmanager, suppress
 
 import pkg_resources
 
@@ -36,6 +40,8 @@ import docutils.core
 import docutils.io
 import qubes.exc
 
+LOGGER = logging.getLogger('qubes.utils')
+
 
 def get_timezone():
     # fc18
@@ -186,6 +192,58 @@ def match_vm_name_with_special(vm, name):
         return name[len('@type:'):] == vm.__class__.__name__
     return name == vm.name
 
+@contextmanager
+def replace_file(dst, *, permissions, close_on_success=True,
+                 logger=LOGGER, log_level=logging.DEBUG):
+    ''' Yield a tempfile whose name starts with dst. If the block does
+        not raise an exception, apply permissions and persist the
+        tempfile to dst (which is allowed to already exist). Otherwise
+        ensure that the tempfile is cleaned up.
+    '''
+    tmp_dir, prefix = os.path.split(dst + '~')
+    tmp = tempfile.NamedTemporaryFile(dir=tmp_dir, prefix=prefix, delete=False)
+    try:
+        yield tmp
+        tmp.flush()
+        os.fchmod(tmp.fileno(), permissions)
+        os.fsync(tmp.fileno())
+        if close_on_success:
+            tmp.close()
+        rename_file(tmp.name, dst, logger=logger, log_level=log_level)
+    except:
+        try:
+            tmp.close()
+        finally:
+            remove_file(tmp.name, logger=logger, log_level=log_level)
+        raise
+
+def rename_file(src, dst, *, logger=LOGGER, log_level=logging.DEBUG):
+    ''' Durably rename src to dst. '''
+    os.rename(src, dst)
+    dst_dir = os.path.dirname(dst)
+    src_dir = os.path.dirname(src)
+    fsync_path(dst_dir)
+    if src_dir != dst_dir:
+        fsync_path(src_dir)
+    logger.log(log_level, 'Renamed file: %r -> %r', src, dst)
+
+def remove_file(path, *, logger=LOGGER, log_level=logging.DEBUG):
+    ''' Durably remove the file at path, if it exists. Return whether
+        we removed it. '''
+    with suppress(FileNotFoundError):
+        os.remove(path)
+        fsync_path(os.path.dirname(path))
+        logger.log(log_level, 'Removed file: %r', path)
+        return True
+    return False
+
+def fsync_path(path):
+    fd = os.open(path, os.O_RDONLY)  # works for a file or a directory
+    try:
+        os.fsync(fd)
+    finally:
+        os.close(fd)
+
 @asyncio.coroutine
 def coro_maybe(value):
     if asyncio.iscoroutine(value):