diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index f6d16b95..1d2daa9d 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -20,6 +20,11 @@ except: import os _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32") +try: + OOM_EXCEPTION = torch.cuda.OutOfMemoryError +except: + OOM_EXCEPTION = Exception + def exists(val): return val is not None @@ -316,7 +321,7 @@ class CrossAttentionDoggettx(nn.Module): r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) del s2 break - except torch.cuda.OutOfMemoryError as e: + except OOM_EXCEPTION as e: if first_op_done == False: torch.cuda.empty_cache() torch.cuda.ipc_collect() diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 62cd0c51..f8c959cc 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -16,6 +16,10 @@ except: XFORMERS_IS_AVAILBLE = False print("No module 'xformers'. Proceeding without it.") +try: + OOM_EXCEPTION = torch.cuda.OutOfMemoryError +except: + OOM_EXCEPTION = Exception def get_timestep_embedding(timesteps, embedding_dim): """ @@ -229,7 +233,7 @@ class AttnBlock(nn.Module): r1[:, :, i:end] = torch.bmm(v, s2) del s2 break - except torch.cuda.OutOfMemoryError as e: + except OOM_EXCEPTION as e: if first_op_done == False: steps *= 2 if steps > 128: