mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Use real softmax in split op to fix issue with some images.
This commit is contained in:
parent
7e1e193f39
commit
509c7dfc6d
@ -215,32 +215,24 @@ class AttnBlock(nn.Module):
|
|||||||
if mem_required > mem_free_total:
|
if mem_required > mem_free_total:
|
||||||
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
|
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||||
|
|
||||||
first_op_done = False
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||||
for i in range(0, q.shape[1], slice_size):
|
for i in range(0, q.shape[1], slice_size):
|
||||||
end = i + slice_size
|
end = i + slice_size
|
||||||
s1 = torch.bmm(q[:, i:end], k) * scale
|
s1 = torch.bmm(q[:, i:end], k) * scale
|
||||||
first_op_done = True
|
|
||||||
|
|
||||||
torch.exp(s1, out=s1)
|
s2 = torch.nn.functional.softmax(s1, dim=2).permute(0,2,1)
|
||||||
summed = torch.sum(s1, dim=2, keepdim=True)
|
|
||||||
s1 /= summed
|
|
||||||
s2 = s1.permute(0,2,1)
|
|
||||||
del s1
|
del s1
|
||||||
|
|
||||||
r1[:, :, i:end] = torch.bmm(v, s2)
|
r1[:, :, i:end] = torch.bmm(v, s2)
|
||||||
del s2
|
del s2
|
||||||
break
|
break
|
||||||
except OOM_EXCEPTION as e:
|
except OOM_EXCEPTION as e:
|
||||||
if first_op_done == False:
|
steps *= 2
|
||||||
steps *= 2
|
if steps > 128:
|
||||||
if steps > 128:
|
|
||||||
raise e
|
|
||||||
print("out of memory error, increasing steps and trying again", steps)
|
|
||||||
else:
|
|
||||||
raise e
|
raise e
|
||||||
|
print("out of memory error, increasing steps and trying again", steps)
|
||||||
|
|
||||||
h_ = r1.reshape(b,c,h,w)
|
h_ = r1.reshape(b,c,h,w)
|
||||||
del r1
|
del r1
|
||||||
|
Loading…
Reference in New Issue
Block a user