From 37e5390f5ff01ae367ac37d62377bbedff2a68da Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 18 Dec 2024 01:56:10 -0500 Subject: [PATCH] Add: --use-sage-attention to enable SageAttention. You need to have the library installed first. --- comfy/cli_args.py | 1 + comfy/ldm/modules/attention.py | 47 ++++++++++++++++++++++++++++++---- comfy/model_management.py | 2 ++ 3 files changed, 45 insertions(+), 5 deletions(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 847f35ab..4c654501 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -104,6 +104,7 @@ attn_group = parser.add_mutually_exclusive_group() attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.") attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.") attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.") +attn_group.add_argument("--use-sage-attention", action="store_true", help="Use sage attention.") parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.") diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index e60d1ab2..0d54e6be 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -15,6 +15,9 @@ if model_management.xformers_enabled(): import xformers import xformers.ops +if model_management.sage_attention_enabled(): + from sageattention import sageattn + from comfy.cli_args import args import comfy.ops ops = comfy.ops.disable_weight_init @@ -447,20 +450,54 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha return out +def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False): + if skip_reshape: + b, _, _, dim_head = q.shape + tensor_layout="HND" + else: + b, _, dim_head = q.shape + dim_head //= heads + q, k, v = map( + lambda t: t.view(b, -1, heads, dim_head), + (q, k, v), + ) + tensor_layout="NHD" + + if mask is not None: + # add a batch dimension if there isn't already one + if mask.ndim == 2: + mask = mask.unsqueeze(0) + # add a heads dimension if there isn't already one + if mask.ndim == 3: + mask = mask.unsqueeze(1) + + out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout) + if tensor_layout == "HND": + out = ( + out.transpose(1, 2).reshape(b, -1, heads * dim_head) + ) + else: + out = out.reshape(b, -1, heads * dim_head) + return out + + optimized_attention = attention_basic -if model_management.xformers_enabled(): - logging.info("Using xformers cross attention") +if model_management.sage_attention_enabled(): + logging.info("Using sage attention") + optimized_attention = attention_sage +elif model_management.xformers_enabled(): + logging.info("Using xformers attention") optimized_attention = attention_xformers elif model_management.pytorch_attention_enabled(): - logging.info("Using pytorch cross attention") + logging.info("Using pytorch attention") optimized_attention = attention_pytorch else: if args.use_split_cross_attention: - logging.info("Using split optimization for cross attention") + logging.info("Using split optimization for attention") optimized_attention = attention_split else: - logging.info("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention") + logging.info("Using sub quadratic optimization for attention, if you have memory or speed issues try using: --use-split-cross-attention") optimized_attention = attention_sub_quad optimized_attention_masked = optimized_attention diff --git a/comfy/model_management.py b/comfy/model_management.py index 177c7998..f6ca252e 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -837,6 +837,8 @@ def cast_to_device(tensor, device, dtype, copy=False): non_blocking = device_supports_non_blocking(device) return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy) +def sage_attention_enabled(): + return args.use_sage_attention def xformers_enabled(): global directml_enabled