mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Code cleanup.
This commit is contained in:
parent
b29b3b86c5
commit
20ace7c853
@ -644,40 +644,46 @@ class ModelPatcher:
|
|||||||
def partially_unload(self, device_to, memory_to_free=0):
|
def partially_unload(self, device_to, memory_to_free=0):
|
||||||
memory_freed = 0
|
memory_freed = 0
|
||||||
patch_counter = 0
|
patch_counter = 0
|
||||||
|
unload_list = []
|
||||||
|
|
||||||
for n, m in list(self.model.named_modules())[::-1]:
|
for n, m in self.model.named_modules():
|
||||||
if memory_to_free < memory_freed:
|
|
||||||
break
|
|
||||||
|
|
||||||
shift_lowvram = False
|
shift_lowvram = False
|
||||||
if hasattr(m, "comfy_cast_weights"):
|
if hasattr(m, "comfy_cast_weights"):
|
||||||
module_mem = comfy.model_management.module_size(m)
|
module_mem = comfy.model_management.module_size(m)
|
||||||
weight_key = "{}.weight".format(n)
|
unload_list.append((module_mem, n, m))
|
||||||
bias_key = "{}.bias".format(n)
|
|
||||||
|
|
||||||
|
unload_list.sort()
|
||||||
|
for unload in unload_list:
|
||||||
|
if memory_to_free < memory_freed:
|
||||||
|
break
|
||||||
|
module_mem = unload[0]
|
||||||
|
n = unload[1]
|
||||||
|
m = unload[2]
|
||||||
|
weight_key = "{}.weight".format(n)
|
||||||
|
bias_key = "{}.bias".format(n)
|
||||||
|
|
||||||
if m.weight is not None and m.weight.device != device_to:
|
if m.weight is not None and m.weight.device != device_to:
|
||||||
for key in [weight_key, bias_key]:
|
for key in [weight_key, bias_key]:
|
||||||
bk = self.backup.get(key, None)
|
bk = self.backup.get(key, None)
|
||||||
if bk is not None:
|
if bk is not None:
|
||||||
if bk.inplace_update:
|
if bk.inplace_update:
|
||||||
comfy.utils.copy_to_param(self.model, key, bk.weight)
|
comfy.utils.copy_to_param(self.model, key, bk.weight)
|
||||||
else:
|
else:
|
||||||
comfy.utils.set_attr_param(self.model, key, bk.weight)
|
comfy.utils.set_attr_param(self.model, key, bk.weight)
|
||||||
self.backup.pop(key)
|
self.backup.pop(key)
|
||||||
|
|
||||||
m.to(device_to)
|
m.to(device_to)
|
||||||
if weight_key in self.patches:
|
if weight_key in self.patches:
|
||||||
m.weight_function = LowVramPatch(weight_key, self)
|
m.weight_function = LowVramPatch(weight_key, self)
|
||||||
patch_counter += 1
|
patch_counter += 1
|
||||||
if bias_key in self.patches:
|
if bias_key in self.patches:
|
||||||
m.bias_function = LowVramPatch(bias_key, self)
|
m.bias_function = LowVramPatch(bias_key, self)
|
||||||
patch_counter += 1
|
patch_counter += 1
|
||||||
|
|
||||||
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||||
m.comfy_cast_weights = True
|
m.comfy_cast_weights = True
|
||||||
memory_freed += module_mem
|
memory_freed += module_mem
|
||||||
logging.debug("freed {}".format(n))
|
logging.debug("freed {}".format(n))
|
||||||
|
|
||||||
self.model.model_lowvram = True
|
self.model.model_lowvram = True
|
||||||
self.model.lowvram_patch_counter += patch_counter
|
self.model.lowvram_patch_counter += patch_counter
|
||||||
|
Loading…
Reference in New Issue
Block a user