mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Try to free enough vram for control lora inference.
This commit is contained in:
parent
e3d0a9a490
commit
51dde87e97
@ -394,6 +394,12 @@ def cleanup_models():
|
|||||||
x.model_unload()
|
x.model_unload()
|
||||||
del x
|
del x
|
||||||
|
|
||||||
|
def dtype_size(dtype):
|
||||||
|
dtype_size = 4
|
||||||
|
if dtype == torch.float16 or dtype == torch.bfloat16:
|
||||||
|
dtype_size = 2
|
||||||
|
return dtype_size
|
||||||
|
|
||||||
def unet_offload_device():
|
def unet_offload_device():
|
||||||
if vram_state == VRAMState.HIGH_VRAM:
|
if vram_state == VRAMState.HIGH_VRAM:
|
||||||
return get_torch_device()
|
return get_torch_device()
|
||||||
@ -409,11 +415,7 @@ def unet_inital_load_device(parameters, dtype):
|
|||||||
if DISABLE_SMART_MEMORY:
|
if DISABLE_SMART_MEMORY:
|
||||||
return cpu_dev
|
return cpu_dev
|
||||||
|
|
||||||
dtype_size = 4
|
model_size = dtype_size(dtype) * parameters
|
||||||
if dtype == torch.float16 or dtype == torch.bfloat16:
|
|
||||||
dtype_size = 2
|
|
||||||
|
|
||||||
model_size = dtype_size * parameters
|
|
||||||
|
|
||||||
mem_dev = get_free_memory(torch_dev)
|
mem_dev = get_free_memory(torch_dev)
|
||||||
mem_cpu = get_free_memory(cpu_dev)
|
mem_cpu = get_free_memory(cpu_dev)
|
||||||
|
@ -51,18 +51,20 @@ def get_models_from_cond(cond, model_type):
|
|||||||
models += [c[1][model_type]]
|
models += [c[1][model_type]]
|
||||||
return models
|
return models
|
||||||
|
|
||||||
def get_additional_models(positive, negative):
|
def get_additional_models(positive, negative, dtype):
|
||||||
"""loads additional models in positive and negative conditioning"""
|
"""loads additional models in positive and negative conditioning"""
|
||||||
control_nets = set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control"))
|
control_nets = set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control"))
|
||||||
|
|
||||||
|
inference_memory = 0
|
||||||
control_models = []
|
control_models = []
|
||||||
for m in control_nets:
|
for m in control_nets:
|
||||||
control_models += m.get_models()
|
control_models += m.get_models()
|
||||||
|
inference_memory += m.inference_memory_requirements(dtype)
|
||||||
|
|
||||||
gligen = get_models_from_cond(positive, "gligen") + get_models_from_cond(negative, "gligen")
|
gligen = get_models_from_cond(positive, "gligen") + get_models_from_cond(negative, "gligen")
|
||||||
gligen = [x[1] for x in gligen]
|
gligen = [x[1] for x in gligen]
|
||||||
models = control_models + gligen
|
models = control_models + gligen
|
||||||
return models
|
return models, inference_memory
|
||||||
|
|
||||||
def cleanup_additional_models(models):
|
def cleanup_additional_models(models):
|
||||||
"""cleanup additional models that were loaded"""
|
"""cleanup additional models that were loaded"""
|
||||||
@ -77,8 +79,8 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
|
|||||||
noise_mask = prepare_mask(noise_mask, noise.shape, device)
|
noise_mask = prepare_mask(noise_mask, noise.shape, device)
|
||||||
|
|
||||||
real_model = None
|
real_model = None
|
||||||
models = get_additional_models(positive, negative)
|
models, inference_memory = get_additional_models(positive, negative, model.model_dtype())
|
||||||
comfy.model_management.load_models_gpu([model] + models, comfy.model_management.batch_area_memory(noise.shape[0] * noise.shape[2] * noise.shape[3]))
|
comfy.model_management.load_models_gpu([model] + models, comfy.model_management.batch_area_memory(noise.shape[0] * noise.shape[2] * noise.shape[3]) + inference_memory)
|
||||||
real_model = model.model
|
real_model = model.model
|
||||||
|
|
||||||
noise = noise.to(device)
|
noise = noise.to(device)
|
||||||
|
19
comfy/sd.py
19
comfy/sd.py
@ -779,6 +779,11 @@ class ControlBase:
|
|||||||
c.strength = self.strength
|
c.strength = self.strength
|
||||||
c.timestep_percent_range = self.timestep_percent_range
|
c.timestep_percent_range = self.timestep_percent_range
|
||||||
|
|
||||||
|
def inference_memory_requirements(self, dtype):
|
||||||
|
if self.previous_controlnet is not None:
|
||||||
|
return self.previous_controlnet.inference_memory_requirements(dtype)
|
||||||
|
return 0
|
||||||
|
|
||||||
def control_merge(self, control_input, control_output, control_prev, output_dtype):
|
def control_merge(self, control_input, control_output, control_prev, output_dtype):
|
||||||
out = {'input':[], 'middle':[], 'output': []}
|
out = {'input':[], 'middle':[], 'output': []}
|
||||||
|
|
||||||
@ -985,6 +990,9 @@ class ControlLora(ControlNet):
|
|||||||
out = ControlBase.get_models(self)
|
out = ControlBase.get_models(self)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def inference_memory_requirements(self, dtype):
|
||||||
|
return utils.calculate_parameters(self.control_weights) * model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)
|
||||||
|
|
||||||
def load_controlnet(ckpt_path, model=None):
|
def load_controlnet(ckpt_path, model=None):
|
||||||
controlnet_data = utils.load_torch_file(ckpt_path, safe_load=True)
|
controlnet_data = utils.load_torch_file(ckpt_path, safe_load=True)
|
||||||
if "lora_controlnet" in controlnet_data:
|
if "lora_controlnet" in controlnet_data:
|
||||||
@ -1323,13 +1331,6 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
|||||||
|
|
||||||
return (ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae)
|
return (ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae)
|
||||||
|
|
||||||
def calculate_parameters(sd, prefix):
|
|
||||||
params = 0
|
|
||||||
for k in sd.keys():
|
|
||||||
if k.startswith(prefix):
|
|
||||||
params += sd[k].nelement()
|
|
||||||
return params
|
|
||||||
|
|
||||||
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None):
|
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None):
|
||||||
sd = utils.load_torch_file(ckpt_path)
|
sd = utils.load_torch_file(ckpt_path)
|
||||||
sd_keys = sd.keys()
|
sd_keys = sd.keys()
|
||||||
@ -1339,7 +1340,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
model = None
|
model = None
|
||||||
clip_target = None
|
clip_target = None
|
||||||
|
|
||||||
parameters = calculate_parameters(sd, "model.diffusion_model.")
|
parameters = utils.calculate_parameters(sd, "model.diffusion_model.")
|
||||||
fp16 = model_management.should_use_fp16(model_params=parameters)
|
fp16 = model_management.should_use_fp16(model_params=parameters)
|
||||||
|
|
||||||
class WeightsLoader(torch.nn.Module):
|
class WeightsLoader(torch.nn.Module):
|
||||||
@ -1390,7 +1391,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
|
|
||||||
def load_unet(unet_path): #load unet in diffusers format
|
def load_unet(unet_path): #load unet in diffusers format
|
||||||
sd = utils.load_torch_file(unet_path)
|
sd = utils.load_torch_file(unet_path)
|
||||||
parameters = calculate_parameters(sd, "")
|
parameters = utils.calculate_parameters(sd)
|
||||||
fp16 = model_management.should_use_fp16(model_params=parameters)
|
fp16 = model_management.should_use_fp16(model_params=parameters)
|
||||||
|
|
||||||
model_config = model_detection.model_config_from_diffusers_unet(sd, fp16)
|
model_config = model_detection.model_config_from_diffusers_unet(sd, fp16)
|
||||||
|
@ -32,6 +32,13 @@ def save_torch_file(sd, ckpt, metadata=None):
|
|||||||
else:
|
else:
|
||||||
safetensors.torch.save_file(sd, ckpt)
|
safetensors.torch.save_file(sd, ckpt)
|
||||||
|
|
||||||
|
def calculate_parameters(sd, prefix=""):
|
||||||
|
params = 0
|
||||||
|
for k in sd.keys():
|
||||||
|
if k.startswith(prefix):
|
||||||
|
params += sd[k].nelement()
|
||||||
|
return params
|
||||||
|
|
||||||
def transformers_convert(sd, prefix_from, prefix_to, number):
|
def transformers_convert(sd, prefix_from, prefix_to, number):
|
||||||
keys_to_replace = {
|
keys_to_replace = {
|
||||||
"{}positional_embedding": "{}embeddings.position_embedding.weight",
|
"{}positional_embedding": "{}embeddings.position_embedding.weight",
|
||||||
|
Loading…
Reference in New Issue
Block a user