Same thing but for the other places where it's used.

This commit is contained in:
comfyanonymous 2023-02-09 12:43:29 -05:00
parent df40d4f3bf
commit 773cdabfce
2 changed files with 11 additions and 2 deletions

View File

@ -20,6 +20,11 @@ except:
import os import os
_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32") _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
try:
OOM_EXCEPTION = torch.cuda.OutOfMemoryError
except:
OOM_EXCEPTION = Exception
def exists(val): def exists(val):
return val is not None return val is not None
@ -316,7 +321,7 @@ class CrossAttentionDoggettx(nn.Module):
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
del s2 del s2
break break
except torch.cuda.OutOfMemoryError as e: except OOM_EXCEPTION as e:
if first_op_done == False: if first_op_done == False:
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.cuda.ipc_collect() torch.cuda.ipc_collect()

View File

@ -16,6 +16,10 @@ except:
XFORMERS_IS_AVAILBLE = False XFORMERS_IS_AVAILBLE = False
print("No module 'xformers'. Proceeding without it.") print("No module 'xformers'. Proceeding without it.")
try:
OOM_EXCEPTION = torch.cuda.OutOfMemoryError
except:
OOM_EXCEPTION = Exception
def get_timestep_embedding(timesteps, embedding_dim): def get_timestep_embedding(timesteps, embedding_dim):
""" """
@ -229,7 +233,7 @@ class AttnBlock(nn.Module):
r1[:, :, i:end] = torch.bmm(v, s2) r1[:, :, i:end] = torch.bmm(v, s2)
del s2 del s2
break break
except torch.cuda.OutOfMemoryError as e: except OOM_EXCEPTION as e:
if first_op_done == False: if first_op_done == False:
steps *= 2 steps *= 2
if steps > 128: if steps > 128: