Remove unnecessary code.

This commit is contained in:
comfyanonymous 2024-05-18 09:36:26 -04:00
parent 1c4af5918a
commit 98f828fad9
2 changed files with 2 additions and 10 deletions

View File

@ -318,11 +318,7 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None):
return attention_pytorch(q, k, v, heads, mask) return attention_pytorch(q, k, v, heads, mask)
q, k, v = map( q, k, v = map(
lambda t: t.unsqueeze(3) lambda t: t.reshape(b, -1, heads, dim_head),
.reshape(b, -1, heads, dim_head)
.permute(0, 2, 1, 3)
.reshape(b * heads, -1, dim_head)
.contiguous(),
(q, k, v), (q, k, v),
) )
@ -335,10 +331,7 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None):
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask) out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
out = ( out = (
out.unsqueeze(0) out.reshape(b, -1, heads * dim_head)
.reshape(b, heads, -1, dim_head)
.permute(0, 2, 1, 3)
.reshape(b, -1, heads * dim_head)
) )
return out return out

View File

@ -3,7 +3,6 @@ import math
import torch import torch
import torch.nn as nn import torch.nn as nn
import numpy as np import numpy as np
from einops import rearrange
from typing import Optional, Any from typing import Optional, Any
import logging import logging