mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Use real softmax in split op to fix issue with some images.
This commit is contained in:
parent
7e1e193f39
commit
509c7dfc6d
@ -215,32 +215,24 @@ class AttnBlock(nn.Module):
|
||||
if mem_required > mem_free_total:
|
||||
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||
|
||||
first_op_done = False
|
||||
while True:
|
||||
try:
|
||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
s1 = torch.bmm(q[:, i:end], k) * scale
|
||||
first_op_done = True
|
||||
|
||||
torch.exp(s1, out=s1)
|
||||
summed = torch.sum(s1, dim=2, keepdim=True)
|
||||
s1 /= summed
|
||||
s2 = s1.permute(0,2,1)
|
||||
s2 = torch.nn.functional.softmax(s1, dim=2).permute(0,2,1)
|
||||
del s1
|
||||
|
||||
r1[:, :, i:end] = torch.bmm(v, s2)
|
||||
del s2
|
||||
break
|
||||
except OOM_EXCEPTION as e:
|
||||
if first_op_done == False:
|
||||
steps *= 2
|
||||
if steps > 128:
|
||||
raise e
|
||||
print("out of memory error, increasing steps and trying again", steps)
|
||||
else:
|
||||
steps *= 2
|
||||
if steps > 128:
|
||||
raise e
|
||||
print("out of memory error, increasing steps and trying again", steps)
|
||||
|
||||
h_ = r1.reshape(b,c,h,w)
|
||||
del r1
|
||||
|
Loading…
Reference in New Issue
Block a user