diff --git a/comfy_extras/nodes_mahiro.py b/comfy_extras/nodes_mahiro.py new file mode 100644 index 00000000..8fcdfba7 --- /dev/null +++ b/comfy_extras/nodes_mahiro.py @@ -0,0 +1,41 @@ +import torch +import torch.nn.functional as F + +class Mahiro: + @classmethod + def INPUT_TYPES(s): + return {"required": {"model": ("MODEL",), + }} + RETURN_TYPES = ("MODEL",) + RETURN_NAMES = ("patched_model",) + FUNCTION = "patch" + CATEGORY = "_for_testing" + DESCRIPTION = "Modify the guidance to scale more on the 'direction' of the positive prompt rather than the difference between the negative prompt." + def patch(self, model): + m = model.clone() + def mahiro_normd(args): + scale: float = args['cond_scale'] + cond_p: torch.Tensor = args['cond_denoised'] + uncond_p: torch.Tensor = args['uncond_denoised'] + #naive leap + leap = cond_p * scale + #sim with uncond leap + u_leap = uncond_p * scale + cfg = args["denoised"] + merge = (leap + cfg) / 2 + normu = torch.sqrt(u_leap.abs()) * u_leap.sign() + normm = torch.sqrt(merge.abs()) * merge.sign() + sim = F.cosine_similarity(normu, normm).mean() + simsc = 2 * (sim+1) + wm = (simsc*cfg + (4-simsc)*leap) / 4 + return wm + m.set_model_sampler_post_cfg_function(mahiro_normd) + return (m, ) + +NODE_CLASS_MAPPINGS = { + "Mahiro": Mahiro +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "Mahiro": "Mahiro is so cute that she deserves a better guidance function!! (。・ω・。)", +} diff --git a/nodes.py b/nodes.py index b6e2465f..df412454 100644 --- a/nodes.py +++ b/nodes.py @@ -2147,6 +2147,7 @@ def init_builtin_extra_nodes(): "nodes_torch_compile.py", "nodes_mochi.py", "nodes_slg.py", + "nodes_mahiro.py", "nodes_lt.py", "nodes_hooks.py", ]