diff --git a/nodes.py b/nodes.py index 490389ca..72a73e9f 100644 --- a/nodes.py +++ b/nodes.py @@ -102,6 +102,34 @@ class ConditioningAverage : out.append(n) return (out, ) +class ConditioningConcat: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "conditioning_to": ("CONDITIONING",), + "conditioning_from": ("CONDITIONING",), + }} + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "concat" + + CATEGORY = "advanced/conditioning" + + def concat(self, conditioning_to, conditioning_from): + out = [] + + if len(conditioning_from) > 1: + print("Warning: ConditioningConcat conditioning_from contains more than 1 cond, only the first one will actually be applied to conditioning_to.") + + cond_from = conditioning_from[0][0] + + for i in range(len(conditioning_to)): + t1 = conditioning_to[i][0] + tw = torch.cat((t1, cond_from),1) + n = [tw, conditioning_to[i][1].copy()] + out.append(n) + + return (out, ) + class ConditioningSetArea: @classmethod def INPUT_TYPES(s): @@ -1409,6 +1437,7 @@ NODE_CLASS_MAPPINGS = { "SaveLatent": SaveLatent, "ConditioningZeroOut": ConditioningZeroOut, + "ConditioningConcat": ConditioningConcat, } NODE_DISPLAY_NAME_MAPPINGS = {