Sfoglia il codice sorgente

qvm-template: Add confirmation for dangerous operations; verify signatures once instead of twice by returning header after verification.

WillyPillow 3 anni fa
parent
commit
8ee0d639b8
1 ha cambiato i file con 111 aggiunte e 88 eliminazioni
  1. 111 88
      qubesadmin/tools/qvm_template.py

+ 111 - 88
qubesadmin/tools/qvm_template.py

@@ -71,6 +71,8 @@ def parser_gen() -> argparse.ArgumentParser:
         help='Set repository metadata as expired before running the command.')
     parser_main.add_argument('--cachedir', default=CACHE_DIR,
         help='Specify cache directory.')
+    parser_main.add_argument('--yes', action='store_true',
+        help='Assume "yes" to questions.')
     # qvm-template {install,reinstall,downgrade,upgrade}
     parser_install = parser_add_command('install',
         help_str='Install template packages.')
@@ -486,8 +488,9 @@ def verify_rpm(
         path: str,
         nogpgcheck: bool = False,
         transaction_set: typing.Optional[rpm.transaction.TransactionSet] = None
-        ) -> bool:
-    """Verify the digest and signature of a RPM package.
+        ) -> rpm.hdr:
+    """Verify the digest and signature of a RPM package and return the package
+    header.
 
     Note that verifying RPMs this way is prone to TOCTOU. This is okay for
     local files, but may create problems if multiple instances of
