diff --git a/comfy/model_management.py b/comfy/model_management.py index ec80afea..23226cbe 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1,3 +1,21 @@ +""" + This file is part of ComfyUI. + Copyright (C) 2024 Comfy + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" + import psutil import logging from enum import Enum @@ -273,6 +291,9 @@ class LoadedModel: def model_memory(self): return self.model.model_size() + def model_offloaded_memory(self): + return self.model.model_size() - self.model.loaded_size() + def model_memory_required(self, device): if device == self.model.current_loaded_device(): return 0 @@ -308,15 +329,37 @@ class LoadedModel: return True return False - def model_unload(self, unpatch_weights=True): + def model_unload(self, memory_to_free=None, unpatch_weights=True): + if memory_to_free is not None: + if memory_to_free < self.model.loaded_size(): + self.model.partially_unload(self.model.offload_device, memory_to_free) + return False self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights) self.model.model_patches_to(self.model.offload_device) self.weights_loaded = self.weights_loaded and not unpatch_weights self.real_model = None + return True + + def model_use_more_vram(self, extra_memory): + return self.model.partially_load(self.device, extra_memory) def __eq__(self, other): return self.model is other.model +def use_more_memory(extra_memory, loaded_models, device): + for m in loaded_models: + if m.device == device: + extra_memory -= m.model_use_more_vram(extra_memory) + if extra_memory <= 0: + break + +def offloaded_memory(loaded_models, device): + offloaded_mem = 0 + for m in loaded_models: + if m.device == device: + offloaded_mem += m.model_offloaded_memory() + return offloaded_mem + def minimum_inference_memory(): return (1024 * 1024 * 1024) * 1.2 @@ -363,11 +406,15 @@ def free_memory(memory_required, device, keep_loaded=[]): for x in sorted(can_unload): i = x[-1] + memory_to_free = None if not DISABLE_SMART_MEMORY: - if get_free_memory(device) > memory_required: + free_mem = get_free_memory(device) + if free_mem > memory_required: break - current_loaded_models[i].model_unload() - unloaded_model.append(i) + memory_to_free = memory_required - free_mem + logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}") + if current_loaded_models[i].model_unload(memory_to_free, free_mem): + unloaded_model.append(i) for i in sorted(unloaded_model, reverse=True): unloaded_models.append(current_loaded_models.pop(i)) @@ -422,12 +469,14 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu devs = set(map(lambda a: a.device, models_already_loaded)) for d in devs: if d != torch.device("cpu"): - free_memory(extra_mem, d, models_already_loaded) + free_memory(extra_mem + offloaded_memory(models_already_loaded, d), d, models_already_loaded) free_mem = get_free_memory(d) if free_mem < minimum_memory_required: logging.info("Unloading models for lowram load.") #TODO: partial model unloading when this case happens, also handle the opposite case where models can be unlowvramed. models_to_load = free_memory(minimum_memory_required, d) logging.info("{} models unloaded.".format(len(models_to_load))) + else: + use_more_memory(free_mem - minimum_memory_required, models_already_loaded, d) if len(models_to_load) == 0: return @@ -467,6 +516,14 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu cur_loaded_model = loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights) current_loaded_models.insert(0, loaded_model) + + + devs = set(map(lambda a: a.device, models_already_loaded)) + for d in devs: + if d != torch.device("cpu"): + free_mem = get_free_memory(d) + if free_mem > minimum_memory_required: + use_more_memory(free_mem - minimum_memory_required, models_already_loaded, d) return diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 1ef49308..4ee3b35e 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -1,8 +1,27 @@ +""" + This file is part of ComfyUI. + Copyright (C) 2024 Comfy + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" + import torch import copy import inspect import logging import uuid +import collections import comfy.utils import comfy.model_management @@ -63,6 +82,21 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_ model_options["disable_cfg1_optimization"] = True return model_options +def wipe_lowvram_weight(m): + if hasattr(m, "prev_comfy_cast_weights"): + m.comfy_cast_weights = m.prev_comfy_cast_weights + del m.prev_comfy_cast_weights + m.weight_function = None + m.bias_function = None + +class LowVramPatch: + def __init__(self, key, model_patcher): + self.key = key + self.model_patcher = model_patcher + def __call__(self, weight): + return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key) + + class ModelPatcher: def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False): self.size = size @@ -82,16 +116,29 @@ class ModelPatcher: self.load_device = load_device self.offload_device = offload_device self.weight_inplace_update = weight_inplace_update - self.model_lowvram = False - self.lowvram_patch_counter = 0 self.patches_uuid = uuid.uuid4() + if not hasattr(self.model, 'model_loaded_weight_memory'): + self.model.model_loaded_weight_memory = 0 + + if not hasattr(self.model, 'lowvram_patch_counter'): + self.model.lowvram_patch_counter = 0 + + if not hasattr(self.model, 'model_lowvram'): + self.model.model_lowvram = False + def model_size(self): if self.size > 0: return self.size self.size = comfy.model_management.module_size(self.model) return self.size + def loaded_size(self): + return self.model.model_loaded_weight_memory + + def lowvram_patch_counter(self): + return self.model.lowvram_patch_counter + def clone(self): n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, weight_inplace_update=self.weight_inplace_update) n.patches = {} @@ -265,16 +312,16 @@ class ModelPatcher: sd.pop(k) return sd - def patch_weight_to_device(self, key, device_to=None): + def patch_weight_to_device(self, key, device_to=None, inplace_update=False): if key not in self.patches: return weight = comfy.utils.get_attr(self.model, key) - inplace_update = self.weight_inplace_update + inplace_update = self.weight_inplace_update or inplace_update if key not in self.backup: - self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update) + self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update) if device_to is not None: temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True) @@ -304,28 +351,24 @@ class ModelPatcher: if device_to is not None: self.model.to(device_to) self.model.device = device_to + self.model.model_loaded_weight_memory = self.model_size() return self.model - def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False): - self.patch_model(device_to, patch_weights=False) - + def lowvram_load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False): logging.info("loading in lowvram mode {}".format(lowvram_model_memory/(1024 * 1024))) - class LowVramPatch: - def __init__(self, key, model_patcher): - self.key = key - self.model_patcher = model_patcher - def __call__(self, weight): - return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key) - mem_counter = 0 patch_counter = 0 + lowvram_counter = 0 for n, m in self.model.named_modules(): lowvram_weight = False if hasattr(m, "comfy_cast_weights"): module_mem = comfy.model_management.module_size(m) if mem_counter + module_mem >= lowvram_model_memory: lowvram_weight = True + lowvram_counter += 1 + if m.comfy_cast_weights: + continue weight_key = "{}.weight".format(n) bias_key = "{}.bias".format(n) @@ -347,16 +390,31 @@ class ModelPatcher: m.prev_comfy_cast_weights = m.comfy_cast_weights m.comfy_cast_weights = True else: + if hasattr(m, "comfy_cast_weights"): + if m.comfy_cast_weights: + wipe_lowvram_weight(m) + if hasattr(m, "weight"): - self.patch_weight_to_device(weight_key) #TODO: speed this up without causing OOM + mem_counter += comfy.model_management.module_size(m) + if m.weight is not None and m.weight.device == device_to: + continue + self.patch_weight_to_device(weight_key) #TODO: speed this up without OOM self.patch_weight_to_device(bias_key) m.to(device_to) - mem_counter += comfy.model_management.module_size(m) logging.debug("lowvram: loaded module regularly {} {}".format(n, m)) - self.model_lowvram = True - self.lowvram_patch_counter = patch_counter + if lowvram_counter > 0: + self.model.model_lowvram = True + else: + self.model.model_lowvram = False + self.model.lowvram_patch_counter += patch_counter self.model.device = device_to + self.model.model_loaded_weight_memory = mem_counter + + + def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False): + self.patch_model(device_to, patch_weights=False) + self.lowvram_load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights) return self.model def calculate_weight(self, patches, weight, key): @@ -529,31 +587,28 @@ class ModelPatcher: def unpatch_model(self, device_to=None, unpatch_weights=True): if unpatch_weights: - if self.model_lowvram: + if self.model.model_lowvram: for m in self.model.modules(): - if hasattr(m, "prev_comfy_cast_weights"): - m.comfy_cast_weights = m.prev_comfy_cast_weights - del m.prev_comfy_cast_weights - m.weight_function = None - m.bias_function = None + wipe_lowvram_weight(m) - self.model_lowvram = False - self.lowvram_patch_counter = 0 + self.model.model_lowvram = False + self.model.lowvram_patch_counter = 0 keys = list(self.backup.keys()) - if self.weight_inplace_update: - for k in keys: - comfy.utils.copy_to_param(self.model, k, self.backup[k]) - else: - for k in keys: - comfy.utils.set_attr_param(self.model, k, self.backup[k]) + for k in keys: + bk = self.backup[k] + if bk.inplace_update: + comfy.utils.copy_to_param(self.model, k, bk.weight) + else: + comfy.utils.set_attr_param(self.model, k, bk.weight) self.backup.clear() if device_to is not None: self.model.to(device_to) self.model.device = device_to + self.model.model_loaded_weight_memory = 0 keys = list(self.object_patches_backup.keys()) for k in keys: @@ -561,5 +616,57 @@ class ModelPatcher: self.object_patches_backup.clear() + def partially_unload(self, device_to, memory_to_free=0): + memory_freed = 0 + patch_counter = 0 + + for n, m in list(self.model.named_modules())[::-1]: + if memory_to_free < memory_freed: + break + + shift_lowvram = False + if hasattr(m, "comfy_cast_weights"): + module_mem = comfy.model_management.module_size(m) + weight_key = "{}.weight".format(n) + bias_key = "{}.bias".format(n) + + + if m.weight is not None and m.weight.device != device_to: + for key in [weight_key, bias_key]: + bk = self.backup.get(key, None) + if bk is not None: + if bk.inplace_update: + comfy.utils.copy_to_param(self.model, key, bk.weight) + else: + comfy.utils.set_attr_param(self.model, key, bk.weight) + self.backup.pop(key) + + m.to(device_to) + if weight_key in self.patches: + m.weight_function = LowVramPatch(weight_key, self) + patch_counter += 1 + if bias_key in self.patches: + m.bias_function = LowVramPatch(bias_key, self) + patch_counter += 1 + + m.prev_comfy_cast_weights = m.comfy_cast_weights + m.comfy_cast_weights = True + memory_freed += module_mem + logging.debug("freed {}".format(n)) + + self.model.model_lowvram = True + self.model.lowvram_patch_counter += patch_counter + self.model.model_loaded_weight_memory -= memory_freed + return memory_freed + + def partially_load(self, device_to, extra_memory=0): + if self.model.model_lowvram == False: + return 0 + if self.model.model_loaded_weight_memory + extra_memory > self.model_size(): + pass #TODO: Full load + current_used = self.model.model_loaded_weight_memory + self.lowvram_load(device_to, lowvram_model_memory=current_used + extra_memory) + return self.model.model_loaded_weight_memory - current_used + def current_loaded_device(self): return self.model.device