Use real softmax in split op to fix issue with some images.

This commit is contained in:
comfyanonymous 2023-02-10 02:46:50 -05:00
parent 7e1e193f39
commit 509c7dfc6d

View File

@ -215,32 +215,24 @@ class AttnBlock(nn.Module):
if mem_required > mem_free_total: if mem_required > mem_free_total:
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
first_op_done = False
while True: while True:
try: try:
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size): for i in range(0, q.shape[1], slice_size):
end = i + slice_size end = i + slice_size
s1 = torch.bmm(q[:, i:end], k) * scale s1 = torch.bmm(q[:, i:end], k) * scale
first_op_done = True
torch.exp(s1, out=s1) s2 = torch.nn.functional.softmax(s1, dim=2).permute(0,2,1)
summed = torch.sum(s1, dim=2, keepdim=True)
s1 /= summed
s2 = s1.permute(0,2,1)
del s1 del s1
r1[:, :, i:end] = torch.bmm(v, s2) r1[:, :, i:end] = torch.bmm(v, s2)
del s2 del s2
break break
except OOM_EXCEPTION as e: except OOM_EXCEPTION as e:
if first_op_done == False: steps *= 2
steps *= 2 if steps > 128:
if steps > 128:
raise e
print("out of memory error, increasing steps and trying again", steps)
else:
raise e raise e
print("out of memory error, increasing steps and trying again", steps)
h_ = r1.reshape(b,c,h,w) h_ = r1.reshape(b,c,h,w)
del r1 del r1