mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
attention_basic now works with hypertile.
This commit is contained in:
parent
e6bc42df46
commit
8b65f5de54
@ -95,9 +95,19 @@ def Normalize(in_channels, dtype=None, device=None):
|
||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
|
||||
|
||||
def attention_basic(q, k, v, heads, mask=None):
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
scale = dim_head ** -0.5
|
||||
|
||||
h = heads
|
||||
scale = (q.shape[-1] // heads) ** -0.5
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||
q, k, v = map(
|
||||
lambda t: t.unsqueeze(3)
|
||||
.reshape(b, -1, heads, dim_head)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(b * heads, -1, dim_head)
|
||||
.contiguous(),
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
# force cast to fp32 to avoid overflowing
|
||||
if _ATTN_PRECISION =="fp32":
|
||||
@ -119,7 +129,12 @@ def attention_basic(q, k, v, heads, mask=None):
|
||||
sim = sim.softmax(dim=-1)
|
||||
|
||||
out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v)
|
||||
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
||||
out = (
|
||||
out.unsqueeze(0)
|
||||
.reshape(b, heads, -1, dim_head)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(b, -1, heads * dim_head)
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user