mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-05-10 14:36:07 +00:00
When cached_hook_patches contain weights for hooks, only use hook_backup for unused keys (#7067)
This commit is contained in:
parent
2bc4b5968f
commit
528d1b3563
@ -1089,7 +1089,6 @@ class ModelPatcher:
|
|||||||
|
|
||||||
def patch_hooks(self, hooks: comfy.hooks.HookGroup):
|
def patch_hooks(self, hooks: comfy.hooks.HookGroup):
|
||||||
with self.use_ejected():
|
with self.use_ejected():
|
||||||
self.unpatch_hooks()
|
|
||||||
if hooks is not None:
|
if hooks is not None:
|
||||||
model_sd_keys = list(self.model_state_dict().keys())
|
model_sd_keys = list(self.model_state_dict().keys())
|
||||||
memory_counter = None
|
memory_counter = None
|
||||||
@ -1100,12 +1099,16 @@ class ModelPatcher:
|
|||||||
# if have cached weights for hooks, use it
|
# if have cached weights for hooks, use it
|
||||||
cached_weights = self.cached_hook_patches.get(hooks, None)
|
cached_weights = self.cached_hook_patches.get(hooks, None)
|
||||||
if cached_weights is not None:
|
if cached_weights is not None:
|
||||||
|
model_sd_keys_set = set(model_sd_keys)
|
||||||
for key in cached_weights:
|
for key in cached_weights:
|
||||||
if key not in model_sd_keys:
|
if key not in model_sd_keys:
|
||||||
logging.warning(f"Cached hook could not patch. Key does not exist in model: {key}")
|
logging.warning(f"Cached hook could not patch. Key does not exist in model: {key}")
|
||||||
continue
|
continue
|
||||||
self.patch_cached_hook_weights(cached_weights=cached_weights, key=key, memory_counter=memory_counter)
|
self.patch_cached_hook_weights(cached_weights=cached_weights, key=key, memory_counter=memory_counter)
|
||||||
|
model_sd_keys_set.remove(key)
|
||||||
|
self.unpatch_hooks(model_sd_keys_set)
|
||||||
else:
|
else:
|
||||||
|
self.unpatch_hooks()
|
||||||
relevant_patches = self.get_combined_hook_patches(hooks=hooks)
|
relevant_patches = self.get_combined_hook_patches(hooks=hooks)
|
||||||
original_weights = None
|
original_weights = None
|
||||||
if len(relevant_patches) > 0:
|
if len(relevant_patches) > 0:
|
||||||
@ -1116,6 +1119,8 @@ class ModelPatcher:
|
|||||||
continue
|
continue
|
||||||
self.patch_hook_weight_to_device(hooks=hooks, combined_patches=relevant_patches, key=key, original_weights=original_weights,
|
self.patch_hook_weight_to_device(hooks=hooks, combined_patches=relevant_patches, key=key, original_weights=original_weights,
|
||||||
memory_counter=memory_counter)
|
memory_counter=memory_counter)
|
||||||
|
else:
|
||||||
|
self.unpatch_hooks()
|
||||||
self.current_hooks = hooks
|
self.current_hooks = hooks
|
||||||
|
|
||||||
def patch_cached_hook_weights(self, cached_weights: dict, key: str, memory_counter: MemoryCounter):
|
def patch_cached_hook_weights(self, cached_weights: dict, key: str, memory_counter: MemoryCounter):
|
||||||
@ -1172,12 +1177,18 @@ class ModelPatcher:
|
|||||||
del out_weight
|
del out_weight
|
||||||
del weight
|
del weight
|
||||||
|
|
||||||
def unpatch_hooks(self) -> None:
|
def unpatch_hooks(self, whitelist_keys_set: set[str]=None) -> None:
|
||||||
with self.use_ejected():
|
with self.use_ejected():
|
||||||
if len(self.hook_backup) == 0:
|
if len(self.hook_backup) == 0:
|
||||||
self.current_hooks = None
|
self.current_hooks = None
|
||||||
return
|
return
|
||||||
keys = list(self.hook_backup.keys())
|
keys = list(self.hook_backup.keys())
|
||||||
|
if whitelist_keys_set:
|
||||||
|
for k in keys:
|
||||||
|
if k in whitelist_keys_set:
|
||||||
|
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
|
||||||
|
self.hook_backup.pop(k)
|
||||||
|
else:
|
||||||
for k in keys:
|
for k in keys:
|
||||||
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
|
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user