mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-20 03:13:30 +00:00
Refactor the attention stuff in the VAE.
This commit is contained in:
parent
c8013f73e5
commit
23680a9155
@ -193,6 +193,52 @@ def slice_attention(q, k, v):
|
|||||||
|
|
||||||
return r1
|
return r1
|
||||||
|
|
||||||
|
def normal_attention(q, k, v):
|
||||||
|
# compute attention
|
||||||
|
b,c,h,w = q.shape
|
||||||
|
|
||||||
|
q = q.reshape(b,c,h*w)
|
||||||
|
q = q.permute(0,2,1) # b,hw,c
|
||||||
|
k = k.reshape(b,c,h*w) # b,c,hw
|
||||||
|
v = v.reshape(b,c,h*w)
|
||||||
|
|
||||||
|
r1 = slice_attention(q, k, v)
|
||||||
|
h_ = r1.reshape(b,c,h,w)
|
||||||
|
del r1
|
||||||
|
return h_
|
||||||
|
|
||||||
|
def xformers_attention(q, k, v):
|
||||||
|
# compute attention
|
||||||
|
B, C, H, W = q.shape
|
||||||
|
q, k, v = map(
|
||||||
|
lambda t: t.view(B, C, -1).transpose(1, 2).contiguous(),
|
||||||
|
(q, k, v),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
|
||||||
|
out = out.transpose(1, 2).reshape(B, C, H, W)
|
||||||
|
except NotImplementedError as e:
|
||||||
|
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def pytorch_attention(q, k, v):
|
||||||
|
# compute attention
|
||||||
|
B, C, H, W = q.shape
|
||||||
|
q, k, v = map(
|
||||||
|
lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(),
|
||||||
|
(q, k, v),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
||||||
|
out = out.transpose(2, 3).reshape(B, C, H, W)
|
||||||
|
except model_management.OOM_EXCEPTION as e:
|
||||||
|
print("scaled_dot_product_attention OOMed: switched to slice attention")
|
||||||
|
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
class AttnBlock(nn.Module):
|
class AttnBlock(nn.Module):
|
||||||
def __init__(self, in_channels):
|
def __init__(self, in_channels):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -220,6 +266,16 @@ class AttnBlock(nn.Module):
|
|||||||
stride=1,
|
stride=1,
|
||||||
padding=0)
|
padding=0)
|
||||||
|
|
||||||
|
if model_management.xformers_enabled_vae():
|
||||||
|
print("Using xformers attention in VAE")
|
||||||
|
self.optimized_attention = xformers_attention
|
||||||
|
elif model_management.pytorch_attention_enabled():
|
||||||
|
print("Using pytorch attention in VAE")
|
||||||
|
self.optimized_attention = pytorch_attention
|
||||||
|
else:
|
||||||
|
print("Using split attention in VAE")
|
||||||
|
self.optimized_attention = normal_attention
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
h_ = x
|
h_ = x
|
||||||
h_ = self.norm(h_)
|
h_ = self.norm(h_)
|
||||||
@ -227,149 +283,15 @@ class AttnBlock(nn.Module):
|
|||||||
k = self.k(h_)
|
k = self.k(h_)
|
||||||
v = self.v(h_)
|
v = self.v(h_)
|
||||||
|
|
||||||
# compute attention
|
h_ = self.optimized_attention(q, k, v)
|
||||||
b,c,h,w = q.shape
|
|
||||||
|
|
||||||
q = q.reshape(b,c,h*w)
|
|
||||||
q = q.permute(0,2,1) # b,hw,c
|
|
||||||
k = k.reshape(b,c,h*w) # b,c,hw
|
|
||||||
v = v.reshape(b,c,h*w)
|
|
||||||
|
|
||||||
r1 = slice_attention(q, k, v)
|
|
||||||
h_ = r1.reshape(b,c,h,w)
|
|
||||||
del r1
|
|
||||||
h_ = self.proj_out(h_)
|
h_ = self.proj_out(h_)
|
||||||
|
|
||||||
return x+h_
|
return x+h_
|
||||||
|
|
||||||
class MemoryEfficientAttnBlock(nn.Module):
|
|
||||||
"""
|
|
||||||
Uses xformers efficient implementation,
|
|
||||||
see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
|
||||||
Note: this is a single-head self-attention operation
|
|
||||||
"""
|
|
||||||
#
|
|
||||||
def __init__(self, in_channels):
|
|
||||||
super().__init__()
|
|
||||||
self.in_channels = in_channels
|
|
||||||
|
|
||||||
self.norm = Normalize(in_channels)
|
|
||||||
self.q = comfy.ops.Conv2d(in_channels,
|
|
||||||
in_channels,
|
|
||||||
kernel_size=1,
|
|
||||||
stride=1,
|
|
||||||
padding=0)
|
|
||||||
self.k = comfy.ops.Conv2d(in_channels,
|
|
||||||
in_channels,
|
|
||||||
kernel_size=1,
|
|
||||||
stride=1,
|
|
||||||
padding=0)
|
|
||||||
self.v = comfy.ops.Conv2d(in_channels,
|
|
||||||
in_channels,
|
|
||||||
kernel_size=1,
|
|
||||||
stride=1,
|
|
||||||
padding=0)
|
|
||||||
self.proj_out = comfy.ops.Conv2d(in_channels,
|
|
||||||
in_channels,
|
|
||||||
kernel_size=1,
|
|
||||||
stride=1,
|
|
||||||
padding=0)
|
|
||||||
self.attention_op: Optional[Any] = None
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
h_ = x
|
|
||||||
h_ = self.norm(h_)
|
|
||||||
q = self.q(h_)
|
|
||||||
k = self.k(h_)
|
|
||||||
v = self.v(h_)
|
|
||||||
|
|
||||||
# compute attention
|
|
||||||
B, C, H, W = q.shape
|
|
||||||
q, k, v = map(
|
|
||||||
lambda t: t.view(B, C, -1).transpose(1, 2).contiguous(),
|
|
||||||
(q, k, v),
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
|
|
||||||
out = out.transpose(1, 2).reshape(B, C, H, W)
|
|
||||||
except NotImplementedError as e:
|
|
||||||
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
|
|
||||||
|
|
||||||
out = self.proj_out(out)
|
|
||||||
return x+out
|
|
||||||
|
|
||||||
class MemoryEfficientAttnBlockPytorch(nn.Module):
|
|
||||||
def __init__(self, in_channels):
|
|
||||||
super().__init__()
|
|
||||||
self.in_channels = in_channels
|
|
||||||
|
|
||||||
self.norm = Normalize(in_channels)
|
|
||||||
self.q = comfy.ops.Conv2d(in_channels,
|
|
||||||
in_channels,
|
|
||||||
kernel_size=1,
|
|
||||||
stride=1,
|
|
||||||
padding=0)
|
|
||||||
self.k = comfy.ops.Conv2d(in_channels,
|
|
||||||
in_channels,
|
|
||||||
kernel_size=1,
|
|
||||||
stride=1,
|
|
||||||
padding=0)
|
|
||||||
self.v = comfy.ops.Conv2d(in_channels,
|
|
||||||
in_channels,
|
|
||||||
kernel_size=1,
|
|
||||||
stride=1,
|
|
||||||
padding=0)
|
|
||||||
self.proj_out = comfy.ops.Conv2d(in_channels,
|
|
||||||
in_channels,
|
|
||||||
kernel_size=1,
|
|
||||||
stride=1,
|
|
||||||
padding=0)
|
|
||||||
self.attention_op: Optional[Any] = None
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
h_ = x
|
|
||||||
h_ = self.norm(h_)
|
|
||||||
q = self.q(h_)
|
|
||||||
k = self.k(h_)
|
|
||||||
v = self.v(h_)
|
|
||||||
|
|
||||||
# compute attention
|
|
||||||
B, C, H, W = q.shape
|
|
||||||
q, k, v = map(
|
|
||||||
lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(),
|
|
||||||
(q, k, v),
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
|
||||||
out = out.transpose(2, 3).reshape(B, C, H, W)
|
|
||||||
except model_management.OOM_EXCEPTION as e:
|
|
||||||
print("scaled_dot_product_attention OOMed: switched to slice attention")
|
|
||||||
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
|
|
||||||
|
|
||||||
out = self.proj_out(out)
|
|
||||||
return x+out
|
|
||||||
|
|
||||||
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
|
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
|
||||||
assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown'
|
|
||||||
if model_management.xformers_enabled_vae() and attn_type == "vanilla":
|
|
||||||
attn_type = "vanilla-xformers"
|
|
||||||
elif model_management.pytorch_attention_enabled() and attn_type == "vanilla":
|
|
||||||
attn_type = "vanilla-pytorch"
|
|
||||||
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
|
|
||||||
if attn_type == "vanilla":
|
|
||||||
assert attn_kwargs is None
|
|
||||||
return AttnBlock(in_channels)
|
return AttnBlock(in_channels)
|
||||||
elif attn_type == "vanilla-xformers":
|
|
||||||
print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
|
|
||||||
return MemoryEfficientAttnBlock(in_channels)
|
|
||||||
elif attn_type == "vanilla-pytorch":
|
|
||||||
return MemoryEfficientAttnBlockPytorch(in_channels)
|
|
||||||
elif attn_type == "none":
|
|
||||||
return nn.Identity(in_channels)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
|
|
||||||
class Model(nn.Module):
|
class Model(nn.Module):
|
||||||
|
Loading…
Reference in New Issue
Block a user