diff --git a/comfy/model_management.py b/comfy/model_management.py index 544a945b..f1873a34 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -394,6 +394,12 @@ def cleanup_models(): x.model_unload() del x +def dtype_size(dtype): + dtype_size = 4 + if dtype == torch.float16 or dtype == torch.bfloat16: + dtype_size = 2 + return dtype_size + def unet_offload_device(): if vram_state == VRAMState.HIGH_VRAM: return get_torch_device() @@ -409,11 +415,7 @@ def unet_inital_load_device(parameters, dtype): if DISABLE_SMART_MEMORY: return cpu_dev - dtype_size = 4 - if dtype == torch.float16 or dtype == torch.bfloat16: - dtype_size = 2 - - model_size = dtype_size * parameters + model_size = dtype_size(dtype) * parameters mem_dev = get_free_memory(torch_dev) mem_cpu = get_free_memory(cpu_dev) diff --git a/comfy/sample.py b/comfy/sample.py index d7292024..79ea37e0 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -51,18 +51,20 @@ def get_models_from_cond(cond, model_type): models += [c[1][model_type]] return models -def get_additional_models(positive, negative): +def get_additional_models(positive, negative, dtype): """loads additional models in positive and negative conditioning""" control_nets = set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control")) + inference_memory = 0 control_models = [] for m in control_nets: control_models += m.get_models() + inference_memory += m.inference_memory_requirements(dtype) gligen = get_models_from_cond(positive, "gligen") + get_models_from_cond(negative, "gligen") gligen = [x[1] for x in gligen] models = control_models + gligen - return models + return models, inference_memory def cleanup_additional_models(models): """cleanup additional models that were loaded""" @@ -77,8 +79,8 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative noise_mask = prepare_mask(noise_mask, noise.shape, device) real_model = None - models = get_additional_models(positive, negative) - comfy.model_management.load_models_gpu([model] + models, comfy.model_management.batch_area_memory(noise.shape[0] * noise.shape[2] * noise.shape[3])) + models, inference_memory = get_additional_models(positive, negative, model.model_dtype()) + comfy.model_management.load_models_gpu([model] + models, comfy.model_management.batch_area_memory(noise.shape[0] * noise.shape[2] * noise.shape[3]) + inference_memory) real_model = model.model noise = noise.to(device) diff --git a/comfy/sd.py b/comfy/sd.py index 89df5a77..3568a2aa 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -779,6 +779,11 @@ class ControlBase: c.strength = self.strength c.timestep_percent_range = self.timestep_percent_range + def inference_memory_requirements(self, dtype): + if self.previous_controlnet is not None: + return self.previous_controlnet.inference_memory_requirements(dtype) + return 0 + def control_merge(self, control_input, control_output, control_prev, output_dtype): out = {'input':[], 'middle':[], 'output': []} @@ -985,6 +990,9 @@ class ControlLora(ControlNet): out = ControlBase.get_models(self) return out + def inference_memory_requirements(self, dtype): + return utils.calculate_parameters(self.control_weights) * model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype) + def load_controlnet(ckpt_path, model=None): controlnet_data = utils.load_torch_file(ckpt_path, safe_load=True) if "lora_controlnet" in controlnet_data: @@ -1323,13 +1331,6 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl return (ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae) -def calculate_parameters(sd, prefix): - params = 0 - for k in sd.keys(): - if k.startswith(prefix): - params += sd[k].nelement() - return params - def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None): sd = utils.load_torch_file(ckpt_path) sd_keys = sd.keys() @@ -1339,7 +1340,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o model = None clip_target = None - parameters = calculate_parameters(sd, "model.diffusion_model.") + parameters = utils.calculate_parameters(sd, "model.diffusion_model.") fp16 = model_management.should_use_fp16(model_params=parameters) class WeightsLoader(torch.nn.Module): @@ -1390,7 +1391,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o def load_unet(unet_path): #load unet in diffusers format sd = utils.load_torch_file(unet_path) - parameters = calculate_parameters(sd, "") + parameters = utils.calculate_parameters(sd) fp16 = model_management.should_use_fp16(model_params=parameters) model_config = model_detection.model_config_from_diffusers_unet(sd, fp16) diff --git a/comfy/utils.py b/comfy/utils.py index 3bbe4f9a..e69125ab 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -32,6 +32,13 @@ def save_torch_file(sd, ckpt, metadata=None): else: safetensors.torch.save_file(sd, ckpt) +def calculate_parameters(sd, prefix=""): + params = 0 + for k in sd.keys(): + if k.startswith(prefix): + params += sd[k].nelement() + return params + def transformers_convert(sd, prefix_from, prefix_to, number): keys_to_replace = { "{}positional_embedding": "{}embeddings.position_embedding.weight",