Store VM collection connected to NetVM

This commit is contained in:
Marek Marczykowski 2011-04-04 19:08:40 +02:00
parent a6d079594b
commit 2aec07dd60

View File

@ -203,6 +203,8 @@ class QubesVm(object):
self.uses_default_netvm = uses_default_netvm self.uses_default_netvm = uses_default_netvm
self.netvm_vm = netvm_vm self.netvm_vm = netvm_vm
if netvm_vm is not None:
netvm_vm.connected_vms[qid] = self
# We use it in remove from disk to avoid removing rpm files (for templates) # We use it in remove from disk to avoid removing rpm files (for templates)
self.installed_by_rpm = installed_by_rpm self.installed_by_rpm = installed_by_rpm
@ -1242,6 +1244,7 @@ class QubesNetVm(QubesVm):
if "vcpus" not in kwargs or kwargs["vcpus"] is None: if "vcpus" not in kwargs or kwargs["vcpus"] is None:
kwargs["vcpus"] = default_servicevm_vcpus kwargs["vcpus"] = default_servicevm_vcpus
super(QubesNetVm, self).__init__(**kwargs) super(QubesNetVm, self).__init__(**kwargs)
self.connected_vms = QubesVmCollection()
@property @property
def type(self): def type(self):
@ -1383,17 +1386,8 @@ class QubesProxyVm(QubesNetVm):
# Allow dom0 networking # Allow dom0 networking
iptables += "-A FORWARD -i vif0.0 -j ACCEPT\n" iptables += "-A FORWARD -i vif0.0 -j ACCEPT\n"
qvm_collection = QubesVmCollection() vms = [vm for vm in self.connected_vms.values() if vm.has_firewall()]
qvm_collection.lock_db_for_reading()
qvm_collection.load()
qvm_collection.unlock_db()
vms = [vm for vm in qvm_collection.values() if vm.has_firewall()]
for vm in vms: for vm in vms:
# Process only VMs connected to this ProxyVM
if not vm.netvm_vm or vm.netvm_vm.qid != self.qid:
continue
conf = vm.get_firewall_conf() conf = vm.get_firewall_conf()
xid = vm.get_xid() xid = vm.get_xid()
@ -1795,8 +1789,8 @@ class QubesVmCollection(dict):
while len(new_vms) > 0: while len(new_vms) > 0:
cur_vm = new_vms.pop() cur_vm = new_vms.pop()
for vm in self.values(): for vm in cur_vm.connected_vms.values():
if vm.netvm_vm and vm.netvm_vm.qid == cur_vm and vm.qid not in dependend_vms_qid: if vm.qid not in dependend_vms_qid:
dependend_vms_qid.append(vm.qid) dependend_vms_qid.append(vm.qid)
if vm.is_netvm(): if vm.is_netvm():
new_vms.append(vm.qid) new_vms.append(vm.qid)
@ -1964,6 +1958,8 @@ class QubesVmCollection(dict):
netvm_vm = self[netvm_qid] netvm_vm = self[netvm_qid]
vm.netvm_vm = netvm_vm vm.netvm_vm = netvm_vm
if netvm_vm:
netvm_vm.connected_vms[vm.qid] = vm
def load(self): def load(self):
self.clear() self.clear()