Add a LoadImageMask node to load one colour channel in an image as a mask.

This commit is contained in:
comfyanonymous 2023-02-15 17:39:42 -05:00
parent d75003001a
commit e87a8669b6

View File

@ -410,11 +410,8 @@ def common_ksampler(device, model, seed, steps, cfg, sampler_name, scheduler, po
if "noise_mask" in latent:
noise_mask = latent['noise_mask']
print(noise_mask.shape, noise.shape)
noise_mask = torch.nn.functional.interpolate(noise_mask[None,None,], size=(noise.shape[2], noise.shape[3]), mode="bilinear")
noise_mask = noise_mask.floor()
noise_mask = torch.ones_like(noise_mask) - noise_mask
noise_mask = noise_mask.round()
noise_mask = torch.cat([noise_mask] * noise.shape[1], dim=1)
noise_mask = torch.cat([noise_mask] * noise.shape[0])
noise_mask = noise_mask.to(device)
@ -581,10 +578,11 @@ class LoadImage:
FUNCTION = "load_image"
def load_image(self, image):
image_path = os.path.join(self.input_dir, image)
image = Image.open(image_path).convert("RGB")
i = Image.open(image_path)
image = i.convert("RGB")
image = np.array(image).astype(np.float32) / 255.0
image = torch.from_numpy(image[None])[None,]
return image
image = torch.from_numpy(image)[None,]
return (image,)
@classmethod
def IS_CHANGED(s, image):
@ -594,6 +592,41 @@ class LoadImage:
m.update(f.read())
return m.digest().hex()
class LoadImageMask:
input_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
@classmethod
def INPUT_TYPES(s):
return {"required":
{"image": (os.listdir(s.input_dir), ),
"channel": (["alpha", "red", "green", "blue"], ),}
}
CATEGORY = "image"
RETURN_TYPES = ("MASK",)
FUNCTION = "load_image"
def load_image(self, image, channel):
image_path = os.path.join(self.input_dir, image)
i = Image.open(image_path)
mask = None
c = channel[0].upper()
if c in i.getbands():
mask = np.array(i.getchannel(c)).astype(np.float32) / 255.0
mask = torch.from_numpy(mask)
if c == 'A':
mask = 1. - mask
else:
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
return (mask,)
@classmethod
def IS_CHANGED(s, image, channel):
image_path = os.path.join(s.input_dir, image)
m = hashlib.sha256()
with open(image_path, 'rb') as f:
m.update(f.read())
return m.digest().hex()
class ImageScale:
upscale_methods = ["nearest-exact", "bilinear", "area"]
crop_methods = ["disabled", "center"]
@ -626,6 +659,7 @@ NODE_CLASS_MAPPINGS = {
"LatentUpscale": LatentUpscale,
"SaveImage": SaveImage,
"LoadImage": LoadImage,
"LoadImageMask": LoadImageMask,
"ImageScale": ImageScale,
"ConditioningCombine": ConditioningCombine,
"ConditioningSetArea": ConditioningSetArea,