From c6951548cfec64c28082e6560c69c59e32729c9c Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 7 Jan 2024 13:52:08 -0500 Subject: [PATCH] Update optimized_attention_for_device function for new functions that support masked attention. --- comfy/clip_model.py | 2 +- comfy/ldm/modules/attention.py | 17 ++++++++--------- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/comfy/clip_model.py b/comfy/clip_model.py index 7397b7a2..09e7bbca 100644 --- a/comfy/clip_model.py +++ b/comfy/clip_model.py @@ -57,7 +57,7 @@ class CLIPEncoder(torch.nn.Module): self.layers = torch.nn.ModuleList([CLIPLayer(embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations) for i in range(num_layers)]) def forward(self, x, mask=None, intermediate_output=None): - optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None) + optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True) if intermediate_output is not None: if intermediate_output < 0: diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 8015a307..309240d5 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -333,7 +333,6 @@ def attention_pytorch(q, k, v, heads, mask=None): optimized_attention = attention_basic -optimized_attention_masked = attention_basic if model_management.xformers_enabled(): print("Using xformers cross attention") @@ -349,15 +348,15 @@ else: print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention") optimized_attention = attention_sub_quad -if model_management.pytorch_attention_enabled(): - optimized_attention_masked = attention_pytorch +optimized_attention_masked = optimized_attention + +def optimized_attention_for_device(device, mask=False, small_input=False): + if small_input and model_management.pytorch_attention_enabled(): + return attention_pytorch #TODO: need to confirm but this is probably slightly faster for small inputs in all cases + + if device == torch.device("cpu"): + return attention_sub_quad -def optimized_attention_for_device(device, mask=False): - if device == torch.device("cpu"): #TODO - if model_management.pytorch_attention_enabled(): - return attention_pytorch - else: - return attention_basic if mask: return optimized_attention_masked