diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index f8c959cc..6f0b41dc 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -215,32 +215,24 @@ class AttnBlock(nn.Module): if mem_required > mem_free_total: steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) - first_op_done = False while True: try: slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] for i in range(0, q.shape[1], slice_size): end = i + slice_size s1 = torch.bmm(q[:, i:end], k) * scale - first_op_done = True - torch.exp(s1, out=s1) - summed = torch.sum(s1, dim=2, keepdim=True) - s1 /= summed - s2 = s1.permute(0,2,1) + s2 = torch.nn.functional.softmax(s1, dim=2).permute(0,2,1) del s1 r1[:, :, i:end] = torch.bmm(v, s2) del s2 break except OOM_EXCEPTION as e: - if first_op_done == False: - steps *= 2 - if steps > 128: - raise e - print("out of memory error, increasing steps and trying again", steps) - else: + steps *= 2 + if steps > 128: raise e + print("out of memory error, increasing steps and trying again", steps) h_ = r1.reshape(b,c,h,w) del r1