Update optimized_attention_for_device function for new functions that

support masked attention.
This commit is contained in:
comfyanonymous 2024-01-07 13:52:08 -05:00
parent aaa9017302
commit c6951548cf
2 changed files with 9 additions and 10 deletions

View File

@ -57,7 +57,7 @@ class CLIPEncoder(torch.nn.Module):
self.layers = torch.nn.ModuleList([CLIPLayer(embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations) for i in range(num_layers)]) self.layers = torch.nn.ModuleList([CLIPLayer(embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations) for i in range(num_layers)])
def forward(self, x, mask=None, intermediate_output=None): def forward(self, x, mask=None, intermediate_output=None):
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None) optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
if intermediate_output is not None: if intermediate_output is not None:
if intermediate_output < 0: if intermediate_output < 0:

View File

@ -333,7 +333,6 @@ def attention_pytorch(q, k, v, heads, mask=None):
optimized_attention = attention_basic optimized_attention = attention_basic
optimized_attention_masked = attention_basic
if model_management.xformers_enabled(): if model_management.xformers_enabled():
print("Using xformers cross attention") print("Using xformers cross attention")
@ -349,15 +348,15 @@ else:
print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention") print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
optimized_attention = attention_sub_quad optimized_attention = attention_sub_quad
if model_management.pytorch_attention_enabled(): optimized_attention_masked = optimized_attention
optimized_attention_masked = attention_pytorch
def optimized_attention_for_device(device, mask=False, small_input=False):
if small_input and model_management.pytorch_attention_enabled():
return attention_pytorch #TODO: need to confirm but this is probably slightly faster for small inputs in all cases
if device == torch.device("cpu"):
return attention_sub_quad
def optimized_attention_for_device(device, mask=False):
if device == torch.device("cpu"): #TODO
if model_management.pytorch_attention_enabled():
return attention_pytorch
else:
return attention_basic
if mask: if mask:
return optimized_attention_masked return optimized_attention_masked