@@ -498,42 +501,22 @@ def verify_rpm(
     :param nogpgcheck: Whether to allow invalid GPG signatures
     :param transaction_set: Override RPM ``TransactionSet``. Optional
 
-    :return: Whether the RPM is verified
+    :return: RPM package header. If verification fails, ``None`` is returned.
     """
     if transaction_set is None:
         transaction_set = rpm.TransactionSet()
     with open(path, 'rb') as fd:
         try:
             hdr = transaction_set.hdrFromFdno(fd)
-            if hdr[rpm.RPMTAG_SIGSIZE] is None \
-                    and hdr[rpm.RPMTAG_SIGPGP] is None \
+            if hdr[rpm.RPMTAG_SIGPGP] is None \
                     and hdr[rpm.RPMTAG_SIGGPG] is None:
-                return nogpgcheck
+                return hdr if nogpgcheck else None
         except rpm.error as e:
             if str(e) == 'public key not trusted' \
                     or str(e) == 'public key not available':
-                return nogpgcheck
-            return False
-    return True
-
-def get_package_hdr(
-        path: str,
-        transaction_set: typing.Optional[rpm.transaction.TransactionSet] = None
-        ) -> rpm.hdr:
-    """Return header of a RPM package.
-
-    Note that this function **does not** check the integrity of the package.
-
-    :param path: Location of the RPM package
-    :param transaction_set: Override RPM ``TransactionSet``. Optional
-
-    :return: RPM headers
-    """
-    if transaction_set is None:
-        transaction_set = rpm.TransactionSet()
-    with open(path, 'rb') as fd:
-        hdr = transaction_set.hdrFromFdno(fd)
-        return hdr
+                return hdr if nogpgcheck else None
+            return None
+    return hdr
 
 def extract_rpm(name: str, path: str, target: str) -> bool:
     """Extract a template RPM package.
@@ -724,21 +707,90 @@ def install(
     try:
         transaction_set = rpm.TransactionSet()
 
-        rpm_list = [] # rpmfile, reponame
+        unverified_rpm_list = [] # rpmfile, reponame
+        verified_rpm_list = []
+        def verify(rpmfile, reponame):
+            """Verify package signature and version, remove "unverified"
+            suffix, and parse package header."""
+            if reponame != '@commandline':
+                path = rpmfile + UNVERIFIED_SUFFIX
+            else:
+                path = rpmfile
+
+            package_hdr = verify_rpm(path, args.nogpgcheck, transaction_set)
+            if not package_hdr:
+                parser.error('Package \'%s\' verification failed.' % rpmfile)
+
+            if reponame != '@commandline':
+                os.rename(path, rpmfile)
+
+            package_name = package_hdr[rpm.RPMTAG_NAME]
+            if not package_name.startswith(PACKAGE_NAME_PREFIX):
+                parser.error(
+                    'Illegal package name for package \'%s\'.' % rpmfile)
+            # Remove prefix to get the real template name
+            name = package_name[len(PACKAGE_NAME_PREFIX):]
+
+            # Check if already installed
+            if not override_existing and name in app.domains:
+                print(('Template \'%s\' already installed, skipping...'
+                    ' (You may want to use the'
+                    ' {reinstall,upgrade,downgrade}'
+                    ' operations.)') % name, file=sys.stderr)
+                return
+
+            # Check if version is really what we want
+            if override_existing:
+                vm = get_managed_template_vm(app, name)
+                pkg_evr = (
+                    str(package_hdr[rpm.RPMTAG_EPOCHNUM]),
+                    package_hdr[rpm.RPMTAG_VERSION],
+                    package_hdr[rpm.RPMTAG_RELEASE])
+                vm_evr = query_local_evr(vm)
+                cmp_res = rpm.labelCompare(pkg_evr, vm_evr)
+                if version_selector == VersionSelector.REINSTALL \
+                        and cmp_res != 0:
+                    parser.error(
+                        'Same version of template \'%s\' not found.' \
+                        % name)
+                elif version_selector == VersionSelector.LATEST_LOWER \
+                        and cmp_res != -1:
+                    print(("Template '%s' of lower version"
+                        " already installed, skipping..." % name),
+                        file=sys.stderr)
+                    return
+                elif version_selector == VersionSelector.LATEST_HIGHER \
+                        and cmp_res != 1:
+                    print(("Template '%s' of higher version"
+                        " already installed, skipping..." % name),
+                        file=sys.stderr)
+                    return
+
+            verified_rpm_list.append((rpmfile, reponame, name, package_hdr))
+
+        # Process local templates
         for template in args.templates:
             if template.endswith('.rpm'):
                 if not os.path.exists(template):
                     parser.error('RPM file \'%s\' not found.' % template)
-                rpm_list.append((template, '@commandline'))
+                unverified_rpm_list.append((template, '@commandline'))
+
+        # First verify local RPMs and extract header
+        for rpmfile, reponame in unverified_rpm_list:
+            verify(rpmfile, reponame)
+        unverified_rpm_list = []
 
         os.makedirs(args.cachedir, exist_ok=True)
 
+        # Get list of templates to download
         dl_list = get_dl_list(args, app, version_selector=version_selector)
         dl_list_copy = dl_list.copy()
-        # Verify that the templates are not yet installed
         for name, entry in dl_list.items():
             # Should be ensured by checks in repoquery
             assert entry.reponame != '@commandline'
+            # Verify that the templates to be downloaded are not yet installed
+            # Note that we *still* have to do this again in verify() for
+            # already-downloaded templates
             if not override_existing and name in app.domains:
                 print(('Template \'%s\' already installed, skipping...'
                     ' (You may want to use the'
@@ -746,75 +798,46 @@ def install(
                     ' operations.)') % name, file=sys.stderr)
                 del dl_list_copy[name]
             else:
+                # XXX: Perhaps this is better returned by download()
                 version_str = build_version_str(entry.evr)
                 target_file = \
                     '%s%s-%s.rpm' % (PACKAGE_NAME_PREFIX, name, version_str)
-                rpm_list.append(
+                unverified_rpm_list.append(
                     (os.path.join(args.cachedir, target_file), entry.reponame))
         dl_list = dl_list_copy
 
+        # Ask the user for confirmation before we actually download stuff
+        if override_existing and not args.yes:
+            override_tpls = []
+            # Local templates, already verified
+            for _, _, name, _ in verified_rpm_list:
+                override_tpls.append(name)
+            # Templates not yet downloaded
+            for name in dl_list:
+                override_tpls.append(name)
+
+            print('This will override changes made in the following VMs:',
+                file=sys.stderr)
+            for tpl in override_tpls:
+                print('  %s' % tpl, file=sys.stderr)
+            confirm = ''
+            while confirm != 'y':
+                confirm = input('Are you sure? [y/N] ').lower()
+                if confirm == 'n':
+                    sys.exit(1)
+
         download(args, app, path_override=args.cachedir,
             dl_list=dl_list, suffix=UNVERIFIED_SUFFIX,
             version_selector=version_selector)
 
-        # Verify package and remove unverified suffix
-        for rpmfile, reponame in rpm_list:
-            if reponame != '@commandline':
-                path = rpmfile + UNVERIFIED_SUFFIX
-            else:
-                path = rpmfile
-            if not verify_rpm(path, args.nogpgcheck, transaction_set):
-                parser.error('Package \'%s\' verification failed.' % rpmfile)
-            if reponame != '@commandline':
-                os.rename(path, rpmfile)
+        # Verify downloaded templates
+        for rpmfile, reponame in unverified_rpm_list:
+            verify(rpmfile, reponame)
+        unverified_rpm_list = []
 
         # Unpack and install
-        for rpmfile, reponame in rpm_list:
+        for rpmfile, reponame, name, package_hdr in verified_rpm_list:
             with tempfile.TemporaryDirectory(dir=TEMP_DIR) as target:
-                package_hdr = get_package_hdr(rpmfile)
-                package_name = package_hdr[rpm.RPMTAG_NAME]
-                if not package_name.startswith(PACKAGE_NAME_PREFIX):
-                    parser.error(
-                        'Illegal package name for package \'%s\'.' % rpmfile)
-                # Remove prefix to get the real template name
-                name = package_name[len(PACKAGE_NAME_PREFIX):]
-
-                # Another check for already-downloaded RPMs
-                if not override_existing and name in app.domains:
-                    print(('Template \'%s\' already installed, skipping...'
-                        ' (You may want to use the'
-                        ' {reinstall,upgrade,downgrade}'
-                        ' operations.)') % name, file=sys.stderr)
-                    continue
-
-                # Check if local versus candidate version is in line with the
-                # operation
-                if override_existing:
-                    vm = get_managed_template_vm(app, name)
-                    pkg_evr = (
-                        str(package_hdr[rpm.RPMTAG_EPOCHNUM]),
-                        package_hdr[rpm.RPMTAG_VERSION],
-                        package_hdr[rpm.RPMTAG_RELEASE])
-                    vm_evr = query_local_evr(vm)
-                    cmp_res = rpm.labelCompare(pkg_evr, vm_evr)
-                    if version_selector == VersionSelector.REINSTALL \
-                            and cmp_res != 0:
-                        parser.error(
-                            'Same version of template \'%s\' not found.' \
-                            % name)
-                    elif version_selector == VersionSelector.LATEST_LOWER \
-                            and cmp_res != -1:
-                        print(("Template '%s' of lower version"
-                            " already installed, skipping..." % name),
-                            file=sys.stderr)
-                        continue
-                    elif version_selector == VersionSelector.LATEST_HIGHER \
-                            and cmp_res != 1:
-                        print(("Template '%s' of higher version"
-                            " already installed, skipping..." % name),
-                            file=sys.stderr)
-                        continue
-
                 print('Installing template \'%s\'...' % name, file=sys.stderr)
                 extract_rpm(name, rpmfile, target)
                 cmdline = [
@@ -824,7 +847,7 @@ def install(
                 ]
                 if args.allow_pv:
                     cmdline.append('--allow-pv')
-                if args.pool:
+                if not override_existing and args.pool:
                     cmdline += ['--pool', args.pool]
                 subprocess.check_call(cmdline + [
                     'post-install',