diff --git a/comfy/model_management.py b/comfy/model_management.py
index ec80afea..23226cbe 100644
--- a/comfy/model_management.py
+++ b/comfy/model_management.py
@@ -1,3 +1,21 @@
+"""
+ This file is part of ComfyUI.
+ Copyright (C) 2024 Comfy
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see .
+"""
+
import psutil
import logging
from enum import Enum
@@ -273,6 +291,9 @@ class LoadedModel:
def model_memory(self):
return self.model.model_size()
+ def model_offloaded_memory(self):
+ return self.model.model_size() - self.model.loaded_size()
+
def model_memory_required(self, device):
if device == self.model.current_loaded_device():
return 0
@@ -308,15 +329,37 @@ class LoadedModel:
return True
return False
- def model_unload(self, unpatch_weights=True):
+ def model_unload(self, memory_to_free=None, unpatch_weights=True):
+ if memory_to_free is not None:
+ if memory_to_free < self.model.loaded_size():
+ self.model.partially_unload(self.model.offload_device, memory_to_free)
+ return False
self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights)
self.model.model_patches_to(self.model.offload_device)
self.weights_loaded = self.weights_loaded and not unpatch_weights
self.real_model = None
+ return True
+
+ def model_use_more_vram(self, extra_memory):
+ return self.model.partially_load(self.device, extra_memory)
def __eq__(self, other):
return self.model is other.model
+def use_more_memory(extra_memory, loaded_models, device):
+ for m in loaded_models:
+ if m.device == device:
+ extra_memory -= m.model_use_more_vram(extra_memory)
+ if extra_memory <= 0:
+ break
+
+def offloaded_memory(loaded_models, device):
+ offloaded_mem = 0
+ for m in loaded_models:
+ if m.device == device:
+ offloaded_mem += m.model_offloaded_memory()
+ return offloaded_mem
+
def minimum_inference_memory():
return (1024 * 1024 * 1024) * 1.2
@@ -363,11 +406,15 @@ def free_memory(memory_required, device, keep_loaded=[]):
for x in sorted(can_unload):
i = x[-1]
+ memory_to_free = None
if not DISABLE_SMART_MEMORY:
- if get_free_memory(device) > memory_required:
+ free_mem = get_free_memory(device)
+ if free_mem > memory_required:
break
- current_loaded_models[i].model_unload()
- unloaded_model.append(i)
+ memory_to_free = memory_required - free_mem
+ logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
+ if current_loaded_models[i].model_unload(memory_to_free, free_mem):
+ unloaded_model.append(i)
for i in sorted(unloaded_model, reverse=True):
unloaded_models.append(current_loaded_models.pop(i))
@@ -422,12 +469,14 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
devs = set(map(lambda a: a.device, models_already_loaded))
for d in devs:
if d != torch.device("cpu"):
- free_memory(extra_mem, d, models_already_loaded)
+ free_memory(extra_mem + offloaded_memory(models_already_loaded, d), d, models_already_loaded)
free_mem = get_free_memory(d)
if free_mem < minimum_memory_required:
logging.info("Unloading models for lowram load.") #TODO: partial model unloading when this case happens, also handle the opposite case where models can be unlowvramed.
models_to_load = free_memory(minimum_memory_required, d)
logging.info("{} models unloaded.".format(len(models_to_load)))
+ else:
+ use_more_memory(free_mem - minimum_memory_required, models_already_loaded, d)
if len(models_to_load) == 0:
return
@@ -467,6 +516,14 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
cur_loaded_model = loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
current_loaded_models.insert(0, loaded_model)
+
+
+ devs = set(map(lambda a: a.device, models_already_loaded))
+ for d in devs:
+ if d != torch.device("cpu"):
+ free_mem = get_free_memory(d)
+ if free_mem > minimum_memory_required:
+ use_more_memory(free_mem - minimum_memory_required, models_already_loaded, d)
return
diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py
index 1ef49308..4ee3b35e 100644
--- a/comfy/model_patcher.py
+++ b/comfy/model_patcher.py
@@ -1,8 +1,27 @@
+"""
+ This file is part of ComfyUI.
+ Copyright (C) 2024 Comfy
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see .
+"""
+
import torch
import copy
import inspect
import logging
import uuid
+import collections
import comfy.utils
import comfy.model_management
@@ -63,6 +82,21 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_
model_options["disable_cfg1_optimization"] = True
return model_options
+def wipe_lowvram_weight(m):
+ if hasattr(m, "prev_comfy_cast_weights"):
+ m.comfy_cast_weights = m.prev_comfy_cast_weights
+ del m.prev_comfy_cast_weights
+ m.weight_function = None
+ m.bias_function = None
+
+class LowVramPatch:
+ def __init__(self, key, model_patcher):
+ self.key = key
+ self.model_patcher = model_patcher
+ def __call__(self, weight):
+ return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key)
+
+
class ModelPatcher:
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
self.size = size
@@ -82,16 +116,29 @@ class ModelPatcher:
self.load_device = load_device
self.offload_device = offload_device
self.weight_inplace_update = weight_inplace_update
- self.model_lowvram = False
- self.lowvram_patch_counter = 0
self.patches_uuid = uuid.uuid4()
+ if not hasattr(self.model, 'model_loaded_weight_memory'):
+ self.model.model_loaded_weight_memory = 0
+
+ if not hasattr(self.model, 'lowvram_patch_counter'):
+ self.model.lowvram_patch_counter = 0
+
+ if not hasattr(self.model, 'model_lowvram'):
+ self.model.model_lowvram = False
+
def model_size(self):
if self.size > 0:
return self.size
self.size = comfy.model_management.module_size(self.model)
return self.size
+ def loaded_size(self):
+ return self.model.model_loaded_weight_memory
+
+ def lowvram_patch_counter(self):
+ return self.model.lowvram_patch_counter
+
def clone(self):
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, weight_inplace_update=self.weight_inplace_update)
n.patches = {}
@@ -265,16 +312,16 @@ class ModelPatcher:
sd.pop(k)
return sd
- def patch_weight_to_device(self, key, device_to=None):
+ def patch_weight_to_device(self, key, device_to=None, inplace_update=False):
if key not in self.patches:
return
weight = comfy.utils.get_attr(self.model, key)
- inplace_update = self.weight_inplace_update
+ inplace_update = self.weight_inplace_update or inplace_update
if key not in self.backup:
- self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update)
+ self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update)
if device_to is not None:
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
@@ -304,28 +351,24 @@ class ModelPatcher:
if device_to is not None:
self.model.to(device_to)
self.model.device = device_to
+ self.model.model_loaded_weight_memory = self.model_size()
return self.model
- def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False):
- self.patch_model(device_to, patch_weights=False)
-
+ def lowvram_load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False):
logging.info("loading in lowvram mode {}".format(lowvram_model_memory/(1024 * 1024)))
- class LowVramPatch:
- def __init__(self, key, model_patcher):
- self.key = key
- self.model_patcher = model_patcher
- def __call__(self, weight):
- return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key)
-
mem_counter = 0
patch_counter = 0
+ lowvram_counter = 0
for n, m in self.model.named_modules():
lowvram_weight = False
if hasattr(m, "comfy_cast_weights"):
module_mem = comfy.model_management.module_size(m)
if mem_counter + module_mem >= lowvram_model_memory:
lowvram_weight = True
+ lowvram_counter += 1
+ if m.comfy_cast_weights:
+ continue
weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n)
@@ -347,16 +390,31 @@ class ModelPatcher:
m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True
else:
+ if hasattr(m, "comfy_cast_weights"):
+ if m.comfy_cast_weights:
+ wipe_lowvram_weight(m)
+
if hasattr(m, "weight"):
- self.patch_weight_to_device(weight_key) #TODO: speed this up without causing OOM
+ mem_counter += comfy.model_management.module_size(m)
+ if m.weight is not None and m.weight.device == device_to:
+ continue
+ self.patch_weight_to_device(weight_key) #TODO: speed this up without OOM
self.patch_weight_to_device(bias_key)
m.to(device_to)
- mem_counter += comfy.model_management.module_size(m)
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
- self.model_lowvram = True
- self.lowvram_patch_counter = patch_counter
+ if lowvram_counter > 0:
+ self.model.model_lowvram = True
+ else:
+ self.model.model_lowvram = False
+ self.model.lowvram_patch_counter += patch_counter
self.model.device = device_to
+ self.model.model_loaded_weight_memory = mem_counter
+
+
+ def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False):
+ self.patch_model(device_to, patch_weights=False)
+ self.lowvram_load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights)
return self.model
def calculate_weight(self, patches, weight, key):
@@ -529,31 +587,28 @@ class ModelPatcher:
def unpatch_model(self, device_to=None, unpatch_weights=True):
if unpatch_weights:
- if self.model_lowvram:
+ if self.model.model_lowvram:
for m in self.model.modules():
- if hasattr(m, "prev_comfy_cast_weights"):
- m.comfy_cast_weights = m.prev_comfy_cast_weights
- del m.prev_comfy_cast_weights
- m.weight_function = None
- m.bias_function = None
+ wipe_lowvram_weight(m)
- self.model_lowvram = False
- self.lowvram_patch_counter = 0
+ self.model.model_lowvram = False
+ self.model.lowvram_patch_counter = 0
keys = list(self.backup.keys())
- if self.weight_inplace_update:
- for k in keys:
- comfy.utils.copy_to_param(self.model, k, self.backup[k])
- else:
- for k in keys:
- comfy.utils.set_attr_param(self.model, k, self.backup[k])
+ for k in keys:
+ bk = self.backup[k]
+ if bk.inplace_update:
+ comfy.utils.copy_to_param(self.model, k, bk.weight)
+ else:
+ comfy.utils.set_attr_param(self.model, k, bk.weight)
self.backup.clear()
if device_to is not None:
self.model.to(device_to)
self.model.device = device_to
+ self.model.model_loaded_weight_memory = 0
keys = list(self.object_patches_backup.keys())
for k in keys:
@@ -561,5 +616,57 @@ class ModelPatcher:
self.object_patches_backup.clear()
+ def partially_unload(self, device_to, memory_to_free=0):
+ memory_freed = 0
+ patch_counter = 0
+
+ for n, m in list(self.model.named_modules())[::-1]:
+ if memory_to_free < memory_freed:
+ break
+
+ shift_lowvram = False
+ if hasattr(m, "comfy_cast_weights"):
+ module_mem = comfy.model_management.module_size(m)
+ weight_key = "{}.weight".format(n)
+ bias_key = "{}.bias".format(n)
+
+
+ if m.weight is not None and m.weight.device != device_to:
+ for key in [weight_key, bias_key]:
+ bk = self.backup.get(key, None)
+ if bk is not None:
+ if bk.inplace_update:
+ comfy.utils.copy_to_param(self.model, key, bk.weight)
+ else:
+ comfy.utils.set_attr_param(self.model, key, bk.weight)
+ self.backup.pop(key)
+
+ m.to(device_to)
+ if weight_key in self.patches:
+ m.weight_function = LowVramPatch(weight_key, self)
+ patch_counter += 1
+ if bias_key in self.patches:
+ m.bias_function = LowVramPatch(bias_key, self)
+ patch_counter += 1
+
+ m.prev_comfy_cast_weights = m.comfy_cast_weights
+ m.comfy_cast_weights = True
+ memory_freed += module_mem
+ logging.debug("freed {}".format(n))
+
+ self.model.model_lowvram = True
+ self.model.lowvram_patch_counter += patch_counter
+ self.model.model_loaded_weight_memory -= memory_freed
+ return memory_freed
+
+ def partially_load(self, device_to, extra_memory=0):
+ if self.model.model_lowvram == False:
+ return 0
+ if self.model.model_loaded_weight_memory + extra_memory > self.model_size():
+ pass #TODO: Full load
+ current_used = self.model.model_loaded_weight_memory
+ self.lowvram_load(device_to, lowvram_model_memory=current_used + extra_memory)
+ return self.model.model_loaded_weight_memory - current_used
+
def current_loaded_device(self):
return self.model.device