diff --git a/nodes.py b/nodes.py index 7ed7a8e4..027bf55d 100644 --- a/nodes.py +++ b/nodes.py @@ -9,7 +9,7 @@ import math import time import random -from PIL import Image, ImageOps +from PIL import Image, ImageOps, ImageSequence from PIL.PngImagePlugin import PngInfo import numpy as np import safetensors.torch @@ -1410,17 +1410,30 @@ class LoadImage: FUNCTION = "load_image" def load_image(self, image): image_path = folder_paths.get_annotated_filepath(image) - i = Image.open(image_path) - i = ImageOps.exif_transpose(i) - image = i.convert("RGB") - image = np.array(image).astype(np.float32) / 255.0 - image = torch.from_numpy(image)[None,] - if 'A' in i.getbands(): - mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0 - mask = 1. - torch.from_numpy(mask) + img = Image.open(image_path) + output_images = [] + output_masks = [] + for i in ImageSequence.Iterator(img): + i = ImageOps.exif_transpose(i) + image = i.convert("RGB") + image = np.array(image).astype(np.float32) / 255.0 + image = torch.from_numpy(image)[None,] + if 'A' in i.getbands(): + mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0 + mask = 1. - torch.from_numpy(mask) + else: + mask = torch.zeros((64,64), dtype=torch.float32, device="cpu") + output_images.append(image) + output_masks.append(mask.unsqueeze(0)) + + if len(output_images) > 1: + output_image = torch.cat(output_images, dim=0) + output_mask = torch.cat(output_masks, dim=0) else: - mask = torch.zeros((64,64), dtype=torch.float32, device="cpu") - return (image, mask.unsqueeze(0)) + output_image = output_images[0] + output_mask = output_masks[0] + + return (output_image, output_mask) @classmethod def IS_CHANGED(s, image):