Slightly smarter batching behaviour.

Try to keep batch sizes more consistent which seems to improve things on
AMD GPUs.
This commit is contained in:
comfyanonymous 2023-02-08 17:09:47 -05:00
parent bbdcf0b737
commit 3fd87cbd21

View File

@ -86,15 +86,21 @@ class CFGDenoiserComplex(torch.nn.Module):
while len(to_run) > 0: while len(to_run) > 0:
first = to_run[0] first = to_run[0]
first_shape = first[0][0].shape first_shape = first[0][0].shape
to_batch = [] to_batch_temp = []
for x in range(len(to_run)): for x in range(len(to_run)):
if to_run[x][0][0].shape == first_shape: if to_run[x][0][0].shape == first_shape:
if to_run[x][0][2].shape == first[0][2].shape: if to_run[x][0][2].shape == first[0][2].shape:
to_batch += [x] to_batch_temp += [x]
if (len(to_batch) * first_shape[0] * first_shape[2] * first_shape[3] >= max_total_area):
break to_batch_temp.reverse()
to_batch = to_batch_temp[:1]
for i in range(1, len(to_batch_temp) + 1):
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
if (len(batch_amount) * first_shape[0] * first_shape[2] * first_shape[3] < max_total_area):
to_batch = batch_amount
break
to_batch.reverse()
input_x = [] input_x = []
mult = [] mult = []
c = [] c = []