diff --git a/comfy_extras/nodes_sd3.py b/comfy_extras/nodes_sd3.py index ddf538de..6bd06f4a 100644 --- a/comfy_extras/nodes_sd3.py +++ b/comfy_extras/nodes_sd3.py @@ -3,7 +3,7 @@ import comfy.sd import comfy.model_management import nodes import torch - +import re class TripleCLIPLoader: @classmethod def INPUT_TYPES(s): @@ -95,11 +95,70 @@ class ControlNetApplySD3(nodes.ControlNetApplyAdvanced): CATEGORY = "conditioning/controlnet" DEPRECATED = True +class SkipLayerGuidanceSD3: + ''' + Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers. + Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377) + Experimental implementation by Dango233@StabilityAI. + ''' + @classmethod + def INPUT_TYPES(s): + return {"required": {"model": ("MODEL", ), + "layers": ("STRING", {"default": "7,8,9", "multiline": False}), + "scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 10.0, "step": 0.1}), + "start_percent": ("FLOAT", {"default": 0.01, "min": 0.0, "max": 1.0, "step": 0.001}), + "end_percent": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001}) + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "skip_guidance" + + CATEGORY = "advanced/guidance" + + + def skip_guidance(self, model, layers, scale, start_percent, end_percent): + if layers == "" or layers == None: + return (model, ) + # check if layer is comma separated integers + assert layers.replace(",", "").isdigit(), "Layers must be comma separated integers" + def skip(args, extra_args): + return args + + model_sampling = model.get_model_object("model_sampling") + + def post_cfg_function(args): + model = args["model"] + cond_pred = args["cond_denoised"] + cond = args["cond"] + cfg_result = args["denoised"] + sigma = args["sigma"] + x = args["input"] + model_options = args["model_options"].copy() + + for layer in layers: + model_options = comfy.model_patcher.set_model_options_patch_replace(model_options, skip, "dit", "double_block", layer) + model_sampling.percent_to_sigma(start_percent) + sigma_start = model_sampling.percent_to_sigma(start_percent) + sigma_end = model_sampling.percent_to_sigma(end_percent) + sigma_ = sigma[0].item() + if scale > 0 and sigma_ > sigma_end and sigma_ < sigma_start: + (slg,) = comfy.samplers.calc_cond_batch(model, [cond], x, sigma, model_options) + cfg_result = cfg_result + (cond_pred - slg) * scale + return cfg_result + + layers = re.findall(r'\d+', layers) + layers = [int(i) for i in layers] + m = model.clone() + m.set_model_sampler_post_cfg_function(post_cfg_function) + + return (m, ) + + NODE_CLASS_MAPPINGS = { "TripleCLIPLoader": TripleCLIPLoader, "EmptySD3LatentImage": EmptySD3LatentImage, "CLIPTextEncodeSD3": CLIPTextEncodeSD3, "ControlNetApplySD3": ControlNetApplySD3, + "SkipLayerGuidanceSD3": SkipLayerGuidanceSD3, } NODE_DISPLAY_NAME_MAPPINGS = {