diff --git a/comfy/sample.py b/comfy/sample.py index 5c8a7d13..3c65d0a8 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -52,9 +52,16 @@ def convert_cond(cond): out.append(temp) return out -def get_additional_models(positive, negative, dtype): - """loads additional models in positive and negative conditioning""" - control_nets = set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control")) +def get_additional_models(conds, dtype): + """loads additional models in conditioning""" + cnets = [] + gligen = [] + + for i in range(len(conds)): + cnets += get_models_from_cond(conds[i], "control") + gligen += get_models_from_cond(conds[i], "gligen") + + control_nets = set(cnets) inference_memory = 0 control_models = [] @@ -62,7 +69,6 @@ def get_additional_models(positive, negative, dtype): control_models += m.get_models() inference_memory += m.inference_memory_requirements(dtype) - gligen = get_models_from_cond(positive, "gligen") + get_models_from_cond(negative, "gligen") gligen = [x[1] for x in gligen] models = control_models + gligen return models, inference_memory @@ -73,24 +79,25 @@ def cleanup_additional_models(models): if hasattr(m, 'cleanup'): m.cleanup() -def prepare_sampling(model, noise_shape, positive, negative, noise_mask): +def prepare_sampling(model, noise_shape, conds, noise_mask): device = model.load_device - positive = convert_cond(positive) - negative = convert_cond(negative) + for i in range(len(conds)): + conds[i] = convert_cond(conds[i]) if noise_mask is not None: noise_mask = prepare_mask(noise_mask, noise_shape, device) real_model = None - models, inference_memory = get_additional_models(positive, negative, model.model_dtype()) + models, inference_memory = get_additional_models(conds, model.model_dtype()) comfy.model_management.load_models_gpu([model] + models, model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory) real_model = model.model - return real_model, positive, negative, noise_mask, models + return real_model, conds, noise_mask, models def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None): - real_model, positive_copy, negative_copy, noise_mask, models = prepare_sampling(model, noise.shape, positive, negative, noise_mask) + real_model, conds_copy, noise_mask, models = prepare_sampling(model, noise.shape, [positive, negative], noise_mask) + positive_copy, negative_copy = conds_copy noise = noise.to(model.load_device) latent_image = latent_image.to(model.load_device) @@ -105,14 +112,19 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative return samples def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None): - real_model, positive_copy, negative_copy, noise_mask, models = prepare_sampling(model, noise.shape, positive, negative, noise_mask) + real_model, conds, noise_mask, models = prepare_sampling(model, noise.shape, [positive, negative], noise_mask) noise = noise.to(model.load_device) latent_image = latent_image.to(model.load_device) sigmas = sigmas.to(model.load_device) - samples = comfy.samplers.sample(real_model, noise, positive_copy, negative_copy, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed) + samples = comfy.samplers.sample(real_model, noise, conds[0], conds[1], cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed) samples = samples.to(comfy.model_management.intermediate_device()) cleanup_additional_models(models) - cleanup_additional_models(set(get_models_from_cond(positive_copy, "control") + get_models_from_cond(negative_copy, "control"))) + + control_cleanup = [] + for i in range(len(conds)): + control_cleanup += get_models_from_cond(conds[i], "control") + + cleanup_additional_models(set(control_cleanup)) return samples diff --git a/comfy/samplers.py b/comfy/samplers.py index 204a98f9..f18de200 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -260,11 +260,12 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option return cfg_result class CFGNoisePredictor(torch.nn.Module): - def __init__(self, model): + def __init__(self, model, cond_scale=1.0): super().__init__() self.inner_model = model - def apply_model(self, x, timestep, cond, uncond, cond_scale, model_options={}, seed=None): - out = sampling_function(self.inner_model, x, timestep, uncond, cond, cond_scale, model_options=model_options, seed=seed) + self.cond_scale = cond_scale + def apply_model(self, x, timestep, conds, model_options={}, seed=None): + out = sampling_function(self.inner_model, x, timestep, conds.get("negative", None), conds.get("positive", None), self.cond_scale, model_options=model_options, seed=seed) return out def forward(self, *args, **kwargs): return self.apply_model(*args, **kwargs) @@ -274,13 +275,13 @@ class KSamplerX0Inpaint(torch.nn.Module): super().__init__() self.inner_model = model self.sigmas = sigmas - def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, model_options={}, seed=None): + def forward(self, x, sigma, conds, denoise_mask, model_options={}, seed=None): if denoise_mask is not None: if "denoise_mask_function" in model_options: denoise_mask = model_options["denoise_mask_function"](sigma, denoise_mask, extra_options={"model": self.inner_model, "sigmas": self.sigmas}) latent_mask = 1. - denoise_mask x = x * denoise_mask + self.inner_model.inner_model.model_sampling.noise_scaling(sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1)), self.noise, self.latent_image) * latent_mask - out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, model_options=model_options, seed=seed) + out = self.inner_model(x, sigma, conds=conds, model_options=model_options, seed=seed) if denoise_mask is not None: out = out * denoise_mask + self.latent_image * latent_mask return out @@ -568,45 +569,56 @@ def ksampler(sampler_name, extra_options={}, inpaint_options={}): return KSAMPLER(sampler_function, extra_options, inpaint_options) -def wrap_model(model): - model_denoise = CFGNoisePredictor(model) - return model_denoise -def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None): - positive = positive[:] - negative = negative[:] +def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=None, seed=None): + for k in conds: + conds[k] = conds[k][:] + resolve_areas_and_cond_masks(conds[k], noise.shape[2], noise.shape[3], device) - resolve_areas_and_cond_masks(positive, noise.shape[2], noise.shape[3], device) - resolve_areas_and_cond_masks(negative, noise.shape[2], noise.shape[3], device) + for k in conds: + calculate_start_end_timesteps(model, conds[k]) - model_wrap = wrap_model(model) + if hasattr(model, 'extra_conds'): + for k in conds: + conds[k] = encode_model_conds(model.extra_conds, conds[k], noise, device, k, latent_image=latent_image, denoise_mask=denoise_mask, seed=seed) - calculate_start_end_timesteps(model, negative) - calculate_start_end_timesteps(model, positive) + #make sure each cond area has an opposite one with the same area + for k in conds: + for c in conds[k]: + for kk in conds: + if k != kk: + create_cond_with_same_area_if_none(conds[kk], c) + for k in conds: + pre_run_control(model, conds[k]) + + if "positive" in conds: + positive = conds["positive"] + for k in conds: + if k != "positive": + apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), conds[k], 'control', lambda cond_cnets, x: cond_cnets[x]) + apply_empty_x_to_equal_area(positive, conds[k], 'gligen', lambda cond_cnets, x: cond_cnets[x]) + + return conds + + +def sample_advanced(model, noise, conds, guider_class, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None): if latent_image is not None and torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image. latent_image = model.process_latent_in(latent_image) - if hasattr(model, 'extra_conds'): - positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask, seed=seed) - negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask, seed=seed) + conds = process_conds(model, noise, conds, device, latent_image, denoise_mask, seed) + model_wrap = guider_class(model) - #make sure each cond area has an opposite one with the same area - for c in positive: - create_cond_with_same_area_if_none(negative, c) - for c in negative: - create_cond_with_same_area_if_none(positive, c) - - pre_run_control(model, negative + positive) - - apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x]) - apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x]) - - extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed} + extra_args = {"conds": conds, "model_options": model_options, "seed":seed} samples = sampler.sample(model_wrap, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar) return model.process_latent_out(samples.to(torch.float32)) + +def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None): + return sample_advanced(model, noise, {"positive": positive, "negative": negative}, lambda a: CFGNoisePredictor(a, cfg), device, sampler, sigmas, model_options, latent_image, denoise_mask, callback, disable_pbar, seed) + + SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"] SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"]