From 79d5ceae6efe458302899a28496b0daf59890109 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 2 Dec 2024 14:39:34 -0500 Subject: [PATCH] Improved memory management. (#5450) * Less fragile memory management. * Fix issue. * Remove useless function. * Prevent and detect some types of memory leaks. * Run garbage collector when switching workflow if needed. * Fix issue. --- comfy/model_management.py | 200 ++++++++++++++++---------------------- comfy/model_patcher.py | 36 ++++++- execution.py | 2 +- main.py | 1 - 4 files changed, 119 insertions(+), 120 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index a793cab3..2a500956 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -23,6 +23,8 @@ from comfy.cli_args import args import torch import sys import platform +import weakref +import gc class VRAMState(Enum): DISABLED = 0 #No vram present: no need to move models to vram @@ -287,11 +289,27 @@ def module_size(module): class LoadedModel: def __init__(self, model): - self.model = model + self._set_model(model) self.device = model.load_device - self.weights_loaded = False self.real_model = None self.currently_used = True + self.model_finalizer = None + self._patcher_finalizer = None + + def _set_model(self, model): + self._model = weakref.ref(model) + if model.parent is not None: + self._parent_model = weakref.ref(model.parent) + self._patcher_finalizer = weakref.finalize(model, self._switch_parent) + + def _switch_parent(self): + model = self._parent_model() + if model is not None: + self._set_model(model) + + @property + def model(self): + return self._model() def model_memory(self): return self.model.model_size() @@ -306,32 +324,23 @@ class LoadedModel: return self.model_memory() def model_load(self, lowvram_model_memory=0, force_patch_weights=False): - patch_model_to = self.device - self.model.model_patches_to(self.device) self.model.model_patches_to(self.model.model_dtype()) - load_weights = not self.weights_loaded + # if self.model.loaded_size() > 0: + use_more_vram = lowvram_model_memory + if use_more_vram == 0: + use_more_vram = 1e32 + self.model_use_more_vram(use_more_vram, force_patch_weights=force_patch_weights) + real_model = self.model.model - if self.model.loaded_size() > 0: - use_more_vram = lowvram_model_memory - if use_more_vram == 0: - use_more_vram = 1e32 - self.model_use_more_vram(use_more_vram) - else: - try: - self.real_model = self.model.patch_model(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, load_weights=load_weights, force_patch_weights=force_patch_weights) - except Exception as e: - self.model.unpatch_model(self.model.offload_device) - self.model_unload() - raise e - - if is_intel_xpu() and not args.disable_ipex_optimize and 'ipex' in globals() and self.real_model is not None: + if is_intel_xpu() and not args.disable_ipex_optimize and 'ipex' in globals() and real_model is not None: with torch.no_grad(): - self.real_model = ipex.optimize(self.real_model.eval(), inplace=True, graph_mode=True, concat_linear=True) + real_model = ipex.optimize(real_model.eval(), inplace=True, graph_mode=True, concat_linear=True) - self.weights_loaded = True - return self.real_model + self.real_model = weakref.ref(real_model) + self.model_finalizer = weakref.finalize(real_model, cleanup_models) + return real_model def should_reload_model(self, force_patch_weights=False): if force_patch_weights and self.model.lowvram_patch_counter() > 0: @@ -344,18 +353,23 @@ class LoadedModel: freed = self.model.partially_unload(self.model.offload_device, memory_to_free) if freed >= memory_to_free: return False - self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights) - self.model.model_patches_to(self.model.offload_device) - self.weights_loaded = self.weights_loaded and not unpatch_weights + self.model.detach(unpatch_weights) + self.model_finalizer.detach() + self.model_finalizer = None self.real_model = None return True - def model_use_more_vram(self, extra_memory): - return self.model.partially_load(self.device, extra_memory) + def model_use_more_vram(self, extra_memory, force_patch_weights=False): + return self.model.partially_load(self.device, extra_memory, force_patch_weights=force_patch_weights) def __eq__(self, other): return self.model is other.model + def __del__(self): + if self._patcher_finalizer is not None: + self._patcher_finalizer.detach() + + def use_more_memory(extra_memory, loaded_models, device): for m in loaded_models: if m.device == device: @@ -386,38 +400,8 @@ def extra_reserved_memory(): def minimum_inference_memory(): return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory() -def unload_model_clones(model, unload_weights_only=True, force_unload=True): - to_unload = [] - for i in range(len(current_loaded_models)): - if model.is_clone(current_loaded_models[i].model): - to_unload = [i] + to_unload - - if len(to_unload) == 0: - return True - - same_weights = 0 - for i in to_unload: - if model.clone_has_same_weights(current_loaded_models[i].model): - same_weights += 1 - - if same_weights == len(to_unload): - unload_weight = False - else: - unload_weight = True - - if not force_unload: - if unload_weights_only and unload_weight == False: - return None - else: - unload_weight = True - - for i in to_unload: - logging.debug("unload clone {} {}".format(i, unload_weight)) - current_loaded_models.pop(i).model_unload(unpatch_weights=unload_weight) - - return unload_weight - def free_memory(memory_required, device, keep_loaded=[]): + cleanup_models_gc() unloaded_model = [] can_unload = [] unloaded_models = [] @@ -454,6 +438,7 @@ def free_memory(memory_required, device, keep_loaded=[]): return unloaded_models def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False): + cleanup_models_gc() global vram_state inference_memory = minimum_inference_memory() @@ -466,11 +451,9 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu models = set(models) models_to_load = [] - models_already_loaded = [] + for x in models: loaded_model = LoadedModel(x) - loaded = None - try: loaded_model_index = current_loaded_models.index(loaded_model) except: @@ -478,51 +461,35 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu if loaded_model_index is not None: loaded = current_loaded_models[loaded_model_index] - if loaded.should_reload_model(force_patch_weights=force_patch_weights): #TODO: cleanup this model reload logic - current_loaded_models.pop(loaded_model_index).model_unload(unpatch_weights=True) - loaded = None - else: - loaded.currently_used = True - models_already_loaded.append(loaded) - - if loaded is None: + loaded.currently_used = True + models_to_load.append(loaded) + else: if hasattr(x, "model"): logging.info(f"Requested to load {x.model.__class__.__name__}") models_to_load.append(loaded_model) - if len(models_to_load) == 0: - devs = set(map(lambda a: a.device, models_already_loaded)) - for d in devs: - if d != torch.device("cpu"): - free_memory(extra_mem + offloaded_memory(models_already_loaded, d), d, models_already_loaded) - free_mem = get_free_memory(d) - if free_mem < minimum_memory_required: - logging.info("Unloading models for lowram load.") #TODO: partial model unloading when this case happens, also handle the opposite case where models can be unlowvramed. - models_to_load = free_memory(minimum_memory_required, d) - logging.info("{} models unloaded.".format(len(models_to_load))) - else: - use_more_memory(free_mem - minimum_memory_required, models_already_loaded, d) - if len(models_to_load) == 0: - return - - logging.info(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}") + for loaded_model in models_to_load: + to_unload = [] + for i in range(len(current_loaded_models)): + if loaded_model.model.is_clone(current_loaded_models[i].model): + to_unload = [i] + to_unload + for i in to_unload: + current_loaded_models.pop(i).model.detach(unpatch_all=False) total_memory_required = {} for loaded_model in models_to_load: - unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False) #unload clones where the weights are different total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device) - for loaded_model in models_already_loaded: - total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device) - - for loaded_model in models_to_load: - weights_unloaded = unload_model_clones(loaded_model.model, unload_weights_only=False, force_unload=False) #unload the rest of the clones where the weights can stay loaded - if weights_unloaded is not None: - loaded_model.weights_loaded = not weights_unloaded - for device in total_memory_required: if device != torch.device("cpu"): - free_memory(total_memory_required[device] * 1.1 + extra_mem, device, models_already_loaded) + free_memory(total_memory_required[device] * 1.1 + extra_mem, device) + + for device in total_memory_required: + if device != torch.device("cpu"): + free_mem = get_free_memory(device) + if free_mem < minimum_memory_required: + models_l = free_memory(minimum_memory_required, device) + logging.info("{} models unloaded.".format(len(models_l))) for loaded_model in models_to_load: model = loaded_model.model @@ -544,17 +511,8 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu cur_loaded_model = loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights) current_loaded_models.insert(0, loaded_model) - - - devs = set(map(lambda a: a.device, models_already_loaded)) - for d in devs: - if d != torch.device("cpu"): - free_mem = get_free_memory(d) - if free_mem > minimum_memory_required: - use_more_memory(free_mem - minimum_memory_required, models_already_loaded, d) return - def load_model_gpu(model): return load_models_gpu([model]) @@ -568,21 +526,35 @@ def loaded_models(only_currently_used=False): output.append(m.model) return output -def cleanup_models(keep_clone_weights_loaded=False): + +def cleanup_models_gc(): + do_gc = False + for i in range(len(current_loaded_models)): + cur = current_loaded_models[i] + if cur.real_model() is not None and cur.model is None: + logging.info("Potential memory leak detected with model {}, doing a full garbage collect, for maximum performance avoid circular references in the model code.".format(cur.real_model().__class__.__name__)) + do_gc = True + break + + if do_gc: + gc.collect() + soft_empty_cache() + + for i in range(len(current_loaded_models)): + cur = current_loaded_models[i] + if cur.real_model() is not None and cur.model is None: + logging.warning("WARNING, memory leak with model {}. Please make sure it is not being referenced from somewhere.".format(cur.real_model().__class__.__name__)) + + + +def cleanup_models(): to_delete = [] for i in range(len(current_loaded_models)): - #TODO: very fragile function needs improvement - num_refs = sys.getrefcount(current_loaded_models[i].model) - if num_refs <= 2: - if not keep_clone_weights_loaded: - to_delete = [i] + to_delete - #TODO: find a less fragile way to do this. - elif sys.getrefcount(current_loaded_models[i].real_model) <= 3: #references from .real_model + the .model - to_delete = [i] + to_delete + if current_loaded_models[i].real_model() is None: + to_delete = [i] + to_delete for i in to_delete: x = current_loaded_models.pop(i) - x.model_unload() del x def dtype_size(dtype): diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index f53f1074..b64c795a 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -139,6 +139,7 @@ class ModelPatcher: self.offload_device = offload_device self.weight_inplace_update = weight_inplace_update self.patches_uuid = uuid.uuid4() + self.parent = None if not hasattr(self.model, 'model_loaded_weight_memory'): self.model.model_loaded_weight_memory = 0 @@ -149,6 +150,9 @@ class ModelPatcher: if not hasattr(self.model, 'model_lowvram'): self.model.model_lowvram = False + if not hasattr(self.model, 'current_weight_patches_uuid'): + self.model.current_weight_patches_uuid = None + def model_size(self): if self.size > 0: return self.size @@ -172,6 +176,7 @@ class ModelPatcher: n.model_options = copy.deepcopy(self.model_options) n.backup = self.backup n.object_patches_backup = self.object_patches_backup + n.parent = self return n def is_clone(self, other): @@ -464,6 +469,7 @@ class ModelPatcher: self.model.lowvram_patch_counter += patch_counter self.model.device = device_to self.model.model_loaded_weight_memory = mem_counter + self.model.current_weight_patches_uuid = self.patches_uuid def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False): for k in self.object_patches: @@ -498,6 +504,7 @@ class ModelPatcher: else: comfy.utils.set_attr_param(self.model, k, bk.weight) + self.model.current_weight_patches_uuid = None self.backup.clear() if device_to is not None: @@ -568,21 +575,42 @@ class ModelPatcher: self.model.model_loaded_weight_memory -= memory_freed return memory_freed - def partially_load(self, device_to, extra_memory=0): - self.unpatch_model(unpatch_weights=False) + def partially_load(self, device_to, extra_memory=0, force_patch_weights=False): + unpatch_weights = self.model.current_weight_patches_uuid is not None and (self.model.current_weight_patches_uuid != self.patches_uuid or force_patch_weights) + # TODO: force_patch_weights should not unload + reload full model + used = self.model.model_loaded_weight_memory + self.unpatch_model(self.offload_device, unpatch_weights=unpatch_weights) + if unpatch_weights: + extra_memory += (used - self.model.model_loaded_weight_memory) + self.patch_model(load_weights=False) full_load = False - if self.model.model_lowvram == False: + if self.model.model_lowvram == False and self.model.model_loaded_weight_memory > 0: return 0 if self.model.model_loaded_weight_memory + extra_memory > self.model_size(): full_load = True current_used = self.model.model_loaded_weight_memory - self.load(device_to, lowvram_model_memory=current_used + extra_memory, full_load=full_load) + try: + self.load(device_to, lowvram_model_memory=current_used + extra_memory, force_patch_weights=force_patch_weights, full_load=full_load) + except Exception as e: + self.detach() + raise e + return self.model.model_loaded_weight_memory - current_used + def detach(self, unpatch_all=True): + self.model_patches_to(self.offload_device) + if unpatch_all: + self.unpatch_model(self.offload_device, unpatch_weights=unpatch_all) + return self.model + def current_loaded_device(self): return self.model.device def calculate_weight(self, patches, weight, key, intermediate_dtype=torch.float32): print("WARNING the ModelPatcher.calculate_weight function is deprecated, please use: comfy.lora.calculate_weight instead") return comfy.lora.calculate_weight(patches, weight, key, intermediate_dtype=intermediate_dtype) + + def __del__(self): + self.detach(unpatch_all=False) + diff --git a/execution.py b/execution.py index 6c386341..768e35ab 100644 --- a/execution.py +++ b/execution.py @@ -480,7 +480,7 @@ class PromptExecutor: if self.caches.outputs.get(node_id) is not None: cached_nodes.append(node_id) - comfy.model_management.cleanup_models(keep_clone_weights_loaded=True) + comfy.model_management.cleanup_models_gc() self.add_message("execution_cached", { "nodes": cached_nodes, "prompt_id": prompt_id}, broadcast=False) diff --git a/main.py b/main.py index 05eb31c7..c2c2ff8c 100644 --- a/main.py +++ b/main.py @@ -154,7 +154,6 @@ def prompt_worker(q, server): if need_gc: current_time = time.perf_counter() if (current_time - last_gc_collect) > gc_collect_interval: - comfy.model_management.cleanup_models() gc.collect() comfy.model_management.soft_empty_cache() last_gc_collect = current_time