Partial model shift support.

This commit is contained in:
comfyanonymous 2024-08-08 03:27:37 -04:00
parent 8115d8cce9
commit 08f92d55e9
2 changed files with 202 additions and 38 deletions

View File

@ -1,3 +1,21 @@
"""
This file is part of ComfyUI.
Copyright (C) 2024 Comfy
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import psutil
import logging
from enum import Enum
@ -273,6 +291,9 @@ class LoadedModel:
def model_memory(self):
return self.model.model_size()
def model_offloaded_memory(self):
return self.model.model_size() - self.model.loaded_size()
def model_memory_required(self, device):
if device == self.model.current_loaded_device():
return 0
@ -308,15 +329,37 @@ class LoadedModel:
return True
return False
def model_unload(self, unpatch_weights=True):
def model_unload(self, memory_to_free=None, unpatch_weights=True):
if memory_to_free is not None:
if memory_to_free < self.model.loaded_size():
self.model.partially_unload(self.model.offload_device, memory_to_free)
return False
self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights)
self.model.model_patches_to(self.model.offload_device)
self.weights_loaded = self.weights_loaded and not unpatch_weights
self.real_model = None
return True
def model_use_more_vram(self, extra_memory):
return self.model.partially_load(self.device, extra_memory)
def __eq__(self, other):
return self.model is other.model
def use_more_memory(extra_memory, loaded_models, device):
for m in loaded_models:
if m.device == device:
extra_memory -= m.model_use_more_vram(extra_memory)
if extra_memory <= 0:
break
def offloaded_memory(loaded_models, device):
offloaded_mem = 0
for m in loaded_models:
if m.device == device:
offloaded_mem += m.model_offloaded_memory()
return offloaded_mem
def minimum_inference_memory():
return (1024 * 1024 * 1024) * 1.2
@ -363,11 +406,15 @@ def free_memory(memory_required, device, keep_loaded=[]):
for x in sorted(can_unload):
i = x[-1]
memory_to_free = None
if not DISABLE_SMART_MEMORY:
if get_free_memory(device) > memory_required:
free_mem = get_free_memory(device)
if free_mem > memory_required:
break
current_loaded_models[i].model_unload()
unloaded_model.append(i)
memory_to_free = memory_required - free_mem
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
if current_loaded_models[i].model_unload(memory_to_free, free_mem):
unloaded_model.append(i)
for i in sorted(unloaded_model, reverse=True):
unloaded_models.append(current_loaded_models.pop(i))
@ -422,12 +469,14 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
devs = set(map(lambda a: a.device, models_already_loaded))
for d in devs:
if d != torch.device("cpu"):
free_memory(extra_mem, d, models_already_loaded)
free_memory(extra_mem + offloaded_memory(models_already_loaded, d), d, models_already_loaded)
free_mem = get_free_memory(d)
if free_mem < minimum_memory_required:
logging.info("Unloading models for lowram load.") #TODO: partial model unloading when this case happens, also handle the opposite case where models can be unlowvramed.
models_to_load = free_memory(minimum_memory_required, d)
logging.info("{} models unloaded.".format(len(models_to_load)))
else:
use_more_memory(free_mem - minimum_memory_required, models_already_loaded, d)
if len(models_to_load) == 0:
return
@ -467,6 +516,14 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
cur_loaded_model = loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
current_loaded_models.insert(0, loaded_model)
devs = set(map(lambda a: a.device, models_already_loaded))
for d in devs:
if d != torch.device("cpu"):
free_mem = get_free_memory(d)
if free_mem > minimum_memory_required:
use_more_memory(free_mem - minimum_memory_required, models_already_loaded, d)
return

View File

@ -1,8 +1,27 @@
"""
This file is part of ComfyUI.
Copyright (C) 2024 Comfy
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import torch
import copy
import inspect
import logging
import uuid
import collections
import comfy.utils
import comfy.model_management
@ -63,6 +82,21 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_
model_options["disable_cfg1_optimization"] = True
return model_options
def wipe_lowvram_weight(m):
if hasattr(m, "prev_comfy_cast_weights"):
m.comfy_cast_weights = m.prev_comfy_cast_weights
del m.prev_comfy_cast_weights
m.weight_function = None
m.bias_function = None
class LowVramPatch:
def __init__(self, key, model_patcher):
self.key = key
self.model_patcher = model_patcher
def __call__(self, weight):
return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key)
class ModelPatcher:
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
self.size = size
@ -82,16 +116,29 @@ class ModelPatcher:
self.load_device = load_device
self.offload_device = offload_device
self.weight_inplace_update = weight_inplace_update
self.model_lowvram = False
self.lowvram_patch_counter = 0
self.patches_uuid = uuid.uuid4()
if not hasattr(self.model, 'model_loaded_weight_memory'):
self.model.model_loaded_weight_memory = 0
if not hasattr(self.model, 'lowvram_patch_counter'):
self.model.lowvram_patch_counter = 0
if not hasattr(self.model, 'model_lowvram'):
self.model.model_lowvram = False
def model_size(self):
if self.size > 0:
return self.size
self.size = comfy.model_management.module_size(self.model)
return self.size
def loaded_size(self):
return self.model.model_loaded_weight_memory
def lowvram_patch_counter(self):
return self.model.lowvram_patch_counter
def clone(self):
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, weight_inplace_update=self.weight_inplace_update)
n.patches = {}
@ -265,16 +312,16 @@ class ModelPatcher:
sd.pop(k)
return sd
def patch_weight_to_device(self, key, device_to=None):
def patch_weight_to_device(self, key, device_to=None, inplace_update=False):
if key not in self.patches:
return
weight = comfy.utils.get_attr(self.model, key)
inplace_update = self.weight_inplace_update
inplace_update = self.weight_inplace_update or inplace_update
if key not in self.backup:
self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update)
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update)
if device_to is not None:
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
@ -304,28 +351,24 @@ class ModelPatcher:
if device_to is not None:
self.model.to(device_to)
self.model.device = device_to
self.model.model_loaded_weight_memory = self.model_size()
return self.model
def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False):
self.patch_model(device_to, patch_weights=False)
def lowvram_load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False):
logging.info("loading in lowvram mode {}".format(lowvram_model_memory/(1024 * 1024)))
class LowVramPatch:
def __init__(self, key, model_patcher):
self.key = key
self.model_patcher = model_patcher
def __call__(self, weight):
return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key)
mem_counter = 0
patch_counter = 0
lowvram_counter = 0
for n, m in self.model.named_modules():
lowvram_weight = False
if hasattr(m, "comfy_cast_weights"):
module_mem = comfy.model_management.module_size(m)
if mem_counter + module_mem >= lowvram_model_memory:
lowvram_weight = True
lowvram_counter += 1
if m.comfy_cast_weights:
continue
weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n)
@ -347,16 +390,31 @@ class ModelPatcher:
m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True
else:
if hasattr(m, "comfy_cast_weights"):
if m.comfy_cast_weights:
wipe_lowvram_weight(m)
if hasattr(m, "weight"):
self.patch_weight_to_device(weight_key) #TODO: speed this up without causing OOM
mem_counter += comfy.model_management.module_size(m)
if m.weight is not None and m.weight.device == device_to:
continue
self.patch_weight_to_device(weight_key) #TODO: speed this up without OOM
self.patch_weight_to_device(bias_key)
m.to(device_to)
mem_counter += comfy.model_management.module_size(m)
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
self.model_lowvram = True
self.lowvram_patch_counter = patch_counter
if lowvram_counter > 0:
self.model.model_lowvram = True
else:
self.model.model_lowvram = False
self.model.lowvram_patch_counter += patch_counter
self.model.device = device_to
self.model.model_loaded_weight_memory = mem_counter
def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False):
self.patch_model(device_to, patch_weights=False)
self.lowvram_load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights)
return self.model
def calculate_weight(self, patches, weight, key):
@ -529,31 +587,28 @@ class ModelPatcher:
def unpatch_model(self, device_to=None, unpatch_weights=True):
if unpatch_weights:
if self.model_lowvram:
if self.model.model_lowvram:
for m in self.model.modules():
if hasattr(m, "prev_comfy_cast_weights"):
m.comfy_cast_weights = m.prev_comfy_cast_weights
del m.prev_comfy_cast_weights
m.weight_function = None
m.bias_function = None
wipe_lowvram_weight(m)
self.model_lowvram = False
self.lowvram_patch_counter = 0
self.model.model_lowvram = False
self.model.lowvram_patch_counter = 0
keys = list(self.backup.keys())
if self.weight_inplace_update:
for k in keys:
comfy.utils.copy_to_param(self.model, k, self.backup[k])
else:
for k in keys:
comfy.utils.set_attr_param(self.model, k, self.backup[k])
for k in keys:
bk = self.backup[k]
if bk.inplace_update:
comfy.utils.copy_to_param(self.model, k, bk.weight)
else:
comfy.utils.set_attr_param(self.model, k, bk.weight)
self.backup.clear()
if device_to is not None:
self.model.to(device_to)
self.model.device = device_to
self.model.model_loaded_weight_memory = 0
keys = list(self.object_patches_backup.keys())
for k in keys:
@ -561,5 +616,57 @@ class ModelPatcher:
self.object_patches_backup.clear()
def partially_unload(self, device_to, memory_to_free=0):
memory_freed = 0
patch_counter = 0
for n, m in list(self.model.named_modules())[::-1]:
if memory_to_free < memory_freed:
break
shift_lowvram = False
if hasattr(m, "comfy_cast_weights"):
module_mem = comfy.model_management.module_size(m)
weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n)
if m.weight is not None and m.weight.device != device_to:
for key in [weight_key, bias_key]:
bk = self.backup.get(key, None)
if bk is not None:
if bk.inplace_update:
comfy.utils.copy_to_param(self.model, key, bk.weight)
else:
comfy.utils.set_attr_param(self.model, key, bk.weight)
self.backup.pop(key)
m.to(device_to)
if weight_key in self.patches:
m.weight_function = LowVramPatch(weight_key, self)
patch_counter += 1
if bias_key in self.patches:
m.bias_function = LowVramPatch(bias_key, self)
patch_counter += 1
m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True
memory_freed += module_mem
logging.debug("freed {}".format(n))
self.model.model_lowvram = True
self.model.lowvram_patch_counter += patch_counter
self.model.model_loaded_weight_memory -= memory_freed
return memory_freed
def partially_load(self, device_to, extra_memory=0):
if self.model.model_lowvram == False:
return 0
if self.model.model_loaded_weight_memory + extra_memory > self.model_size():
pass #TODO: Full load
current_used = self.model.model_loaded_weight_memory
self.lowvram_load(device_to, lowvram_model_memory=current_used + extra_memory)
return self.model.model_loaded_weight_memory - current_used
def current_loaded_device(self):
return self.model.device