From 83d969e3975d340ef980db59b07954a67d08ce6f Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 21 May 2024 13:55:49 -0400 Subject: [PATCH] Disable xformers when tracing model. --- comfy/ldm/modules/attention.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 74a2fd99..2ce99d46 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -313,9 +313,19 @@ except: def attention_xformers(q, k, v, heads, mask=None, attn_precision=None): b, _, dim_head = q.shape dim_head //= heads + + disabled_xformers = False + if BROKEN_XFORMERS: if b * heads > 65535: - return attention_pytorch(q, k, v, heads, mask) + disabled_xformers = True + + if not disabled_xformers: + if torch.jit.is_tracing() or torch.jit.is_scripting(): + disabled_xformers = True + + if disabled_xformers: + return attention_pytorch(q, k, v, heads, mask) q, k, v = map( lambda t: t.reshape(b, -1, heads, dim_head),