mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-03-12 22:02:14 +00:00
ConditioningAverage now also averages the pooled output.
This commit is contained in:
parent
d94ddd8548
commit
3a09fac835
10
nodes.py
10
nodes.py
@ -82,15 +82,23 @@ class ConditioningAverage :
|
|||||||
print("Warning: ConditioningAverage conditioning_from contains more than 1 cond, only the first one will actually be applied to conditioning_to.")
|
print("Warning: ConditioningAverage conditioning_from contains more than 1 cond, only the first one will actually be applied to conditioning_to.")
|
||||||
|
|
||||||
cond_from = conditioning_from[0][0]
|
cond_from = conditioning_from[0][0]
|
||||||
|
pooled_output_from = conditioning_from[0][1].get("pooled_output", None)
|
||||||
|
|
||||||
for i in range(len(conditioning_to)):
|
for i in range(len(conditioning_to)):
|
||||||
t1 = conditioning_to[i][0]
|
t1 = conditioning_to[i][0]
|
||||||
|
pooled_output_to = conditioning_to[i][1].get("pooled_output", pooled_output_from)
|
||||||
t0 = cond_from[:,:t1.shape[1]]
|
t0 = cond_from[:,:t1.shape[1]]
|
||||||
if t0.shape[1] < t1.shape[1]:
|
if t0.shape[1] < t1.shape[1]:
|
||||||
t0 = torch.cat([t0] + [torch.zeros((1, (t1.shape[1] - t0.shape[1]), t1.shape[2]))], dim=1)
|
t0 = torch.cat([t0] + [torch.zeros((1, (t1.shape[1] - t0.shape[1]), t1.shape[2]))], dim=1)
|
||||||
|
|
||||||
tw = torch.mul(t1, conditioning_to_strength) + torch.mul(t0, (1.0 - conditioning_to_strength))
|
tw = torch.mul(t1, conditioning_to_strength) + torch.mul(t0, (1.0 - conditioning_to_strength))
|
||||||
n = [tw, conditioning_to[i][1].copy()]
|
t_to = conditioning_to[i][1].copy()
|
||||||
|
if pooled_output_from is not None and pooled_output_to is not None:
|
||||||
|
t_to["pooled_output"] = torch.mul(pooled_output_to, conditioning_to_strength) + torch.mul(pooled_output_from, (1.0 - conditioning_to_strength))
|
||||||
|
elif pooled_output_from is not None:
|
||||||
|
t_to["pooled_output"] = pooled_output_from
|
||||||
|
|
||||||
|
n = [tw, t_to]
|
||||||
out.append(n)
|
out.append(n)
|
||||||
return (out, )
|
return (out, )
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user