Remove unnecessary code.

This commit is contained in:
comfyanonymous 2024-05-18 09:36:26 -04:00
parent 1c4af5918a
commit 98f828fad9
2 changed files with 2 additions and 10 deletions

View File

@ -318,11 +318,7 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None):
return attention_pytorch(q, k, v, heads, mask)
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(b, -1, heads, dim_head)
.permute(0, 2, 1, 3)
.reshape(b * heads, -1, dim_head)
.contiguous(),
lambda t: t.reshape(b, -1, heads, dim_head),
(q, k, v),
)
@ -335,10 +331,7 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None):
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
out = (
out.unsqueeze(0)
.reshape(b, heads, -1, dim_head)
.permute(0, 2, 1, 3)
.reshape(b, -1, heads * dim_head)
out.reshape(b, -1, heads * dim_head)
)
return out

View File

@ -3,7 +3,6 @@ import math
import torch
import torch.nn as nn
import numpy as np
from einops import rearrange
from typing import Optional, Any
import logging