mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 10:25:16 +00:00
Hopefully fix a strange issue with xformers + lowvram.
This commit is contained in:
parent
5f0f97634f
commit
9502ee45c3
@ -394,15 +394,6 @@ class CrossAttention(nn.Module):
|
|||||||
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
||||||
return self.to_out(out)
|
return self.to_out(out)
|
||||||
|
|
||||||
import sys
|
|
||||||
if XFORMERS_IS_AVAILBLE == False:
|
|
||||||
if "--use-split-cross-attention" in sys.argv:
|
|
||||||
print("Using split optimization for cross attention")
|
|
||||||
CrossAttention = CrossAttentionDoggettx
|
|
||||||
else:
|
|
||||||
print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
|
|
||||||
CrossAttention = CrossAttentionBirchSan
|
|
||||||
|
|
||||||
class MemoryEfficientCrossAttention(nn.Module):
|
class MemoryEfficientCrossAttention(nn.Module):
|
||||||
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
||||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
|
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
|
||||||
@ -451,23 +442,27 @@ class MemoryEfficientCrossAttention(nn.Module):
|
|||||||
)
|
)
|
||||||
return self.to_out(out)
|
return self.to_out(out)
|
||||||
|
|
||||||
|
import sys
|
||||||
|
if XFORMERS_IS_AVAILBLE == False:
|
||||||
|
if "--use-split-cross-attention" in sys.argv:
|
||||||
|
print("Using split optimization for cross attention")
|
||||||
|
CrossAttention = CrossAttentionDoggettx
|
||||||
|
else:
|
||||||
|
print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
|
||||||
|
CrossAttention = CrossAttentionBirchSan
|
||||||
|
else:
|
||||||
|
print("Using xformers cross attention")
|
||||||
|
CrossAttention = MemoryEfficientCrossAttention
|
||||||
|
|
||||||
class BasicTransformerBlock(nn.Module):
|
class BasicTransformerBlock(nn.Module):
|
||||||
ATTENTION_MODES = {
|
|
||||||
"softmax": CrossAttention, # vanilla attention
|
|
||||||
"softmax-xformers": MemoryEfficientCrossAttention
|
|
||||||
}
|
|
||||||
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
|
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
|
||||||
disable_self_attn=False):
|
disable_self_attn=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
|
|
||||||
assert attn_mode in self.ATTENTION_MODES
|
|
||||||
attn_cls = self.ATTENTION_MODES[attn_mode]
|
|
||||||
self.disable_self_attn = disable_self_attn
|
self.disable_self_attn = disable_self_attn
|
||||||
self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
|
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
|
||||||
context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn
|
context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn
|
||||||
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
||||||
self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim,
|
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
|
||||||
heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
|
heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
|
||||||
self.norm1 = nn.LayerNorm(dim)
|
self.norm1 = nn.LayerNorm(dim)
|
||||||
self.norm2 = nn.LayerNorm(dim)
|
self.norm2 = nn.LayerNorm(dim)
|
||||||
|
Loading…
Reference in New Issue
Block a user