diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index b1d6d439..52e76b5f 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -379,6 +379,9 @@ class ModelPatcher: def set_model_sampler_pre_cfg_function(self, pre_cfg_function, disable_cfg1_optimization=False): self.model_options = set_model_options_pre_cfg_function(self.model_options, pre_cfg_function, disable_cfg1_optimization) + def set_model_sampler_calc_cond_batch_function(self, sampler_calc_cond_batch_function): + self.model_options["sampler_calc_cond_batch_function"] = sampler_calc_cond_batch_function + def set_model_unet_function_wrapper(self, unet_wrapper_function: UnetWrapperFunction): self.model_options["model_function_wrapper"] = unet_wrapper_function diff --git a/comfy/samplers.py b/comfy/samplers.py index 078a675f..25ccaf39 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -373,7 +373,11 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option uncond_ = uncond conds = [cond, uncond_] - out = calc_cond_batch(model, conds, x, timestep, model_options) + if "sampler_calc_cond_batch_function" in model_options: + args = {"conds": conds, "input": x, "sigma": timestep, "model": model, "model_options": model_options} + out = model_options["sampler_calc_cond_batch_function"](args) + else: + out = calc_cond_batch(model, conds, x, timestep, model_options) for fn in model_options.get("sampler_pre_cfg_function", []): args = {"conds":conds, "conds_out": out, "cond_scale": cond_scale, "timestep": timestep, diff --git a/comfy_extras/nodes_slg.py b/comfy_extras/nodes_slg.py index 2fa09e25..7adff202 100644 --- a/comfy_extras/nodes_slg.py +++ b/comfy_extras/nodes_slg.py @@ -78,7 +78,75 @@ class SkipLayerGuidanceDiT: return (m, ) +class SkipLayerGuidanceDiTSimple: + ''' + Simple version of the SkipLayerGuidanceDiT node that only modifies the uncond pass. + ''' + @classmethod + def INPUT_TYPES(s): + return {"required": {"model": ("MODEL", ), + "double_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}), + "single_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}), + "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "skip_guidance" + EXPERIMENTAL = True + + DESCRIPTION = "Simple version of the SkipLayerGuidanceDiT node that only modifies the uncond pass." + + CATEGORY = "advanced/guidance" + + def skip_guidance(self, model, start_percent, end_percent, double_layers="", single_layers=""): + def skip(args, extra_args): + return args + + model_sampling = model.get_model_object("model_sampling") + sigma_start = model_sampling.percent_to_sigma(start_percent) + sigma_end = model_sampling.percent_to_sigma(end_percent) + + double_layers = re.findall(r'\d+', double_layers) + double_layers = [int(i) for i in double_layers] + + single_layers = re.findall(r'\d+', single_layers) + single_layers = [int(i) for i in single_layers] + + if len(double_layers) == 0 and len(single_layers) == 0: + return (model, ) + + def calc_cond_batch_function(args): + x = args["input"] + model = args["model"] + conds = args["conds"] + sigma = args["sigma"] + + model_options = args["model_options"] + slg_model_options = model_options.copy() + + for layer in double_layers: + slg_model_options = comfy.model_patcher.set_model_options_patch_replace(slg_model_options, skip, "dit", "double_block", layer) + + for layer in single_layers: + slg_model_options = comfy.model_patcher.set_model_options_patch_replace(slg_model_options, skip, "dit", "single_block", layer) + + cond, uncond = conds + sigma_ = sigma[0].item() + if sigma_ >= sigma_end and sigma_ <= sigma_start and uncond is not None: + cond_out, _ = comfy.samplers.calc_cond_batch(model, [cond, None], x, sigma, model_options) + _, uncond_out = comfy.samplers.calc_cond_batch(model, [None, uncond], x, sigma, slg_model_options) + out = [cond_out, uncond_out] + else: + out = comfy.samplers.calc_cond_batch(model, conds, x, sigma, model_options) + + return out + + m = model.clone() + m.set_model_sampler_calc_cond_batch_function(calc_cond_batch_function) + + return (m, ) NODE_CLASS_MAPPINGS = { "SkipLayerGuidanceDiT": SkipLayerGuidanceDiT, + "SkipLayerGuidanceDiTSimple": SkipLayerGuidanceDiTSimple, }