Add a RebatchImages node.

This commit is contained in:
comfyanonymous 2023-12-20 16:22:18 -05:00
parent e82942cc29
commit 5f54614e7f

View File

@ -99,10 +99,40 @@ class LatentRebatch:
return (output_list,) return (output_list,)
class ImageRebatch:
@classmethod
def INPUT_TYPES(s):
return {"required": { "images": ("IMAGE",),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
}}
RETURN_TYPES = ("IMAGE",)
INPUT_IS_LIST = True
OUTPUT_IS_LIST = (True, )
FUNCTION = "rebatch"
CATEGORY = "image/batch"
def rebatch(self, images, batch_size):
batch_size = batch_size[0]
output_list = []
all_images = []
for img in images:
for i in range(img.shape[0]):
all_images.append(img[i:i+1])
for i in range(0, len(all_images), batch_size):
output_list.append(torch.cat(all_images[i:i+batch_size], dim=0))
return (output_list,)
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"RebatchLatents": LatentRebatch, "RebatchLatents": LatentRebatch,
"RebatchImages": ImageRebatch,
} }
NODE_DISPLAY_NAME_MAPPINGS = { NODE_DISPLAY_NAME_MAPPINGS = {
"RebatchLatents": "Rebatch Latents", "RebatchLatents": "Rebatch Latents",
} "RebatchImages": "Rebatch Images",
}