mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-10 18:05:16 +00:00
Smarter memory management.
Try to keep models on the vram when possible. Better lowvram mode for controlnets.
This commit is contained in:
parent
2c97c30256
commit
89a0767abf
@ -244,30 +244,15 @@ class Gligen(nn.Module):
|
||||
self.position_net = position_net
|
||||
self.key_dim = key_dim
|
||||
self.max_objs = 30
|
||||
self.lowvram = False
|
||||
self.current_device = torch.device("cpu")
|
||||
|
||||
def _set_position(self, boxes, masks, positive_embeddings):
|
||||
if self.lowvram == True:
|
||||
self.position_net.to(boxes.device)
|
||||
|
||||
objs = self.position_net(boxes, masks, positive_embeddings)
|
||||
|
||||
if self.lowvram == True:
|
||||
self.position_net.cpu()
|
||||
def func_lowvram(x, extra_options):
|
||||
key = extra_options["transformer_index"]
|
||||
module = self.module_list[key]
|
||||
module.to(x.device)
|
||||
r = module(x, objs)
|
||||
module.cpu()
|
||||
return r
|
||||
return func_lowvram
|
||||
else:
|
||||
def func(x, extra_options):
|
||||
key = extra_options["transformer_index"]
|
||||
module = self.module_list[key]
|
||||
return module(x, objs)
|
||||
return func
|
||||
def func(x, extra_options):
|
||||
key = extra_options["transformer_index"]
|
||||
module = self.module_list[key]
|
||||
return module(x, objs)
|
||||
return func
|
||||
|
||||
def set_position(self, latent_image_shape, position_params, device):
|
||||
batch, c, h, w = latent_image_shape
|
||||
@ -312,14 +297,6 @@ class Gligen(nn.Module):
|
||||
masks.to(device),
|
||||
conds.to(device))
|
||||
|
||||
def set_lowvram(self, value=True):
|
||||
self.lowvram = value
|
||||
|
||||
def cleanup(self):
|
||||
self.lowvram = False
|
||||
|
||||
def get_models(self):
|
||||
return [self]
|
||||
|
||||
def load_gligen(sd):
|
||||
sd_k = sd.keys()
|
||||
|
@ -2,6 +2,7 @@ import psutil
|
||||
from enum import Enum
|
||||
from comfy.cli_args import args
|
||||
import torch
|
||||
import sys
|
||||
|
||||
class VRAMState(Enum):
|
||||
DISABLED = 0 #No vram present: no need to move models to vram
|
||||
@ -221,132 +222,161 @@ except:
|
||||
print("Could not pick default device.")
|
||||
|
||||
|
||||
current_loaded_model = None
|
||||
current_gpu_controlnets = []
|
||||
current_loaded_models = []
|
||||
|
||||
model_accelerated = False
|
||||
class LoadedModel:
|
||||
def __init__(self, model):
|
||||
self.model = model
|
||||
self.model_accelerated = False
|
||||
self.device = model.load_device
|
||||
|
||||
def model_memory(self):
|
||||
return self.model.model_size()
|
||||
|
||||
def unload_model():
|
||||
global current_loaded_model
|
||||
global model_accelerated
|
||||
global current_gpu_controlnets
|
||||
global vram_state
|
||||
def model_memory_required(self, device):
|
||||
if device == self.model.current_device:
|
||||
return 0
|
||||
else:
|
||||
return self.model_memory()
|
||||
|
||||
if current_loaded_model is not None:
|
||||
if model_accelerated:
|
||||
accelerate.hooks.remove_hook_from_submodules(current_loaded_model.model)
|
||||
model_accelerated = False
|
||||
def model_load(self, lowvram_model_memory=0):
|
||||
patch_model_to = None
|
||||
if lowvram_model_memory == 0:
|
||||
patch_model_to = self.device
|
||||
|
||||
current_loaded_model.unpatch_model()
|
||||
current_loaded_model.model.to(current_loaded_model.offload_device)
|
||||
current_loaded_model.model_patches_to(current_loaded_model.offload_device)
|
||||
current_loaded_model = None
|
||||
if vram_state != VRAMState.HIGH_VRAM:
|
||||
soft_empty_cache()
|
||||
self.model.model_patches_to(self.device)
|
||||
self.model.model_patches_to(self.model.model_dtype())
|
||||
|
||||
if vram_state != VRAMState.HIGH_VRAM:
|
||||
if len(current_gpu_controlnets) > 0:
|
||||
for n in current_gpu_controlnets:
|
||||
n.cpu()
|
||||
current_gpu_controlnets = []
|
||||
try:
|
||||
self.real_model = self.model.patch_model(device_to=patch_model_to) #TODO: do something with loras and offloading to CPU
|
||||
except Exception as e:
|
||||
self.model.unpatch_model(self.model.offload_device)
|
||||
self.model_unload()
|
||||
raise e
|
||||
|
||||
if lowvram_model_memory > 0:
|
||||
print("loading in lowvram mode", lowvram_model_memory/(1024 * 1024))
|
||||
device_map = accelerate.infer_auto_device_map(self.real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"})
|
||||
accelerate.dispatch_model(self.real_model, device_map=device_map, main_device=self.device)
|
||||
self.model_accelerated = True
|
||||
|
||||
return self.real_model
|
||||
|
||||
def model_unload(self):
|
||||
if self.model_accelerated:
|
||||
accelerate.hooks.remove_hook_from_submodules(self.real_model)
|
||||
self.model_accelerated = False
|
||||
|
||||
self.model.unpatch_model(self.model.offload_device)
|
||||
self.model.model_patches_to(self.model.offload_device)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.model is other.model
|
||||
|
||||
def minimum_inference_memory():
|
||||
return (768 * 1024 * 1024)
|
||||
return (1024 * 1024 * 1024)
|
||||
|
||||
def unload_model_clones(model):
|
||||
to_unload = []
|
||||
for i in range(len(current_loaded_models)):
|
||||
if model.is_clone(current_loaded_models[i].model):
|
||||
to_unload = [i] + to_unload
|
||||
|
||||
for i in to_unload:
|
||||
print("unload clone", i)
|
||||
current_loaded_models.pop(i).model_unload()
|
||||
|
||||
def free_memory(memory_required, device, keep_loaded=[]):
|
||||
unloaded_model = False
|
||||
for i in range(len(current_loaded_models) -1, -1, -1):
|
||||
current_free_mem = get_free_memory(device)
|
||||
if current_free_mem > memory_required:
|
||||
break
|
||||
shift_model = current_loaded_models[i]
|
||||
if shift_model.device == device:
|
||||
if shift_model not in keep_loaded:
|
||||
current_loaded_models.pop(i).model_unload()
|
||||
unloaded_model = True
|
||||
|
||||
if unloaded_model:
|
||||
soft_empty_cache()
|
||||
|
||||
|
||||
def load_models_gpu(models, memory_required=0):
|
||||
global vram_state
|
||||
|
||||
inference_memory = minimum_inference_memory()
|
||||
extra_mem = max(inference_memory, memory_required)
|
||||
|
||||
models_to_load = []
|
||||
models_already_loaded = []
|
||||
for x in models:
|
||||
loaded_model = LoadedModel(x)
|
||||
|
||||
if loaded_model in current_loaded_models:
|
||||
index = current_loaded_models.index(loaded_model)
|
||||
current_loaded_models.insert(0, current_loaded_models.pop(index))
|
||||
models_already_loaded.append(loaded_model)
|
||||
else:
|
||||
models_to_load.append(loaded_model)
|
||||
|
||||
if len(models_to_load) == 0:
|
||||
devs = set(map(lambda a: a.device, models_already_loaded))
|
||||
for d in devs:
|
||||
if d != torch.device("cpu"):
|
||||
free_memory(extra_mem, d, models_already_loaded)
|
||||
return
|
||||
|
||||
print("loading new")
|
||||
|
||||
total_memory_required = {}
|
||||
for loaded_model in models_to_load:
|
||||
unload_model_clones(loaded_model.model)
|
||||
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
|
||||
|
||||
for device in total_memory_required:
|
||||
if device != torch.device("cpu"):
|
||||
free_memory(total_memory_required[device] * 1.3 + extra_mem, device, models_already_loaded)
|
||||
|
||||
for loaded_model in models_to_load:
|
||||
model = loaded_model.model
|
||||
torch_dev = model.load_device
|
||||
if is_device_cpu(torch_dev):
|
||||
vram_set_state = VRAMState.DISABLED
|
||||
else:
|
||||
vram_set_state = vram_state
|
||||
lowvram_model_memory = 0
|
||||
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
|
||||
model_size = loaded_model.model_memory_required(torch_dev)
|
||||
current_free_mem = get_free_memory(torch_dev)
|
||||
lowvram_model_memory = int(max(256 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 ))
|
||||
if model_size > (current_free_mem - inference_memory): #only switch to lowvram if really necessary
|
||||
vram_set_state = VRAMState.LOW_VRAM
|
||||
else:
|
||||
lowvram_model_memory = 0
|
||||
|
||||
if vram_set_state == VRAMState.NO_VRAM:
|
||||
lowvram_model_memory = 256 * 1024 * 1024
|
||||
|
||||
cur_loaded_model = loaded_model.model_load(lowvram_model_memory)
|
||||
current_loaded_models.insert(0, loaded_model)
|
||||
return
|
||||
|
||||
|
||||
def load_model_gpu(model):
|
||||
global current_loaded_model
|
||||
global vram_state
|
||||
global model_accelerated
|
||||
return load_models_gpu([model])
|
||||
|
||||
if model is current_loaded_model:
|
||||
return
|
||||
unload_model()
|
||||
def cleanup_models():
|
||||
to_delete = []
|
||||
for i in range(len(current_loaded_models)):
|
||||
print(sys.getrefcount(current_loaded_models[i].model))
|
||||
if sys.getrefcount(current_loaded_models[i].model) <= 2:
|
||||
to_delete = [i] + to_delete
|
||||
|
||||
torch_dev = model.load_device
|
||||
model.model_patches_to(torch_dev)
|
||||
model.model_patches_to(model.model_dtype())
|
||||
current_loaded_model = model
|
||||
|
||||
if is_device_cpu(torch_dev):
|
||||
vram_set_state = VRAMState.DISABLED
|
||||
else:
|
||||
vram_set_state = vram_state
|
||||
|
||||
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
|
||||
model_size = model.model_size()
|
||||
current_free_mem = get_free_memory(torch_dev)
|
||||
lowvram_model_memory = int(max(256 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 ))
|
||||
if model_size > (current_free_mem - minimum_inference_memory()): #only switch to lowvram if really necessary
|
||||
vram_set_state = VRAMState.LOW_VRAM
|
||||
|
||||
real_model = model.model
|
||||
patch_model_to = None
|
||||
if vram_set_state == VRAMState.DISABLED:
|
||||
pass
|
||||
elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM or vram_set_state == VRAMState.SHARED:
|
||||
model_accelerated = False
|
||||
patch_model_to = torch_dev
|
||||
|
||||
try:
|
||||
real_model = model.patch_model(device_to=patch_model_to)
|
||||
except Exception as e:
|
||||
model.unpatch_model()
|
||||
unload_model()
|
||||
raise e
|
||||
|
||||
if patch_model_to is not None:
|
||||
real_model.to(torch_dev)
|
||||
|
||||
if vram_set_state == VRAMState.NO_VRAM:
|
||||
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"})
|
||||
accelerate.dispatch_model(real_model, device_map=device_map, main_device=torch_dev)
|
||||
model_accelerated = True
|
||||
elif vram_set_state == VRAMState.LOW_VRAM:
|
||||
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"})
|
||||
accelerate.dispatch_model(real_model, device_map=device_map, main_device=torch_dev)
|
||||
model_accelerated = True
|
||||
|
||||
return current_loaded_model
|
||||
|
||||
def load_controlnet_gpu(control_models):
|
||||
global current_gpu_controlnets
|
||||
global vram_state
|
||||
if vram_state == VRAMState.DISABLED:
|
||||
return
|
||||
|
||||
if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM:
|
||||
for m in control_models:
|
||||
if hasattr(m, 'set_lowvram'):
|
||||
m.set_lowvram(True)
|
||||
#don't load controlnets like this if low vram because they will be loaded right before running and unloaded right after
|
||||
return
|
||||
|
||||
models = []
|
||||
for m in control_models:
|
||||
models += m.get_models()
|
||||
|
||||
for m in current_gpu_controlnets:
|
||||
if m not in models:
|
||||
m.cpu()
|
||||
|
||||
device = get_torch_device()
|
||||
current_gpu_controlnets = []
|
||||
for m in models:
|
||||
current_gpu_controlnets.append(m.to(device))
|
||||
|
||||
|
||||
def load_if_low_vram(model):
|
||||
global vram_state
|
||||
if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM:
|
||||
return model.to(get_torch_device())
|
||||
return model
|
||||
|
||||
def unload_if_low_vram(model):
|
||||
global vram_state
|
||||
if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM:
|
||||
return model.cpu()
|
||||
return model
|
||||
for i in to_delete:
|
||||
x = current_loaded_models.pop(i)
|
||||
x.model_unload()
|
||||
del x
|
||||
|
||||
def unet_offload_device():
|
||||
if vram_state == VRAMState.HIGH_VRAM:
|
||||
@ -354,6 +384,21 @@ def unet_offload_device():
|
||||
else:
|
||||
return torch.device("cpu")
|
||||
|
||||
def unet_inital_load_device(parameters, dtype):
|
||||
torch_dev = get_torch_device()
|
||||
if vram_state == VRAMState.HIGH_VRAM:
|
||||
return torch_dev
|
||||
|
||||
cpu_dev = torch.device("cpu")
|
||||
model_size = dtype.itemsize * parameters
|
||||
|
||||
mem_dev = get_free_memory(torch_dev)
|
||||
mem_cpu = get_free_memory(cpu_dev)
|
||||
if mem_dev > mem_cpu and model_size < mem_dev:
|
||||
return torch_dev
|
||||
else:
|
||||
return cpu_dev
|
||||
|
||||
def text_encoder_offload_device():
|
||||
if args.gpu_only:
|
||||
return get_torch_device()
|
||||
@ -456,6 +501,13 @@ def get_free_memory(dev=None, torch_free_too=False):
|
||||
else:
|
||||
return mem_free_total
|
||||
|
||||
def batch_area_memory(area):
|
||||
if xformers_enabled() or pytorch_attention_flash_attention():
|
||||
#TODO: these formulas are copied from maximum_batch_area below
|
||||
return (area / 20) * (1024 * 1024)
|
||||
else:
|
||||
return (((area * 0.6) / 0.9) + 1024) * (1024 * 1024)
|
||||
|
||||
def maximum_batch_area():
|
||||
global vram_state
|
||||
if vram_state == VRAMState.NO_VRAM:
|
||||
|
@ -51,19 +51,24 @@ def get_models_from_cond(cond, model_type):
|
||||
models += [c[1][model_type]]
|
||||
return models
|
||||
|
||||
def load_additional_models(positive, negative, dtype):
|
||||
def get_additional_models(positive, negative):
|
||||
"""loads additional models in positive and negative conditioning"""
|
||||
control_nets = get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control")
|
||||
|
||||
control_models = []
|
||||
for m in control_nets:
|
||||
control_models += m.get_models()
|
||||
|
||||
gligen = get_models_from_cond(positive, "gligen") + get_models_from_cond(negative, "gligen")
|
||||
gligen = [x[1].to(dtype) for x in gligen]
|
||||
models = control_nets + gligen
|
||||
comfy.model_management.load_controlnet_gpu(models)
|
||||
gligen = [x[1] for x in gligen]
|
||||
models = control_models + gligen
|
||||
return models
|
||||
|
||||
def cleanup_additional_models(models):
|
||||
"""cleanup additional models that were loaded"""
|
||||
for m in models:
|
||||
m.cleanup()
|
||||
if hasattr(m, 'cleanup'):
|
||||
m.cleanup()
|
||||
|
||||
def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None):
|
||||
device = comfy.model_management.get_torch_device()
|
||||
@ -72,7 +77,8 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
|
||||
noise_mask = prepare_mask(noise_mask, noise.shape, device)
|
||||
|
||||
real_model = None
|
||||
comfy.model_management.load_model_gpu(model)
|
||||
models = get_additional_models(positive, negative)
|
||||
comfy.model_management.load_models_gpu([model] + models, comfy.model_management.batch_area_memory(noise.shape[2] * noise.shape[3]))
|
||||
real_model = model.model
|
||||
|
||||
noise = noise.to(device)
|
||||
@ -81,7 +87,6 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
|
||||
positive_copy = broadcast_cond(positive, noise.shape[0], device)
|
||||
negative_copy = broadcast_cond(negative, noise.shape[0], device)
|
||||
|
||||
models = load_additional_models(positive, negative, model.model_dtype())
|
||||
|
||||
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
|
||||
|
||||
|
@ -88,9 +88,9 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
||||
gligen_type = gligen[0]
|
||||
gligen_model = gligen[1]
|
||||
if gligen_type == "position":
|
||||
gligen_patch = gligen_model.set_position(input_x.shape, gligen[2], input_x.device)
|
||||
gligen_patch = gligen_model.model.set_position(input_x.shape, gligen[2], input_x.device)
|
||||
else:
|
||||
gligen_patch = gligen_model.set_empty(input_x.shape, input_x.device)
|
||||
gligen_patch = gligen_model.model.set_empty(input_x.shape, input_x.device)
|
||||
|
||||
patches['middle_patch'] = [gligen_patch]
|
||||
|
||||
|
59
comfy/sd.py
59
comfy/sd.py
@ -244,7 +244,7 @@ def set_attr(obj, attr, value):
|
||||
del prev
|
||||
|
||||
class ModelPatcher:
|
||||
def __init__(self, model, load_device, offload_device, size=0):
|
||||
def __init__(self, model, load_device, offload_device, size=0, current_device=None):
|
||||
self.size = size
|
||||
self.model = model
|
||||
self.patches = {}
|
||||
@ -253,6 +253,10 @@ class ModelPatcher:
|
||||
self.model_size()
|
||||
self.load_device = load_device
|
||||
self.offload_device = offload_device
|
||||
if current_device is None:
|
||||
self.current_device = self.offload_device
|
||||
else:
|
||||
self.current_device = current_device
|
||||
|
||||
def model_size(self):
|
||||
if self.size > 0:
|
||||
@ -267,7 +271,7 @@ class ModelPatcher:
|
||||
return size
|
||||
|
||||
def clone(self):
|
||||
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size)
|
||||
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device)
|
||||
n.patches = {}
|
||||
for k in self.patches:
|
||||
n.patches[k] = self.patches[k][:]
|
||||
@ -276,6 +280,11 @@ class ModelPatcher:
|
||||
n.model_keys = self.model_keys
|
||||
return n
|
||||
|
||||
def is_clone(self, other):
|
||||
if hasattr(other, 'model') and self.model is other.model:
|
||||
return True
|
||||
return False
|
||||
|
||||
def set_model_sampler_cfg_function(self, sampler_cfg_function):
|
||||
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
|
||||
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
|
||||
@ -390,6 +399,11 @@ class ModelPatcher:
|
||||
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
|
||||
set_attr(self.model, key, out_weight)
|
||||
del temp_weight
|
||||
|
||||
if device_to is not None:
|
||||
self.model.to(device_to)
|
||||
self.current_device = device_to
|
||||
|
||||
return self.model
|
||||
|
||||
def calculate_weight(self, patches, weight, key):
|
||||
@ -482,7 +496,7 @@ class ModelPatcher:
|
||||
|
||||
return weight
|
||||
|
||||
def unpatch_model(self):
|
||||
def unpatch_model(self, device_to=None):
|
||||
keys = list(self.backup.keys())
|
||||
|
||||
for k in keys:
|
||||
@ -490,6 +504,11 @@ class ModelPatcher:
|
||||
|
||||
self.backup = {}
|
||||
|
||||
if device_to is not None:
|
||||
self.model.to(device_to)
|
||||
self.current_device = device_to
|
||||
|
||||
|
||||
def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
||||
key_map = model_lora_keys_unet(model.model)
|
||||
key_map = model_lora_keys_clip(clip.cond_stage_model, key_map)
|
||||
@ -630,11 +649,12 @@ class VAE:
|
||||
return samples
|
||||
|
||||
def decode(self, samples_in):
|
||||
model_management.unload_model()
|
||||
self.first_stage_model = self.first_stage_model.to(self.device)
|
||||
try:
|
||||
memory_used = (2562 * samples_in.shape[2] * samples_in.shape[3] * 64) * 1.4
|
||||
model_management.free_memory(memory_used, self.device)
|
||||
free_memory = model_management.get_free_memory(self.device)
|
||||
batch_number = int((free_memory * 0.7) / (2562 * samples_in.shape[2] * samples_in.shape[3] * 64))
|
||||
batch_number = int(free_memory / memory_used)
|
||||
batch_number = max(1, batch_number)
|
||||
|
||||
pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * 8), round(samples_in.shape[3] * 8)), device="cpu")
|
||||
@ -650,19 +670,19 @@ class VAE:
|
||||
return pixel_samples
|
||||
|
||||
def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16):
|
||||
model_management.unload_model()
|
||||
self.first_stage_model = self.first_stage_model.to(self.device)
|
||||
output = self.decode_tiled_(samples, tile_x, tile_y, overlap)
|
||||
self.first_stage_model = self.first_stage_model.to(self.offload_device)
|
||||
return output.movedim(1,-1)
|
||||
|
||||
def encode(self, pixel_samples):
|
||||
model_management.unload_model()
|
||||
self.first_stage_model = self.first_stage_model.to(self.device)
|
||||
pixel_samples = pixel_samples.movedim(-1,1)
|
||||
try:
|
||||
memory_used = (2078 * pixel_samples.shape[2] * pixel_samples.shape[3]) * 1.4 #NOTE: this constant along with the one in the decode above are estimated from the mem usage for the VAE and could change.
|
||||
model_management.free_memory(memory_used, self.device)
|
||||
free_memory = model_management.get_free_memory(self.device)
|
||||
batch_number = int((free_memory * 0.7) / (2078 * pixel_samples.shape[2] * pixel_samples.shape[3])) #NOTE: this constant along with the one in the decode above are estimated from the mem usage for the VAE and could change.
|
||||
batch_number = int(free_memory / memory_used)
|
||||
batch_number = max(1, batch_number)
|
||||
samples = torch.empty((pixel_samples.shape[0], 4, round(pixel_samples.shape[2] // 8), round(pixel_samples.shape[3] // 8)), device="cpu")
|
||||
for x in range(0, pixel_samples.shape[0], batch_number):
|
||||
@ -677,7 +697,6 @@ class VAE:
|
||||
return samples
|
||||
|
||||
def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
|
||||
model_management.unload_model()
|
||||
self.first_stage_model = self.first_stage_model.to(self.device)
|
||||
pixel_samples = pixel_samples.movedim(-1,1)
|
||||
samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap)
|
||||
@ -757,6 +776,7 @@ class ControlNet(ControlBase):
|
||||
def __init__(self, control_model, global_average_pooling=False, device=None):
|
||||
super().__init__(device)
|
||||
self.control_model = control_model
|
||||
self.control_model_wrapped = ModelPatcher(self.control_model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device())
|
||||
self.global_average_pooling = global_average_pooling
|
||||
|
||||
def get_control(self, x_noisy, t, cond, batched_number):
|
||||
@ -786,11 +806,9 @@ class ControlNet(ControlBase):
|
||||
precision_scope = contextlib.nullcontext
|
||||
|
||||
with precision_scope(model_management.get_autocast_device(self.device)):
|
||||
self.control_model = model_management.load_if_low_vram(self.control_model)
|
||||
context = torch.cat(cond['c_crossattn'], 1)
|
||||
y = cond.get('c_adm', None)
|
||||
control = self.control_model(x=x_noisy, hint=self.cond_hint, timesteps=t, context=context, y=y)
|
||||
self.control_model = model_management.unload_if_low_vram(self.control_model)
|
||||
out = {'middle':[], 'output': []}
|
||||
autocast_enabled = torch.is_autocast_enabled()
|
||||
|
||||
@ -825,7 +843,7 @@ class ControlNet(ControlBase):
|
||||
|
||||
def get_models(self):
|
||||
out = super().get_models()
|
||||
out.append(self.control_model)
|
||||
out.append(self.control_model_wrapped)
|
||||
return out
|
||||
|
||||
|
||||
@ -1004,7 +1022,6 @@ class T2IAdapter(ControlBase):
|
||||
self.copy_to(c)
|
||||
return c
|
||||
|
||||
|
||||
def load_t2i_adapter(t2i_data):
|
||||
keys = t2i_data.keys()
|
||||
if 'adapter' in keys:
|
||||
@ -1090,7 +1107,7 @@ def load_gligen(ckpt_path):
|
||||
model = gligen.load_gligen(data)
|
||||
if model_management.should_use_fp16():
|
||||
model = model.half()
|
||||
return model
|
||||
return ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device())
|
||||
|
||||
def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_clip=True, embedding_directory=None, state_dict=None, config=None):
|
||||
#TODO: this function is a mess and should be removed eventually
|
||||
@ -1202,8 +1219,13 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
if output_clipvision:
|
||||
clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True)
|
||||
|
||||
dtype = torch.float32
|
||||
if fp16:
|
||||
dtype = torch.float16
|
||||
|
||||
inital_load_device = model_management.unet_inital_load_device(parameters, dtype)
|
||||
offload_device = model_management.unet_offload_device()
|
||||
model = model_config.get_model(sd, "model.diffusion_model.", device=offload_device)
|
||||
model = model_config.get_model(sd, "model.diffusion_model.", device=inital_load_device)
|
||||
model.load_model_weights(sd, "model.diffusion_model.")
|
||||
|
||||
if output_vae:
|
||||
@ -1224,7 +1246,12 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
if len(left_over) > 0:
|
||||
print("left over keys:", left_over)
|
||||
|
||||
return (ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae, clipvision)
|
||||
model_patcher = ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device(), current_device=inital_load_device)
|
||||
if inital_load_device != torch.device("cpu"):
|
||||
print("loaded straight to GPU")
|
||||
model_management.load_model_gpu(model_patcher)
|
||||
|
||||
return (model_patcher, clip, vae, clipvision)
|
||||
|
||||
|
||||
def load_unet(unet_path): #load unet in diffusers format
|
||||
|
@ -354,6 +354,7 @@ class PromptExecutor:
|
||||
d = self.outputs_ui.pop(x)
|
||||
del d
|
||||
|
||||
comfy.model_management.cleanup_models()
|
||||
if self.server.client_id is not None:
|
||||
self.server.send_sync("execution_cached", { "nodes": list(current_outputs) , "prompt_id": prompt_id}, self.server.client_id)
|
||||
executed = set()
|
||||
|
Loading…
Reference in New Issue
Block a user