Use real softmax in split op to fix issue with some images.

This commit is contained in:
comfyanonymous 2023-02-10 02:46:50 -05:00
parent 7e1e193f39
commit 509c7dfc6d

View File

@ -215,32 +215,24 @@ class AttnBlock(nn.Module):
if mem_required > mem_free_total:
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
first_op_done = False
while True:
try:
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
s1 = torch.bmm(q[:, i:end], k) * scale
first_op_done = True
torch.exp(s1, out=s1)
summed = torch.sum(s1, dim=2, keepdim=True)
s1 /= summed
s2 = s1.permute(0,2,1)
s2 = torch.nn.functional.softmax(s1, dim=2).permute(0,2,1)
del s1
r1[:, :, i:end] = torch.bmm(v, s2)
del s2
break
except OOM_EXCEPTION as e:
if first_op_done == False:
steps *= 2
if steps > 128:
raise e
print("out of memory error, increasing steps and trying again", steps)
else:
raise e
h_ = r1.reshape(b,c,h,w)
del r1