mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Partial model shift support.
This commit is contained in:
parent
8115d8cce9
commit
08f92d55e9
@ -1,3 +1,21 @@
|
|||||||
|
"""
|
||||||
|
This file is part of ComfyUI.
|
||||||
|
Copyright (C) 2024 Comfy
|
||||||
|
|
||||||
|
This program is free software: you can redistribute it and/or modify
|
||||||
|
it under the terms of the GNU General Public License as published by
|
||||||
|
the Free Software Foundation, either version 3 of the License, or
|
||||||
|
(at your option) any later version.
|
||||||
|
|
||||||
|
This program is distributed in the hope that it will be useful,
|
||||||
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
GNU General Public License for more details.
|
||||||
|
|
||||||
|
You should have received a copy of the GNU General Public License
|
||||||
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
"""
|
||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
import logging
|
import logging
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
@ -273,6 +291,9 @@ class LoadedModel:
|
|||||||
def model_memory(self):
|
def model_memory(self):
|
||||||
return self.model.model_size()
|
return self.model.model_size()
|
||||||
|
|
||||||
|
def model_offloaded_memory(self):
|
||||||
|
return self.model.model_size() - self.model.loaded_size()
|
||||||
|
|
||||||
def model_memory_required(self, device):
|
def model_memory_required(self, device):
|
||||||
if device == self.model.current_loaded_device():
|
if device == self.model.current_loaded_device():
|
||||||
return 0
|
return 0
|
||||||
@ -308,15 +329,37 @@ class LoadedModel:
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def model_unload(self, unpatch_weights=True):
|
def model_unload(self, memory_to_free=None, unpatch_weights=True):
|
||||||
|
if memory_to_free is not None:
|
||||||
|
if memory_to_free < self.model.loaded_size():
|
||||||
|
self.model.partially_unload(self.model.offload_device, memory_to_free)
|
||||||
|
return False
|
||||||
self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights)
|
self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights)
|
||||||
self.model.model_patches_to(self.model.offload_device)
|
self.model.model_patches_to(self.model.offload_device)
|
||||||
self.weights_loaded = self.weights_loaded and not unpatch_weights
|
self.weights_loaded = self.weights_loaded and not unpatch_weights
|
||||||
self.real_model = None
|
self.real_model = None
|
||||||
|
return True
|
||||||
|
|
||||||
|
def model_use_more_vram(self, extra_memory):
|
||||||
|
return self.model.partially_load(self.device, extra_memory)
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
return self.model is other.model
|
return self.model is other.model
|
||||||
|
|
||||||
|
def use_more_memory(extra_memory, loaded_models, device):
|
||||||
|
for m in loaded_models:
|
||||||
|
if m.device == device:
|
||||||
|
extra_memory -= m.model_use_more_vram(extra_memory)
|
||||||
|
if extra_memory <= 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
def offloaded_memory(loaded_models, device):
|
||||||
|
offloaded_mem = 0
|
||||||
|
for m in loaded_models:
|
||||||
|
if m.device == device:
|
||||||
|
offloaded_mem += m.model_offloaded_memory()
|
||||||
|
return offloaded_mem
|
||||||
|
|
||||||
def minimum_inference_memory():
|
def minimum_inference_memory():
|
||||||
return (1024 * 1024 * 1024) * 1.2
|
return (1024 * 1024 * 1024) * 1.2
|
||||||
|
|
||||||
@ -363,11 +406,15 @@ def free_memory(memory_required, device, keep_loaded=[]):
|
|||||||
|
|
||||||
for x in sorted(can_unload):
|
for x in sorted(can_unload):
|
||||||
i = x[-1]
|
i = x[-1]
|
||||||
|
memory_to_free = None
|
||||||
if not DISABLE_SMART_MEMORY:
|
if not DISABLE_SMART_MEMORY:
|
||||||
if get_free_memory(device) > memory_required:
|
free_mem = get_free_memory(device)
|
||||||
|
if free_mem > memory_required:
|
||||||
break
|
break
|
||||||
current_loaded_models[i].model_unload()
|
memory_to_free = memory_required - free_mem
|
||||||
unloaded_model.append(i)
|
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
||||||
|
if current_loaded_models[i].model_unload(memory_to_free, free_mem):
|
||||||
|
unloaded_model.append(i)
|
||||||
|
|
||||||
for i in sorted(unloaded_model, reverse=True):
|
for i in sorted(unloaded_model, reverse=True):
|
||||||
unloaded_models.append(current_loaded_models.pop(i))
|
unloaded_models.append(current_loaded_models.pop(i))
|
||||||
@ -422,12 +469,14 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
|||||||
devs = set(map(lambda a: a.device, models_already_loaded))
|
devs = set(map(lambda a: a.device, models_already_loaded))
|
||||||
for d in devs:
|
for d in devs:
|
||||||
if d != torch.device("cpu"):
|
if d != torch.device("cpu"):
|
||||||
free_memory(extra_mem, d, models_already_loaded)
|
free_memory(extra_mem + offloaded_memory(models_already_loaded, d), d, models_already_loaded)
|
||||||
free_mem = get_free_memory(d)
|
free_mem = get_free_memory(d)
|
||||||
if free_mem < minimum_memory_required:
|
if free_mem < minimum_memory_required:
|
||||||
logging.info("Unloading models for lowram load.") #TODO: partial model unloading when this case happens, also handle the opposite case where models can be unlowvramed.
|
logging.info("Unloading models for lowram load.") #TODO: partial model unloading when this case happens, also handle the opposite case where models can be unlowvramed.
|
||||||
models_to_load = free_memory(minimum_memory_required, d)
|
models_to_load = free_memory(minimum_memory_required, d)
|
||||||
logging.info("{} models unloaded.".format(len(models_to_load)))
|
logging.info("{} models unloaded.".format(len(models_to_load)))
|
||||||
|
else:
|
||||||
|
use_more_memory(free_mem - minimum_memory_required, models_already_loaded, d)
|
||||||
if len(models_to_load) == 0:
|
if len(models_to_load) == 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -467,6 +516,14 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
|||||||
|
|
||||||
cur_loaded_model = loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
|
cur_loaded_model = loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
|
||||||
current_loaded_models.insert(0, loaded_model)
|
current_loaded_models.insert(0, loaded_model)
|
||||||
|
|
||||||
|
|
||||||
|
devs = set(map(lambda a: a.device, models_already_loaded))
|
||||||
|
for d in devs:
|
||||||
|
if d != torch.device("cpu"):
|
||||||
|
free_mem = get_free_memory(d)
|
||||||
|
if free_mem > minimum_memory_required:
|
||||||
|
use_more_memory(free_mem - minimum_memory_required, models_already_loaded, d)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,8 +1,27 @@
|
|||||||
|
"""
|
||||||
|
This file is part of ComfyUI.
|
||||||
|
Copyright (C) 2024 Comfy
|
||||||
|
|
||||||
|
This program is free software: you can redistribute it and/or modify
|
||||||
|
it under the terms of the GNU General Public License as published by
|
||||||
|
the Free Software Foundation, either version 3 of the License, or
|
||||||
|
(at your option) any later version.
|
||||||
|
|
||||||
|
This program is distributed in the hope that it will be useful,
|
||||||
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
GNU General Public License for more details.
|
||||||
|
|
||||||
|
You should have received a copy of the GNU General Public License
|
||||||
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import copy
|
import copy
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
|
import collections
|
||||||
|
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
@ -63,6 +82,21 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_
|
|||||||
model_options["disable_cfg1_optimization"] = True
|
model_options["disable_cfg1_optimization"] = True
|
||||||
return model_options
|
return model_options
|
||||||
|
|
||||||
|
def wipe_lowvram_weight(m):
|
||||||
|
if hasattr(m, "prev_comfy_cast_weights"):
|
||||||
|
m.comfy_cast_weights = m.prev_comfy_cast_weights
|
||||||
|
del m.prev_comfy_cast_weights
|
||||||
|
m.weight_function = None
|
||||||
|
m.bias_function = None
|
||||||
|
|
||||||
|
class LowVramPatch:
|
||||||
|
def __init__(self, key, model_patcher):
|
||||||
|
self.key = key
|
||||||
|
self.model_patcher = model_patcher
|
||||||
|
def __call__(self, weight):
|
||||||
|
return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key)
|
||||||
|
|
||||||
|
|
||||||
class ModelPatcher:
|
class ModelPatcher:
|
||||||
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
|
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
|
||||||
self.size = size
|
self.size = size
|
||||||
@ -82,16 +116,29 @@ class ModelPatcher:
|
|||||||
self.load_device = load_device
|
self.load_device = load_device
|
||||||
self.offload_device = offload_device
|
self.offload_device = offload_device
|
||||||
self.weight_inplace_update = weight_inplace_update
|
self.weight_inplace_update = weight_inplace_update
|
||||||
self.model_lowvram = False
|
|
||||||
self.lowvram_patch_counter = 0
|
|
||||||
self.patches_uuid = uuid.uuid4()
|
self.patches_uuid = uuid.uuid4()
|
||||||
|
|
||||||
|
if not hasattr(self.model, 'model_loaded_weight_memory'):
|
||||||
|
self.model.model_loaded_weight_memory = 0
|
||||||
|
|
||||||
|
if not hasattr(self.model, 'lowvram_patch_counter'):
|
||||||
|
self.model.lowvram_patch_counter = 0
|
||||||
|
|
||||||
|
if not hasattr(self.model, 'model_lowvram'):
|
||||||
|
self.model.model_lowvram = False
|
||||||
|
|
||||||
def model_size(self):
|
def model_size(self):
|
||||||
if self.size > 0:
|
if self.size > 0:
|
||||||
return self.size
|
return self.size
|
||||||
self.size = comfy.model_management.module_size(self.model)
|
self.size = comfy.model_management.module_size(self.model)
|
||||||
return self.size
|
return self.size
|
||||||
|
|
||||||
|
def loaded_size(self):
|
||||||
|
return self.model.model_loaded_weight_memory
|
||||||
|
|
||||||
|
def lowvram_patch_counter(self):
|
||||||
|
return self.model.lowvram_patch_counter
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, weight_inplace_update=self.weight_inplace_update)
|
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, weight_inplace_update=self.weight_inplace_update)
|
||||||
n.patches = {}
|
n.patches = {}
|
||||||
@ -265,16 +312,16 @@ class ModelPatcher:
|
|||||||
sd.pop(k)
|
sd.pop(k)
|
||||||
return sd
|
return sd
|
||||||
|
|
||||||
def patch_weight_to_device(self, key, device_to=None):
|
def patch_weight_to_device(self, key, device_to=None, inplace_update=False):
|
||||||
if key not in self.patches:
|
if key not in self.patches:
|
||||||
return
|
return
|
||||||
|
|
||||||
weight = comfy.utils.get_attr(self.model, key)
|
weight = comfy.utils.get_attr(self.model, key)
|
||||||
|
|
||||||
inplace_update = self.weight_inplace_update
|
inplace_update = self.weight_inplace_update or inplace_update
|
||||||
|
|
||||||
if key not in self.backup:
|
if key not in self.backup:
|
||||||
self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update)
|
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update)
|
||||||
|
|
||||||
if device_to is not None:
|
if device_to is not None:
|
||||||
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
|
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
|
||||||
@ -304,28 +351,24 @@ class ModelPatcher:
|
|||||||
if device_to is not None:
|
if device_to is not None:
|
||||||
self.model.to(device_to)
|
self.model.to(device_to)
|
||||||
self.model.device = device_to
|
self.model.device = device_to
|
||||||
|
self.model.model_loaded_weight_memory = self.model_size()
|
||||||
|
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False):
|
def lowvram_load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False):
|
||||||
self.patch_model(device_to, patch_weights=False)
|
|
||||||
|
|
||||||
logging.info("loading in lowvram mode {}".format(lowvram_model_memory/(1024 * 1024)))
|
logging.info("loading in lowvram mode {}".format(lowvram_model_memory/(1024 * 1024)))
|
||||||
class LowVramPatch:
|
|
||||||
def __init__(self, key, model_patcher):
|
|
||||||
self.key = key
|
|
||||||
self.model_patcher = model_patcher
|
|
||||||
def __call__(self, weight):
|
|
||||||
return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key)
|
|
||||||
|
|
||||||
mem_counter = 0
|
mem_counter = 0
|
||||||
patch_counter = 0
|
patch_counter = 0
|
||||||
|
lowvram_counter = 0
|
||||||
for n, m in self.model.named_modules():
|
for n, m in self.model.named_modules():
|
||||||
lowvram_weight = False
|
lowvram_weight = False
|
||||||
if hasattr(m, "comfy_cast_weights"):
|
if hasattr(m, "comfy_cast_weights"):
|
||||||
module_mem = comfy.model_management.module_size(m)
|
module_mem = comfy.model_management.module_size(m)
|
||||||
if mem_counter + module_mem >= lowvram_model_memory:
|
if mem_counter + module_mem >= lowvram_model_memory:
|
||||||
lowvram_weight = True
|
lowvram_weight = True
|
||||||
|
lowvram_counter += 1
|
||||||
|
if m.comfy_cast_weights:
|
||||||
|
continue
|
||||||
|
|
||||||
weight_key = "{}.weight".format(n)
|
weight_key = "{}.weight".format(n)
|
||||||
bias_key = "{}.bias".format(n)
|
bias_key = "{}.bias".format(n)
|
||||||
@ -347,16 +390,31 @@ class ModelPatcher:
|
|||||||
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||||
m.comfy_cast_weights = True
|
m.comfy_cast_weights = True
|
||||||
else:
|
else:
|
||||||
|
if hasattr(m, "comfy_cast_weights"):
|
||||||
|
if m.comfy_cast_weights:
|
||||||
|
wipe_lowvram_weight(m)
|
||||||
|
|
||||||
if hasattr(m, "weight"):
|
if hasattr(m, "weight"):
|
||||||
self.patch_weight_to_device(weight_key) #TODO: speed this up without causing OOM
|
mem_counter += comfy.model_management.module_size(m)
|
||||||
|
if m.weight is not None and m.weight.device == device_to:
|
||||||
|
continue
|
||||||
|
self.patch_weight_to_device(weight_key) #TODO: speed this up without OOM
|
||||||
self.patch_weight_to_device(bias_key)
|
self.patch_weight_to_device(bias_key)
|
||||||
m.to(device_to)
|
m.to(device_to)
|
||||||
mem_counter += comfy.model_management.module_size(m)
|
|
||||||
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
|
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
|
||||||
|
|
||||||
self.model_lowvram = True
|
if lowvram_counter > 0:
|
||||||
self.lowvram_patch_counter = patch_counter
|
self.model.model_lowvram = True
|
||||||
|
else:
|
||||||
|
self.model.model_lowvram = False
|
||||||
|
self.model.lowvram_patch_counter += patch_counter
|
||||||
self.model.device = device_to
|
self.model.device = device_to
|
||||||
|
self.model.model_loaded_weight_memory = mem_counter
|
||||||
|
|
||||||
|
|
||||||
|
def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False):
|
||||||
|
self.patch_model(device_to, patch_weights=False)
|
||||||
|
self.lowvram_load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights)
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def calculate_weight(self, patches, weight, key):
|
def calculate_weight(self, patches, weight, key):
|
||||||
@ -529,31 +587,28 @@ class ModelPatcher:
|
|||||||
|
|
||||||
def unpatch_model(self, device_to=None, unpatch_weights=True):
|
def unpatch_model(self, device_to=None, unpatch_weights=True):
|
||||||
if unpatch_weights:
|
if unpatch_weights:
|
||||||
if self.model_lowvram:
|
if self.model.model_lowvram:
|
||||||
for m in self.model.modules():
|
for m in self.model.modules():
|
||||||
if hasattr(m, "prev_comfy_cast_weights"):
|
wipe_lowvram_weight(m)
|
||||||
m.comfy_cast_weights = m.prev_comfy_cast_weights
|
|
||||||
del m.prev_comfy_cast_weights
|
|
||||||
m.weight_function = None
|
|
||||||
m.bias_function = None
|
|
||||||
|
|
||||||
self.model_lowvram = False
|
self.model.model_lowvram = False
|
||||||
self.lowvram_patch_counter = 0
|
self.model.lowvram_patch_counter = 0
|
||||||
|
|
||||||
keys = list(self.backup.keys())
|
keys = list(self.backup.keys())
|
||||||
|
|
||||||
if self.weight_inplace_update:
|
for k in keys:
|
||||||
for k in keys:
|
bk = self.backup[k]
|
||||||
comfy.utils.copy_to_param(self.model, k, self.backup[k])
|
if bk.inplace_update:
|
||||||
else:
|
comfy.utils.copy_to_param(self.model, k, bk.weight)
|
||||||
for k in keys:
|
else:
|
||||||
comfy.utils.set_attr_param(self.model, k, self.backup[k])
|
comfy.utils.set_attr_param(self.model, k, bk.weight)
|
||||||
|
|
||||||
self.backup.clear()
|
self.backup.clear()
|
||||||
|
|
||||||
if device_to is not None:
|
if device_to is not None:
|
||||||
self.model.to(device_to)
|
self.model.to(device_to)
|
||||||
self.model.device = device_to
|
self.model.device = device_to
|
||||||
|
self.model.model_loaded_weight_memory = 0
|
||||||
|
|
||||||
keys = list(self.object_patches_backup.keys())
|
keys = list(self.object_patches_backup.keys())
|
||||||
for k in keys:
|
for k in keys:
|
||||||
@ -561,5 +616,57 @@ class ModelPatcher:
|
|||||||
|
|
||||||
self.object_patches_backup.clear()
|
self.object_patches_backup.clear()
|
||||||
|
|
||||||
|
def partially_unload(self, device_to, memory_to_free=0):
|
||||||
|
memory_freed = 0
|
||||||
|
patch_counter = 0
|
||||||
|
|
||||||
|
for n, m in list(self.model.named_modules())[::-1]:
|
||||||
|
if memory_to_free < memory_freed:
|
||||||
|
break
|
||||||
|
|
||||||
|
shift_lowvram = False
|
||||||
|
if hasattr(m, "comfy_cast_weights"):
|
||||||
|
module_mem = comfy.model_management.module_size(m)
|
||||||
|
weight_key = "{}.weight".format(n)
|
||||||
|
bias_key = "{}.bias".format(n)
|
||||||
|
|
||||||
|
|
||||||
|
if m.weight is not None and m.weight.device != device_to:
|
||||||
|
for key in [weight_key, bias_key]:
|
||||||
|
bk = self.backup.get(key, None)
|
||||||
|
if bk is not None:
|
||||||
|
if bk.inplace_update:
|
||||||
|
comfy.utils.copy_to_param(self.model, key, bk.weight)
|
||||||
|
else:
|
||||||
|
comfy.utils.set_attr_param(self.model, key, bk.weight)
|
||||||
|
self.backup.pop(key)
|
||||||
|
|
||||||
|
m.to(device_to)
|
||||||
|
if weight_key in self.patches:
|
||||||
|
m.weight_function = LowVramPatch(weight_key, self)
|
||||||
|
patch_counter += 1
|
||||||
|
if bias_key in self.patches:
|
||||||
|
m.bias_function = LowVramPatch(bias_key, self)
|
||||||
|
patch_counter += 1
|
||||||
|
|
||||||
|
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||||
|
m.comfy_cast_weights = True
|
||||||
|
memory_freed += module_mem
|
||||||
|
logging.debug("freed {}".format(n))
|
||||||
|
|
||||||
|
self.model.model_lowvram = True
|
||||||
|
self.model.lowvram_patch_counter += patch_counter
|
||||||
|
self.model.model_loaded_weight_memory -= memory_freed
|
||||||
|
return memory_freed
|
||||||
|
|
||||||
|
def partially_load(self, device_to, extra_memory=0):
|
||||||
|
if self.model.model_lowvram == False:
|
||||||
|
return 0
|
||||||
|
if self.model.model_loaded_weight_memory + extra_memory > self.model_size():
|
||||||
|
pass #TODO: Full load
|
||||||
|
current_used = self.model.model_loaded_weight_memory
|
||||||
|
self.lowvram_load(device_to, lowvram_model_memory=current_used + extra_memory)
|
||||||
|
return self.model.model_loaded_weight_memory - current_used
|
||||||
|
|
||||||
def current_loaded_device(self):
|
def current_loaded_device(self):
|
||||||
return self.model.device
|
return self.model.device
|
||||||
|
Loading…
Reference in New Issue
Block a user