Speedup on some models by not upcasting bfloat16 to float32 on mac.

This commit is contained in:
comfyanonymous 2025-02-24 05:41:07 -05:00
parent 4553891bbd
commit 96d891cb94
2 changed files with 8 additions and 7 deletions

View File

@ -30,11 +30,12 @@ ops = comfy.ops.disable_weight_init
FORCE_UPCAST_ATTENTION_DTYPE = model_management.force_upcast_attention_dtype() FORCE_UPCAST_ATTENTION_DTYPE = model_management.force_upcast_attention_dtype()
def get_attn_precision(attn_precision): def get_attn_precision(attn_precision, current_dtype):
if args.dont_upcast_attention: if args.dont_upcast_attention:
return None return None
if FORCE_UPCAST_ATTENTION_DTYPE is not None:
return FORCE_UPCAST_ATTENTION_DTYPE if FORCE_UPCAST_ATTENTION_DTYPE is not None and current_dtype in FORCE_UPCAST_ATTENTION_DTYPE:
return FORCE_UPCAST_ATTENTION_DTYPE[current_dtype]
return attn_precision return attn_precision
def exists(val): def exists(val):
@ -81,7 +82,7 @@ def Normalize(in_channels, dtype=None, device=None):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device) return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
attn_precision = get_attn_precision(attn_precision) attn_precision = get_attn_precision(attn_precision, q.dtype)
if skip_reshape: if skip_reshape:
b, _, _, dim_head = q.shape b, _, _, dim_head = q.shape
@ -150,7 +151,7 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
attn_precision = get_attn_precision(attn_precision) attn_precision = get_attn_precision(attn_precision, query.dtype)
if skip_reshape: if skip_reshape:
b, _, _, dim_head = query.shape b, _, _, dim_head = query.shape
@ -220,7 +221,7 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
return hidden_states return hidden_states
def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
attn_precision = get_attn_precision(attn_precision) attn_precision = get_attn_precision(attn_precision, q.dtype)
if skip_reshape: if skip_reshape:
b, _, _, dim_head = q.shape b, _, _, dim_head = q.shape

View File

@ -954,7 +954,7 @@ def force_upcast_attention_dtype():
upcast = True upcast = True
if upcast: if upcast:
return torch.float32 return {torch.float16: torch.float32}
else: else:
return None return None