mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-03-12 22:02:14 +00:00
Add: --use-sage-attention to enable SageAttention.
You need to have the library installed first.
This commit is contained in:
parent
a4f59bc65e
commit
37e5390f5f
@ -104,6 +104,7 @@ attn_group = parser.add_mutually_exclusive_group()
|
|||||||
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
|
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
|
||||||
attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
|
attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
|
||||||
attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
|
attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
|
||||||
|
attn_group.add_argument("--use-sage-attention", action="store_true", help="Use sage attention.")
|
||||||
|
|
||||||
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
|
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
|
||||||
|
|
||||||
|
@ -15,6 +15,9 @@ if model_management.xformers_enabled():
|
|||||||
import xformers
|
import xformers
|
||||||
import xformers.ops
|
import xformers.ops
|
||||||
|
|
||||||
|
if model_management.sage_attention_enabled():
|
||||||
|
from sageattention import sageattn
|
||||||
|
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
ops = comfy.ops.disable_weight_init
|
ops = comfy.ops.disable_weight_init
|
||||||
@ -447,20 +450,54 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||||
|
if skip_reshape:
|
||||||
|
b, _, _, dim_head = q.shape
|
||||||
|
tensor_layout="HND"
|
||||||
|
else:
|
||||||
|
b, _, dim_head = q.shape
|
||||||
|
dim_head //= heads
|
||||||
|
q, k, v = map(
|
||||||
|
lambda t: t.view(b, -1, heads, dim_head),
|
||||||
|
(q, k, v),
|
||||||
|
)
|
||||||
|
tensor_layout="NHD"
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
# add a batch dimension if there isn't already one
|
||||||
|
if mask.ndim == 2:
|
||||||
|
mask = mask.unsqueeze(0)
|
||||||
|
# add a heads dimension if there isn't already one
|
||||||
|
if mask.ndim == 3:
|
||||||
|
mask = mask.unsqueeze(1)
|
||||||
|
|
||||||
|
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
|
||||||
|
if tensor_layout == "HND":
|
||||||
|
out = (
|
||||||
|
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
out = out.reshape(b, -1, heads * dim_head)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
optimized_attention = attention_basic
|
optimized_attention = attention_basic
|
||||||
|
|
||||||
if model_management.xformers_enabled():
|
if model_management.sage_attention_enabled():
|
||||||
logging.info("Using xformers cross attention")
|
logging.info("Using sage attention")
|
||||||
|
optimized_attention = attention_sage
|
||||||
|
elif model_management.xformers_enabled():
|
||||||
|
logging.info("Using xformers attention")
|
||||||
optimized_attention = attention_xformers
|
optimized_attention = attention_xformers
|
||||||
elif model_management.pytorch_attention_enabled():
|
elif model_management.pytorch_attention_enabled():
|
||||||
logging.info("Using pytorch cross attention")
|
logging.info("Using pytorch attention")
|
||||||
optimized_attention = attention_pytorch
|
optimized_attention = attention_pytorch
|
||||||
else:
|
else:
|
||||||
if args.use_split_cross_attention:
|
if args.use_split_cross_attention:
|
||||||
logging.info("Using split optimization for cross attention")
|
logging.info("Using split optimization for attention")
|
||||||
optimized_attention = attention_split
|
optimized_attention = attention_split
|
||||||
else:
|
else:
|
||||||
logging.info("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
|
logging.info("Using sub quadratic optimization for attention, if you have memory or speed issues try using: --use-split-cross-attention")
|
||||||
optimized_attention = attention_sub_quad
|
optimized_attention = attention_sub_quad
|
||||||
|
|
||||||
optimized_attention_masked = optimized_attention
|
optimized_attention_masked = optimized_attention
|
||||||
|
@ -837,6 +837,8 @@ def cast_to_device(tensor, device, dtype, copy=False):
|
|||||||
non_blocking = device_supports_non_blocking(device)
|
non_blocking = device_supports_non_blocking(device)
|
||||||
return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy)
|
return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy)
|
||||||
|
|
||||||
|
def sage_attention_enabled():
|
||||||
|
return args.use_sage_attention
|
||||||
|
|
||||||
def xformers_enabled():
|
def xformers_enabled():
|
||||||
global directml_enabled
|
global directml_enabled
|
||||||
|
Loading…
Reference in New Issue
Block a user