From 574363a8a69cb48db71c96d03fe056d56853f4f6 Mon Sep 17 00:00:00 2001 From: Hari Date: Sat, 16 Dec 2023 00:28:16 +0530 Subject: [PATCH] Implement Perp-Neg --- comfy/samplers.py | 3 +- comfy_extras/nodes_perpneg.py | 58 +++++++++++++++++++++++++++++++++++ nodes.py | 1 + 3 files changed, 61 insertions(+), 1 deletion(-) create mode 100644 comfy_extras/nodes_perpneg.py diff --git a/comfy/samplers.py b/comfy/samplers.py index 39bc3774..35c9ccf0 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -251,7 +251,8 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option cond_pred, uncond_pred = calc_cond_uncond_batch(model, cond, uncond_, x, timestep, model_options) if "sampler_cfg_function" in model_options: - args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep} + args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep, + "cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options} cfg_result = x - model_options["sampler_cfg_function"](args) else: cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale diff --git a/comfy_extras/nodes_perpneg.py b/comfy_extras/nodes_perpneg.py new file mode 100644 index 00000000..36f2eb01 --- /dev/null +++ b/comfy_extras/nodes_perpneg.py @@ -0,0 +1,58 @@ +import torch +import comfy.model_management +import comfy.sample +import comfy.samplers +import comfy.utils + + +class PerpNeg: + @classmethod + def INPUT_TYPES(s): + return {"required": {"model": ("MODEL", ), + "clip": ("CLIP", ), + "neg_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "_for_testing" + + def patch(self, model, clip, neg_scale): + m = model.clone() + + tokens = clip.tokenize("") + nocond, nocond_pooled = clip.encode_from_tokens(tokens, return_pooled=True) + nocond = [[nocond, {"pooled_output": nocond_pooled}]] + nocond = comfy.sample.convert_cond(nocond) + + def cfg_function(args): + model = args["model"] + noise_pred_pos = args["cond_denoised"] + noise_pred_neg = args["uncond_denoised"] + cond_scale = args["cond_scale"] + x = args["input"] + sigma = args["sigma"] + model_options = args["model_options"] + + (noise_pred_nocond, _) = comfy.samplers.calc_cond_uncond_batch(model, nocond, None, x, sigma, model_options) + + pos = noise_pred_pos - noise_pred_nocond + neg = noise_pred_neg - noise_pred_nocond + perp = ((torch.mul(pos, neg).sum())/(torch.norm(neg)**2)) * neg + perp_neg = perp * neg_scale + cfg_result = noise_pred_nocond + cond_scale*(pos - perp_neg) + cfg_result = x - cfg_result + return cfg_result + + m.set_model_sampler_cfg_function(cfg_function) + + return (m, ) + + +NODE_CLASS_MAPPINGS = { + "PerpNeg": PerpNeg, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "PerpNeg": "Perp-Neg", +} diff --git a/nodes.py b/nodes.py index 3d24750c..3031b10a 100644 --- a/nodes.py +++ b/nodes.py @@ -1868,6 +1868,7 @@ def init_custom_nodes(): "nodes_images.py", "nodes_video_model.py", "nodes_sag.py", + "nodes_perpneg.py", ] for node_file in extras_files: