When cached_hook_patches contain weights for hooks, only use hook_backup for unused keys (#7067)

This commit is contained in:
Jedrzej Kosinski 2025-03-09 03:26:31 -05:00 committed by GitHub
parent 2bc4b5968f
commit 528d1b3563
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1089,7 +1089,6 @@ class ModelPatcher:
def patch_hooks(self, hooks: comfy.hooks.HookGroup): def patch_hooks(self, hooks: comfy.hooks.HookGroup):
with self.use_ejected(): with self.use_ejected():
self.unpatch_hooks()
if hooks is not None: if hooks is not None:
model_sd_keys = list(self.model_state_dict().keys()) model_sd_keys = list(self.model_state_dict().keys())
memory_counter = None memory_counter = None
@ -1100,12 +1099,16 @@ class ModelPatcher:
# if have cached weights for hooks, use it # if have cached weights for hooks, use it
cached_weights = self.cached_hook_patches.get(hooks, None) cached_weights = self.cached_hook_patches.get(hooks, None)
if cached_weights is not None: if cached_weights is not None:
model_sd_keys_set = set(model_sd_keys)
for key in cached_weights: for key in cached_weights:
if key not in model_sd_keys: if key not in model_sd_keys:
logging.warning(f"Cached hook could not patch. Key does not exist in model: {key}") logging.warning(f"Cached hook could not patch. Key does not exist in model: {key}")
continue continue
self.patch_cached_hook_weights(cached_weights=cached_weights, key=key, memory_counter=memory_counter) self.patch_cached_hook_weights(cached_weights=cached_weights, key=key, memory_counter=memory_counter)
model_sd_keys_set.remove(key)
self.unpatch_hooks(model_sd_keys_set)
else: else:
self.unpatch_hooks()
relevant_patches = self.get_combined_hook_patches(hooks=hooks) relevant_patches = self.get_combined_hook_patches(hooks=hooks)
original_weights = None original_weights = None
if len(relevant_patches) > 0: if len(relevant_patches) > 0:
@ -1116,6 +1119,8 @@ class ModelPatcher:
continue continue
self.patch_hook_weight_to_device(hooks=hooks, combined_patches=relevant_patches, key=key, original_weights=original_weights, self.patch_hook_weight_to_device(hooks=hooks, combined_patches=relevant_patches, key=key, original_weights=original_weights,
memory_counter=memory_counter) memory_counter=memory_counter)
else:
self.unpatch_hooks()
self.current_hooks = hooks self.current_hooks = hooks
def patch_cached_hook_weights(self, cached_weights: dict, key: str, memory_counter: MemoryCounter): def patch_cached_hook_weights(self, cached_weights: dict, key: str, memory_counter: MemoryCounter):
@ -1172,12 +1177,18 @@ class ModelPatcher:
del out_weight del out_weight
del weight del weight
def unpatch_hooks(self) -> None: def unpatch_hooks(self, whitelist_keys_set: set[str]=None) -> None:
with self.use_ejected(): with self.use_ejected():
if len(self.hook_backup) == 0: if len(self.hook_backup) == 0:
self.current_hooks = None self.current_hooks = None
return return
keys = list(self.hook_backup.keys()) keys = list(self.hook_backup.keys())
if whitelist_keys_set:
for k in keys:
if k in whitelist_keys_set:
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
self.hook_backup.pop(k)
else:
for k in keys: for k in keys:
comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1])) comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))