Browse Source

Merge remote-tracking branch 'origin/pr/198'

* origin/pr/198:
  Fixed new firewall rule window
Marek Marczykowski-Górecki 4 năm trước cách đây
mục cha
commit
bdf0951c52
1 tập tin đã thay đổi với 67 bổ sung57 xóa
  1. 67 57
      qubesmanager/firewall.py

+ 67 - 57
qubesmanager/firewall.py

@@ -90,6 +90,66 @@ class NewFwRuleDlg(QtWidgets.QDialog, ui_newfwruledlg.Ui_NewFwRuleDlg):
         self.populate_combos()
         self.serviceComboBox.setInsertPolicy(QtWidgets.QComboBox.InsertAtTop)
 
+        self.model = None
+
+    def try_to_create_rule(self):
+        # return True if successful, False otherwise
+        address = str(self.addressComboBox.currentText())
+        service = str(self.serviceComboBox.currentText())
+
+        rule = qubesadmin.firewall.Rule(None, action='accept')
+
+        if address is not None and address != "*":
+            try:
+                rule.dsthost = address
+            except ValueError:
+                QtWidgets.QMessageBox.warning(
+                    self, self.tr("Invalid address"),
+                    self.tr("Address '{0}' is invalid.").format(address))
+                return False
+
+        if self.tcp_radio.isChecked():
+            rule.proto = 'tcp'
+        elif self.udp_radio.isChecked():
+            rule.proto = 'udp'
+
+        if self.model.port_range_pattern.fullmatch(service):
+            try:
+                rule.dstports = service
+            except ValueError:
+                QtWidgets.QMessageBox.warning(
+                    self,
+                    self.tr("Invalid port or service"),
+                    self.tr("Port number or service '{0}' is "
+                            "invalid.").format(service))
+                return False
+        elif service:
+            if self.model.service_port_pattern.fullmatch(service):
+                parsed_service = self.model.service_port_pattern.match(
+                    service).groups()[0]
+            else:
+                parsed_service = service
+
+            try:
+                rule.dstports = parsed_service
+            except (TypeError, ValueError):
+                if self.model.get_service_port(parsed_service) is not None:
+                    rule.dstports = self.model.get_service_port(parsed_service)
+                else:
+                    QtWidgets.QMessageBox.warning(
+                        self,
+                        self.tr("Invalid port or service"),
+                        self.tr(
+                            "Port number or service '{0}' is "
+                            "invalid.".format(parsed_service)))
+                    return False
+
+        if self.model.current_row is not None:
+            self.model.set_child(self.model.current_row, rule)
+        else:
+            self.model.append_child(rule)
+        return True
+
     def accept(self):
         if self.tcp_radio.isChecked() or self.udp_radio.isChecked():
             if not self.serviceComboBox.currentText():
@@ -99,7 +159,8 @@ class NewFwRuleDlg(QtWidgets.QDialog, ui_newfwruledlg.Ui_NewFwRuleDlg):
                     self.tr("You need to fill service "
                             "name/port for TCP/UDP rule"))
                 return
-        super().accept()
+        if self.try_to_create_rule():
+            super().accept()
 
     def populate_combos(self):
         example_addresses = [
@@ -145,6 +206,8 @@ class QubesFirewallRulesModel(QtCore.QAbstractItemModel):
     def __init__(self, parent=None):
         QtCore.QAbstractItemModel.__init__(self, parent)
 
+        self.current_row = None
+
         self.__column_names = {0: "Address", 1: "Port/Service", 2: "Protocol", }
         self.__services = list()
 
@@ -366,62 +429,9 @@ class QubesFirewallRulesModel(QtCore.QAbstractItemModel):
             dialog.any_radio.setChecked(True)
 
     def run_rule_dialog(self, dialog, row=None):
-        if dialog.exec_():
-
-            address = str(dialog.addressComboBox.currentText())
-            service = str(dialog.serviceComboBox.currentText())
-
-            rule = qubesadmin.firewall.Rule(None, action='accept')
-
-            if address is not None and address != "*":
-                try:
-                    rule.dsthost = address
-                except ValueError:
-                    QtWidgets.QMessageBox.warning(
-                        dialog, self.tr("Invalid address"),
-                        self.tr("Address '{0}' is invalid.").format(address))
-                    return
-
-            if dialog.tcp_radio.isChecked():
-                rule.proto = 'tcp'
-            elif dialog.udp_radio.isChecked():
-                rule.proto = 'udp'
-
-            if self.port_range_pattern.fullmatch(service):
-                try:
-                    rule.dstports = service
-                except ValueError:
-                    QtWidgets.QMessageBox.warning(
-                        dialog,
-                        self.tr("Invalid port or service"),
-                        self.tr("Port number or service '{0}' is "
-                                "invalid.").format(service))
-                    return
-            elif service:
-                if self.service_port_pattern.fullmatch(service):
-                    parsed_service = self.service_port_pattern.match(
-                        service).groups()[0]
-                else:
-                    parsed_service = service
-
-                try:
-                    rule.dstports = parsed_service
-                except (TypeError, ValueError):
-                    if self.get_service_port(parsed_service) is not None:
-                        rule.dstports = self.get_service_port(parsed_service)
-                    else:
-                        QtWidgets.QMessageBox.warning(
-                            dialog,
-                            self.tr("Invalid port or service"),
-                            self.tr(
-                                "Port number or service '{0}' is "
-                                "invalid.".format(parsed_service)))
-                        return
-
-            if row is not None:
-                self.set_child(row, rule)
-            else:
-                self.append_child(rule)
+        self.current_row = row
+        dialog.model = self
+        dialog.exec()
 
     def index(self, row, column, parent=QtCore.QModelIndex()):
         if not self.hasIndex(row, column, parent):