LoadImage now loads all the frames from animated images as a batch.

This commit is contained in:
comfyanonymous 2023-12-20 16:39:09 -05:00
parent 5f54614e7f
commit a1e1c69f7d

View File

@ -9,7 +9,7 @@ import math
import time import time
import random import random
from PIL import Image, ImageOps from PIL import Image, ImageOps, ImageSequence
from PIL.PngImagePlugin import PngInfo from PIL.PngImagePlugin import PngInfo
import numpy as np import numpy as np
import safetensors.torch import safetensors.torch
@ -1410,7 +1410,10 @@ class LoadImage:
FUNCTION = "load_image" FUNCTION = "load_image"
def load_image(self, image): def load_image(self, image):
image_path = folder_paths.get_annotated_filepath(image) image_path = folder_paths.get_annotated_filepath(image)
i = Image.open(image_path) img = Image.open(image_path)
output_images = []
output_masks = []
for i in ImageSequence.Iterator(img):
i = ImageOps.exif_transpose(i) i = ImageOps.exif_transpose(i)
image = i.convert("RGB") image = i.convert("RGB")
image = np.array(image).astype(np.float32) / 255.0 image = np.array(image).astype(np.float32) / 255.0
@ -1420,7 +1423,17 @@ class LoadImage:
mask = 1. - torch.from_numpy(mask) mask = 1. - torch.from_numpy(mask)
else: else:
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu") mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
return (image, mask.unsqueeze(0)) output_images.append(image)
output_masks.append(mask.unsqueeze(0))
if len(output_images) > 1:
output_image = torch.cat(output_images, dim=0)
output_mask = torch.cat(output_masks, dim=0)
else:
output_image = output_images[0]
output_mask = output_masks[0]
return (output_image, output_mask)
@classmethod @classmethod
def IS_CHANGED(s, image): def IS_CHANGED(s, image):