From 5bea1d2ec95ca4519ed96b09b45fcd327e6dc751 Mon Sep 17 00:00:00 2001 From: yoinked Date: Wed, 11 Dec 2024 13:51:51 -0800 Subject: [PATCH] Add MaHiRo (improved/alternate CFG) (#5975) * Add MaHiRo (improved CFG) long explanation of what it is is [here](https://huggingface.co/spaces/yoinked/blue-arxiv) (2024-1208.1) note: if the node name has encoding issues (utf 8/whatever), id suggest to replace the face at the end with `(>w<)` * add it to nodes.py, add description, and make it a post_cfg function * fix * revert the sampler_cfg_function thing * switch cfg to args["denoised"] --- comfy_extras/nodes_mahiro.py | 41 ++++++++++++++++++++++++++++++++++++ nodes.py | 1 + 2 files changed, 42 insertions(+) create mode 100644 comfy_extras/nodes_mahiro.py diff --git a/comfy_extras/nodes_mahiro.py b/comfy_extras/nodes_mahiro.py new file mode 100644 index 00000000..8fcdfba7 --- /dev/null +++ b/comfy_extras/nodes_mahiro.py @@ -0,0 +1,41 @@ +import torch +import torch.nn.functional as F + +class Mahiro: + @classmethod + def INPUT_TYPES(s): + return {"required": {"model": ("MODEL",), + }} + RETURN_TYPES = ("MODEL",) + RETURN_NAMES = ("patched_model",) + FUNCTION = "patch" + CATEGORY = "_for_testing" + DESCRIPTION = "Modify the guidance to scale more on the 'direction' of the positive prompt rather than the difference between the negative prompt." + def patch(self, model): + m = model.clone() + def mahiro_normd(args): + scale: float = args['cond_scale'] + cond_p: torch.Tensor = args['cond_denoised'] + uncond_p: torch.Tensor = args['uncond_denoised'] + #naive leap + leap = cond_p * scale + #sim with uncond leap + u_leap = uncond_p * scale + cfg = args["denoised"] + merge = (leap + cfg) / 2 + normu = torch.sqrt(u_leap.abs()) * u_leap.sign() + normm = torch.sqrt(merge.abs()) * merge.sign() + sim = F.cosine_similarity(normu, normm).mean() + simsc = 2 * (sim+1) + wm = (simsc*cfg + (4-simsc)*leap) / 4 + return wm + m.set_model_sampler_post_cfg_function(mahiro_normd) + return (m, ) + +NODE_CLASS_MAPPINGS = { + "Mahiro": Mahiro +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "Mahiro": "Mahiro is so cute that she deserves a better guidance function!! (。・ω・。)", +} diff --git a/nodes.py b/nodes.py index b6e2465f..df412454 100644 --- a/nodes.py +++ b/nodes.py @@ -2147,6 +2147,7 @@ def init_builtin_extra_nodes(): "nodes_torch_compile.py", "nodes_mochi.py", "nodes_slg.py", + "nodes_mahiro.py", "nodes_lt.py", "nodes_hooks.py", ]