Try to free enough vram for control lora inference.

This commit is contained in:
comfyanonymous 2023-08-24 17:20:54 -04:00
parent e3d0a9a490
commit 51dde87e97
4 changed files with 30 additions and 18 deletions

View File

@ -394,6 +394,12 @@ def cleanup_models():
x.model_unload() x.model_unload()
del x del x
def dtype_size(dtype):
dtype_size = 4
if dtype == torch.float16 or dtype == torch.bfloat16:
dtype_size = 2
return dtype_size
def unet_offload_device(): def unet_offload_device():
if vram_state == VRAMState.HIGH_VRAM: if vram_state == VRAMState.HIGH_VRAM:
return get_torch_device() return get_torch_device()
@ -409,11 +415,7 @@ def unet_inital_load_device(parameters, dtype):
if DISABLE_SMART_MEMORY: if DISABLE_SMART_MEMORY:
return cpu_dev return cpu_dev
dtype_size = 4 model_size = dtype_size(dtype) * parameters
if dtype == torch.float16 or dtype == torch.bfloat16:
dtype_size = 2
model_size = dtype_size * parameters
mem_dev = get_free_memory(torch_dev) mem_dev = get_free_memory(torch_dev)
mem_cpu = get_free_memory(cpu_dev) mem_cpu = get_free_memory(cpu_dev)

View File

@ -51,18 +51,20 @@ def get_models_from_cond(cond, model_type):
models += [c[1][model_type]] models += [c[1][model_type]]
return models return models
def get_additional_models(positive, negative): def get_additional_models(positive, negative, dtype):
"""loads additional models in positive and negative conditioning""" """loads additional models in positive and negative conditioning"""
control_nets = set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control")) control_nets = set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control"))
inference_memory = 0
control_models = [] control_models = []
for m in control_nets: for m in control_nets:
control_models += m.get_models() control_models += m.get_models()
inference_memory += m.inference_memory_requirements(dtype)
gligen = get_models_from_cond(positive, "gligen") + get_models_from_cond(negative, "gligen") gligen = get_models_from_cond(positive, "gligen") + get_models_from_cond(negative, "gligen")
gligen = [x[1] for x in gligen] gligen = [x[1] for x in gligen]
models = control_models + gligen models = control_models + gligen
return models return models, inference_memory
def cleanup_additional_models(models): def cleanup_additional_models(models):
"""cleanup additional models that were loaded""" """cleanup additional models that were loaded"""
@ -77,8 +79,8 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
noise_mask = prepare_mask(noise_mask, noise.shape, device) noise_mask = prepare_mask(noise_mask, noise.shape, device)
real_model = None real_model = None
models = get_additional_models(positive, negative) models, inference_memory = get_additional_models(positive, negative, model.model_dtype())
comfy.model_management.load_models_gpu([model] + models, comfy.model_management.batch_area_memory(noise.shape[0] * noise.shape[2] * noise.shape[3])) comfy.model_management.load_models_gpu([model] + models, comfy.model_management.batch_area_memory(noise.shape[0] * noise.shape[2] * noise.shape[3]) + inference_memory)
real_model = model.model real_model = model.model
noise = noise.to(device) noise = noise.to(device)

View File

@ -779,6 +779,11 @@ class ControlBase:
c.strength = self.strength c.strength = self.strength
c.timestep_percent_range = self.timestep_percent_range c.timestep_percent_range = self.timestep_percent_range
def inference_memory_requirements(self, dtype):
if self.previous_controlnet is not None:
return self.previous_controlnet.inference_memory_requirements(dtype)
return 0
def control_merge(self, control_input, control_output, control_prev, output_dtype): def control_merge(self, control_input, control_output, control_prev, output_dtype):
out = {'input':[], 'middle':[], 'output': []} out = {'input':[], 'middle':[], 'output': []}
@ -985,6 +990,9 @@ class ControlLora(ControlNet):
out = ControlBase.get_models(self) out = ControlBase.get_models(self)
return out return out
def inference_memory_requirements(self, dtype):
return utils.calculate_parameters(self.control_weights) * model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)
def load_controlnet(ckpt_path, model=None): def load_controlnet(ckpt_path, model=None):
controlnet_data = utils.load_torch_file(ckpt_path, safe_load=True) controlnet_data = utils.load_torch_file(ckpt_path, safe_load=True)
if "lora_controlnet" in controlnet_data: if "lora_controlnet" in controlnet_data:
@ -1323,13 +1331,6 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
return (ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae) return (ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae)
def calculate_parameters(sd, prefix):
params = 0
for k in sd.keys():
if k.startswith(prefix):
params += sd[k].nelement()
return params
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None): def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None):
sd = utils.load_torch_file(ckpt_path) sd = utils.load_torch_file(ckpt_path)
sd_keys = sd.keys() sd_keys = sd.keys()
@ -1339,7 +1340,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
model = None model = None
clip_target = None clip_target = None
parameters = calculate_parameters(sd, "model.diffusion_model.") parameters = utils.calculate_parameters(sd, "model.diffusion_model.")
fp16 = model_management.should_use_fp16(model_params=parameters) fp16 = model_management.should_use_fp16(model_params=parameters)
class WeightsLoader(torch.nn.Module): class WeightsLoader(torch.nn.Module):
@ -1390,7 +1391,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
def load_unet(unet_path): #load unet in diffusers format def load_unet(unet_path): #load unet in diffusers format
sd = utils.load_torch_file(unet_path) sd = utils.load_torch_file(unet_path)
parameters = calculate_parameters(sd, "") parameters = utils.calculate_parameters(sd)
fp16 = model_management.should_use_fp16(model_params=parameters) fp16 = model_management.should_use_fp16(model_params=parameters)
model_config = model_detection.model_config_from_diffusers_unet(sd, fp16) model_config = model_detection.model_config_from_diffusers_unet(sd, fp16)

View File

@ -32,6 +32,13 @@ def save_torch_file(sd, ckpt, metadata=None):
else: else:
safetensors.torch.save_file(sd, ckpt) safetensors.torch.save_file(sd, ckpt)
def calculate_parameters(sd, prefix=""):
params = 0
for k in sd.keys():
if k.startswith(prefix):
params += sd[k].nelement()
return params
def transformers_convert(sd, prefix_from, prefix_to, number): def transformers_convert(sd, prefix_from, prefix_to, number):
keys_to_replace = { keys_to_replace = {
"{}positional_embedding": "{}embeddings.position_embedding.weight", "{}positional_embedding": "{}embeddings.position_embedding.weight",