Code cleanup.

This commit is contained in:
comfyanonymous 2024-08-19 12:48:59 -04:00
parent b29b3b86c5
commit 20ace7c853

View File

@ -644,40 +644,46 @@ class ModelPatcher:
def partially_unload(self, device_to, memory_to_free=0): def partially_unload(self, device_to, memory_to_free=0):
memory_freed = 0 memory_freed = 0
patch_counter = 0 patch_counter = 0
unload_list = []
for n, m in list(self.model.named_modules())[::-1]: for n, m in self.model.named_modules():
if memory_to_free < memory_freed:
break
shift_lowvram = False shift_lowvram = False
if hasattr(m, "comfy_cast_weights"): if hasattr(m, "comfy_cast_weights"):
module_mem = comfy.model_management.module_size(m) module_mem = comfy.model_management.module_size(m)
weight_key = "{}.weight".format(n) unload_list.append((module_mem, n, m))
bias_key = "{}.bias".format(n)
unload_list.sort()
for unload in unload_list:
if memory_to_free < memory_freed:
break
module_mem = unload[0]
n = unload[1]
m = unload[2]
weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n)
if m.weight is not None and m.weight.device != device_to: if m.weight is not None and m.weight.device != device_to:
for key in [weight_key, bias_key]: for key in [weight_key, bias_key]:
bk = self.backup.get(key, None) bk = self.backup.get(key, None)
if bk is not None: if bk is not None:
if bk.inplace_update: if bk.inplace_update:
comfy.utils.copy_to_param(self.model, key, bk.weight) comfy.utils.copy_to_param(self.model, key, bk.weight)
else: else:
comfy.utils.set_attr_param(self.model, key, bk.weight) comfy.utils.set_attr_param(self.model, key, bk.weight)
self.backup.pop(key) self.backup.pop(key)
m.to(device_to) m.to(device_to)
if weight_key in self.patches: if weight_key in self.patches:
m.weight_function = LowVramPatch(weight_key, self) m.weight_function = LowVramPatch(weight_key, self)
patch_counter += 1 patch_counter += 1
if bias_key in self.patches: if bias_key in self.patches:
m.bias_function = LowVramPatch(bias_key, self) m.bias_function = LowVramPatch(bias_key, self)
patch_counter += 1 patch_counter += 1
m.prev_comfy_cast_weights = m.comfy_cast_weights m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True m.comfy_cast_weights = True
memory_freed += module_mem memory_freed += module_mem
logging.debug("freed {}".format(n)) logging.debug("freed {}".format(n))
self.model.model_lowvram = True self.model.model_lowvram = True
self.model.lowvram_patch_counter += patch_counter self.model.lowvram_patch_counter += patch_counter