diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 05bcae6c..5ee7d5ae 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -394,15 +394,6 @@ class CrossAttention(nn.Module): out = rearrange(out, '(b h) n d -> b n (h d)', h=h) return self.to_out(out) -import sys -if XFORMERS_IS_AVAILBLE == False: - if "--use-split-cross-attention" in sys.argv: - print("Using split optimization for cross attention") - CrossAttention = CrossAttentionDoggettx - else: - print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention") - CrossAttention = CrossAttentionBirchSan - class MemoryEfficientCrossAttention(nn.Module): # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): @@ -451,23 +442,27 @@ class MemoryEfficientCrossAttention(nn.Module): ) return self.to_out(out) +import sys +if XFORMERS_IS_AVAILBLE == False: + if "--use-split-cross-attention" in sys.argv: + print("Using split optimization for cross attention") + CrossAttention = CrossAttentionDoggettx + else: + print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention") + CrossAttention = CrossAttentionBirchSan +else: + print("Using xformers cross attention") + CrossAttention = MemoryEfficientCrossAttention class BasicTransformerBlock(nn.Module): - ATTENTION_MODES = { - "softmax": CrossAttention, # vanilla attention - "softmax-xformers": MemoryEfficientCrossAttention - } def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, disable_self_attn=False): super().__init__() - attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax" - assert attn_mode in self.ATTENTION_MODES - attn_cls = self.ATTENTION_MODES[attn_mode] self.disable_self_attn = disable_self_attn - self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, + self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) - self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim, + self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none self.norm1 = nn.LayerNorm(dim) self.norm2 = nn.LayerNorm(dim)