mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-20 11:23:29 +00:00
Speedup on some models by not upcasting bfloat16 to float32 on mac.
This commit is contained in:
parent
4553891bbd
commit
96d891cb94
@ -30,11 +30,12 @@ ops = comfy.ops.disable_weight_init
|
|||||||
|
|
||||||
FORCE_UPCAST_ATTENTION_DTYPE = model_management.force_upcast_attention_dtype()
|
FORCE_UPCAST_ATTENTION_DTYPE = model_management.force_upcast_attention_dtype()
|
||||||
|
|
||||||
def get_attn_precision(attn_precision):
|
def get_attn_precision(attn_precision, current_dtype):
|
||||||
if args.dont_upcast_attention:
|
if args.dont_upcast_attention:
|
||||||
return None
|
return None
|
||||||
if FORCE_UPCAST_ATTENTION_DTYPE is not None:
|
|
||||||
return FORCE_UPCAST_ATTENTION_DTYPE
|
if FORCE_UPCAST_ATTENTION_DTYPE is not None and current_dtype in FORCE_UPCAST_ATTENTION_DTYPE:
|
||||||
|
return FORCE_UPCAST_ATTENTION_DTYPE[current_dtype]
|
||||||
return attn_precision
|
return attn_precision
|
||||||
|
|
||||||
def exists(val):
|
def exists(val):
|
||||||
@ -81,7 +82,7 @@ def Normalize(in_channels, dtype=None, device=None):
|
|||||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
|
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
|
||||||
|
|
||||||
def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||||
attn_precision = get_attn_precision(attn_precision)
|
attn_precision = get_attn_precision(attn_precision, q.dtype)
|
||||||
|
|
||||||
if skip_reshape:
|
if skip_reshape:
|
||||||
b, _, _, dim_head = q.shape
|
b, _, _, dim_head = q.shape
|
||||||
@ -150,7 +151,7 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
|||||||
|
|
||||||
|
|
||||||
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||||
attn_precision = get_attn_precision(attn_precision)
|
attn_precision = get_attn_precision(attn_precision, query.dtype)
|
||||||
|
|
||||||
if skip_reshape:
|
if skip_reshape:
|
||||||
b, _, _, dim_head = query.shape
|
b, _, _, dim_head = query.shape
|
||||||
@ -220,7 +221,7 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||||
attn_precision = get_attn_precision(attn_precision)
|
attn_precision = get_attn_precision(attn_precision, q.dtype)
|
||||||
|
|
||||||
if skip_reshape:
|
if skip_reshape:
|
||||||
b, _, _, dim_head = q.shape
|
b, _, _, dim_head = q.shape
|
||||||
|
@ -954,7 +954,7 @@ def force_upcast_attention_dtype():
|
|||||||
upcast = True
|
upcast = True
|
||||||
|
|
||||||
if upcast:
|
if upcast:
|
||||||
return torch.float32
|
return {torch.float16: torch.float32}
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user