LoadImage now loads all the frames from animated images as a batch.

This commit is contained in:
comfyanonymous 2023-12-20 16:39:09 -05:00
parent 5f54614e7f
commit a1e1c69f7d

View File

@ -9,7 +9,7 @@ import math
import time
import random
from PIL import Image, ImageOps
from PIL import Image, ImageOps, ImageSequence
from PIL.PngImagePlugin import PngInfo
import numpy as np
import safetensors.torch
@ -1410,7 +1410,10 @@ class LoadImage:
FUNCTION = "load_image"
def load_image(self, image):
image_path = folder_paths.get_annotated_filepath(image)
i = Image.open(image_path)
img = Image.open(image_path)
output_images = []
output_masks = []
for i in ImageSequence.Iterator(img):
i = ImageOps.exif_transpose(i)
image = i.convert("RGB")
image = np.array(image).astype(np.float32) / 255.0
@ -1420,7 +1423,17 @@ class LoadImage:
mask = 1. - torch.from_numpy(mask)
else:
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
return (image, mask.unsqueeze(0))
output_images.append(image)
output_masks.append(mask.unsqueeze(0))
if len(output_images) > 1:
output_image = torch.cat(output_images, dim=0)
output_mask = torch.cat(output_masks, dim=0)
else:
output_image = output_images[0]
output_mask = output_masks[0]
return (output_image, output_mask)
@classmethod
def IS_CHANGED(s, image):