From 11200de9700aed41011ed865a164f43d27b62d82 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 8 Aug 2024 20:07:09 -0400 Subject: [PATCH] Cleaner code. --- comfy/ldm/flux/layers.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py index 4a0bd40c..643e7c67 100644 --- a/comfy/ldm/flux/layers.py +++ b/comfy/ldm/flux/layers.py @@ -8,6 +8,7 @@ from torch import Tensor, nn from .math import attention, rope import comfy.ops + class EmbedND(nn.Module): def __init__(self, dim: int, theta: int, axes_dim: list): super().__init__() @@ -174,20 +175,19 @@ class DoubleStreamBlock(nn.Module): txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) # run actual attention - q = torch.cat((txt_q, img_q), dim=2) - k = torch.cat((txt_k, img_k), dim=2) - v = torch.cat((txt_v, img_v), dim=2) + attn = attention(torch.cat((txt_q, img_q), dim=2), + torch.cat((txt_k, img_k), dim=2), + torch.cat((txt_v, img_v), dim=2), pe=pe) - attn = attention(q, k, v, pe=pe) txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] # calculate the img bloks - img = img + img_mod1.gate * self.img_attn.proj(img_attn) - img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift) + img += img_mod1.gate * self.img_attn.proj(img_attn) + img += img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift) # calculate the txt bloks - txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn) - txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) + txt += txt_mod1.gate * self.txt_attn.proj(txt_attn) + txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) if txt.dtype == torch.float16: txt = txt.clip(-65504, 65504) @@ -243,7 +243,7 @@ class SingleStreamBlock(nn.Module): attn = attention(q, k, v, pe=pe) # compute activation in mlp stream, cat again and run second linear layer output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) - x = x + mod.gate * output + x += mod.gate * output if x.dtype == torch.float16: x = x.clip(-65504, 65504) return x