mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-19 19:03:51 +00:00
Refactor attention upcasting code part 1.
This commit is contained in:
parent
2de3b69b30
commit
b0ab31d06c
@ -22,9 +22,9 @@ ops = comfy.ops.disable_weight_init
|
|||||||
# CrossAttn precision handling
|
# CrossAttn precision handling
|
||||||
if args.dont_upcast_attention:
|
if args.dont_upcast_attention:
|
||||||
logging.info("disabling upcasting of attention")
|
logging.info("disabling upcasting of attention")
|
||||||
_ATTN_PRECISION = "fp16"
|
_ATTN_PRECISION = None
|
||||||
else:
|
else:
|
||||||
_ATTN_PRECISION = "fp32"
|
_ATTN_PRECISION = torch.float32
|
||||||
|
|
||||||
|
|
||||||
def exists(val):
|
def exists(val):
|
||||||
@ -85,7 +85,7 @@ class FeedForward(nn.Module):
|
|||||||
def Normalize(in_channels, dtype=None, device=None):
|
def Normalize(in_channels, dtype=None, device=None):
|
||||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
|
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
|
||||||
|
|
||||||
def attention_basic(q, k, v, heads, mask=None):
|
def attention_basic(q, k, v, heads, mask=None, attn_precision=None):
|
||||||
b, _, dim_head = q.shape
|
b, _, dim_head = q.shape
|
||||||
dim_head //= heads
|
dim_head //= heads
|
||||||
scale = dim_head ** -0.5
|
scale = dim_head ** -0.5
|
||||||
@ -101,7 +101,7 @@ def attention_basic(q, k, v, heads, mask=None):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# force cast to fp32 to avoid overflowing
|
# force cast to fp32 to avoid overflowing
|
||||||
if _ATTN_PRECISION =="fp32":
|
if attn_precision == torch.float32:
|
||||||
sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale
|
sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale
|
||||||
else:
|
else:
|
||||||
sim = einsum('b i d, b j d -> b i j', q, k) * scale
|
sim = einsum('b i d, b j d -> b i j', q, k) * scale
|
||||||
@ -135,7 +135,7 @@ def attention_basic(q, k, v, heads, mask=None):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def attention_sub_quad(query, key, value, heads, mask=None):
|
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None):
|
||||||
b, _, dim_head = query.shape
|
b, _, dim_head = query.shape
|
||||||
dim_head //= heads
|
dim_head //= heads
|
||||||
|
|
||||||
@ -146,7 +146,7 @@ def attention_sub_quad(query, key, value, heads, mask=None):
|
|||||||
key = key.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 3, 1).reshape(b * heads, dim_head, -1)
|
key = key.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 3, 1).reshape(b * heads, dim_head, -1)
|
||||||
|
|
||||||
dtype = query.dtype
|
dtype = query.dtype
|
||||||
upcast_attention = _ATTN_PRECISION =="fp32" and query.dtype != torch.float32
|
upcast_attention = attn_precision == torch.float32 and query.dtype != torch.float32
|
||||||
if upcast_attention:
|
if upcast_attention:
|
||||||
bytes_per_token = torch.finfo(torch.float32).bits//8
|
bytes_per_token = torch.finfo(torch.float32).bits//8
|
||||||
else:
|
else:
|
||||||
@ -195,7 +195,7 @@ def attention_sub_quad(query, key, value, heads, mask=None):
|
|||||||
hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2)
|
hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def attention_split(q, k, v, heads, mask=None):
|
def attention_split(q, k, v, heads, mask=None, attn_precision=None):
|
||||||
b, _, dim_head = q.shape
|
b, _, dim_head = q.shape
|
||||||
dim_head //= heads
|
dim_head //= heads
|
||||||
scale = dim_head ** -0.5
|
scale = dim_head ** -0.5
|
||||||
@ -214,10 +214,12 @@ def attention_split(q, k, v, heads, mask=None):
|
|||||||
|
|
||||||
mem_free_total = model_management.get_free_memory(q.device)
|
mem_free_total = model_management.get_free_memory(q.device)
|
||||||
|
|
||||||
if _ATTN_PRECISION =="fp32":
|
if attn_precision == torch.float32:
|
||||||
element_size = 4
|
element_size = 4
|
||||||
|
upcast = True
|
||||||
else:
|
else:
|
||||||
element_size = q.element_size()
|
element_size = q.element_size()
|
||||||
|
upcast = False
|
||||||
|
|
||||||
gb = 1024 ** 3
|
gb = 1024 ** 3
|
||||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * element_size
|
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * element_size
|
||||||
@ -251,7 +253,7 @@ def attention_split(q, k, v, heads, mask=None):
|
|||||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||||
for i in range(0, q.shape[1], slice_size):
|
for i in range(0, q.shape[1], slice_size):
|
||||||
end = i + slice_size
|
end = i + slice_size
|
||||||
if _ATTN_PRECISION =="fp32":
|
if upcast:
|
||||||
with torch.autocast(enabled=False, device_type = 'cuda'):
|
with torch.autocast(enabled=False, device_type = 'cuda'):
|
||||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end].float(), k.float()) * scale
|
s1 = einsum('b i d, b j d -> b i j', q[:, i:end].float(), k.float()) * scale
|
||||||
else:
|
else:
|
||||||
@ -302,7 +304,7 @@ try:
|
|||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def attention_xformers(q, k, v, heads, mask=None):
|
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None):
|
||||||
b, _, dim_head = q.shape
|
b, _, dim_head = q.shape
|
||||||
dim_head //= heads
|
dim_head //= heads
|
||||||
if BROKEN_XFORMERS:
|
if BROKEN_XFORMERS:
|
||||||
@ -334,7 +336,7 @@ def attention_xformers(q, k, v, heads, mask=None):
|
|||||||
)
|
)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def attention_pytorch(q, k, v, heads, mask=None):
|
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None):
|
||||||
b, _, dim_head = q.shape
|
b, _, dim_head = q.shape
|
||||||
dim_head //= heads
|
dim_head //= heads
|
||||||
q, k, v = map(
|
q, k, v = map(
|
||||||
@ -409,9 +411,9 @@ class CrossAttention(nn.Module):
|
|||||||
v = self.to_v(context)
|
v = self.to_v(context)
|
||||||
|
|
||||||
if mask is None:
|
if mask is None:
|
||||||
out = optimized_attention(q, k, v, self.heads)
|
out = optimized_attention(q, k, v, self.heads, attn_precision=_ATTN_PRECISION)
|
||||||
else:
|
else:
|
||||||
out = optimized_attention_masked(q, k, v, self.heads, mask)
|
out = optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=_ATTN_PRECISION)
|
||||||
return self.to_out(out)
|
return self.to_out(out)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user