diff --git a/nodes.py b/nodes.py index 3c096009..f10515f8 100644 --- a/nodes.py +++ b/nodes.py @@ -148,6 +148,25 @@ class ConditioningSetMask: c.append(n) return (c, ) +class ConditioningZeroOut: + @classmethod + def INPUT_TYPES(s): + return {"required": {"conditioning": ("CONDITIONING", )}} + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "zero_out" + + CATEGORY = "advanced/conditioning" + + def zero_out(self, conditioning): + c = [] + for t in conditioning: + d = t[1].copy() + if "pooled_output" in d: + d["pooled_output"] = torch.zeros_like(d["pooled_output"]) + n = [torch.zeros_like(t[0]), d] + c.append(n) + return (c, ) + class VAEDecode: @classmethod def INPUT_TYPES(s): @@ -1350,6 +1369,8 @@ NODE_CLASS_MAPPINGS = { "LoadLatent": LoadLatent, "SaveLatent": SaveLatent, + + "ConditioningZeroOut": ConditioningZeroOut, } NODE_DISPLAY_NAME_MAPPINGS = {