From c18a203a8abdd0fce24743a838fe0d0400d8ff09 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 20 Mar 2024 01:29:26 -0400 Subject: [PATCH] Don't unload model weights for non weight patches. --- comfy/model_management.py | 44 ++++++++++++++++++++++------ comfy/model_patcher.py | 60 ++++++++++++++++++++++++++------------- 2 files changed, 76 insertions(+), 28 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 66fa918b..74958908 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -273,6 +273,7 @@ class LoadedModel: def __init__(self, model): self.model = model self.device = model.load_device + self.weights_loaded = False def model_memory(self): return self.model.model_size() @@ -289,11 +290,13 @@ class LoadedModel: self.model.model_patches_to(self.device) self.model.model_patches_to(self.model.model_dtype()) + load_weights = not self.weights_loaded + try: - if lowvram_model_memory > 0: + if lowvram_model_memory > 0 and load_weights: self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory) else: - self.real_model = self.model.patch_model(device_to=patch_model_to) + self.real_model = self.model.patch_model(device_to=patch_model_to, patch_weights=load_weights) except Exception as e: self.model.unpatch_model(self.model.offload_device) self.model_unload() @@ -302,11 +305,13 @@ class LoadedModel: if is_intel_xpu() and not args.disable_ipex_optimize: self.real_model = torch.xpu.optimize(self.real_model.eval(), inplace=True, auto_kernel_selection=True, graph_mode=True) + self.weights_loaded = True return self.real_model - def model_unload(self): - self.model.unpatch_model(self.model.offload_device) + def model_unload(self, unpatch_weights=True): + self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights) self.model.model_patches_to(self.model.offload_device) + self.weights_loaded = self.weights_loaded and not unpatch_weights def __eq__(self, other): return self.model is other.model @@ -314,15 +319,35 @@ class LoadedModel: def minimum_inference_memory(): return (1024 * 1024 * 1024) -def unload_model_clones(model): +def unload_model_clones(loaded_model, unload_weights_only=True): + model = loaded_model.model + to_unload = [] for i in range(len(current_loaded_models)): if model.is_clone(current_loaded_models[i].model): to_unload = [i] + to_unload + if len(to_unload) == 0: + return + + same_weights = 0 for i in to_unload: - logging.debug("unload clone {}".format(i)) - current_loaded_models.pop(i).model_unload() + if model.clone_has_same_weights(current_loaded_models[i].model): + same_weights += 1 + + if same_weights == len(to_unload): + unload_weight = False + else: + unload_weight = True + + if unload_weights_only and unload_weight == False: + return + + for i in to_unload: + logging.debug("unload clone {} {}".format(i, unload_weight)) + current_loaded_models.pop(i).model_unload(unpatch_weights=unload_weight) + + loaded_model.weights_loaded = not unload_weight def free_memory(memory_required, device, keep_loaded=[]): unloaded_model = False @@ -377,13 +402,16 @@ def load_models_gpu(models, memory_required=0): total_memory_required = {} for loaded_model in models_to_load: - unload_model_clones(loaded_model.model) + unload_model_clones(loaded_model, unload_weights_only=True) #unload clones where the weights are different total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device) for device in total_memory_required: if device != torch.device("cpu"): free_memory(total_memory_required[device] * 1.3 + extra_mem, device, models_already_loaded) + for loaded_model in models_to_load: + unload_model_clones(loaded_model, unload_weights_only=False) #unload the rest of the clones where the weights can stay loaded + for loaded_model in models_to_load: model = loaded_model.model torch_dev = model.load_device diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 475fa812..aa78302d 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -2,6 +2,7 @@ import torch import copy import inspect import logging +import uuid import comfy.utils import comfy.model_management @@ -25,6 +26,7 @@ class ModelPatcher: self.weight_inplace_update = weight_inplace_update self.model_lowvram = False + self.patches_uuid = uuid.uuid4() def model_size(self): if self.size > 0: @@ -39,10 +41,13 @@ class ModelPatcher: n.patches = {} for k in self.patches: n.patches[k] = self.patches[k][:] + n.patches_uuid = self.patches_uuid n.object_patches = self.object_patches.copy() n.model_options = copy.deepcopy(self.model_options) n.model_keys = self.model_keys + n.backup = self.backup + n.object_patches_backup = self.object_patches_backup return n def is_clone(self, other): @@ -50,6 +55,19 @@ class ModelPatcher: return True return False + def clone_has_same_weights(self, clone): + if not self.is_clone(clone): + return False + + if len(self.patches) == 0 and len(clone.patches) == 0: + return True + + if self.patches_uuid == clone.patches_uuid: + if len(self.patches) != len(clone.patches): + logging.warning("WARNING: something went wrong, same patch uuid but different length of patches.") + else: + return True + def memory_required(self, input_shape): return self.model.memory_required(input_shape=input_shape) @@ -154,6 +172,7 @@ class ModelPatcher: current_patches.append((strength_patch, patches[k], strength_model)) self.patches[k] = current_patches + self.patches_uuid = uuid.uuid4() return list(p) def get_key_patches(self, filter_prefix=None): @@ -387,31 +406,32 @@ class ModelPatcher: return weight - def unpatch_model(self, device_to=None): - if self.model_lowvram: - for m in self.model.modules(): - if hasattr(m, "prev_comfy_cast_weights"): - m.comfy_cast_weights = m.prev_comfy_cast_weights - del m.prev_comfy_cast_weights - m.weight_function = None - m.bias_function = None + def unpatch_model(self, device_to=None, unpatch_weights=True): + if unpatch_weights: + if self.model_lowvram: + for m in self.model.modules(): + if hasattr(m, "prev_comfy_cast_weights"): + m.comfy_cast_weights = m.prev_comfy_cast_weights + del m.prev_comfy_cast_weights + m.weight_function = None + m.bias_function = None - self.model_lowvram = False + self.model_lowvram = False - keys = list(self.backup.keys()) + keys = list(self.backup.keys()) - if self.weight_inplace_update: - for k in keys: - comfy.utils.copy_to_param(self.model, k, self.backup[k]) - else: - for k in keys: - comfy.utils.set_attr_param(self.model, k, self.backup[k]) + if self.weight_inplace_update: + for k in keys: + comfy.utils.copy_to_param(self.model, k, self.backup[k]) + else: + for k in keys: + comfy.utils.set_attr_param(self.model, k, self.backup[k]) - self.backup = {} + self.backup.clear() - if device_to is not None: - self.model.to(device_to) - self.current_device = device_to + if device_to is not None: + self.model.to(device_to) + self.current_device = device_to keys = list(self.object_patches_backup.keys()) for k in keys: