diff --git a/comfy/samplers.py b/comfy/samplers.py index c8ab106b..7e2e667b 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -330,8 +330,7 @@ class KSampler: if ck == "mask": cond_concat.append(denoise_mask[:,:1]) elif ck == "masked_image": - blank_image = blank_inpaint_image_like(latent_image) - cond_concat.append(latent_image * (1.0 - denoise_mask) + denoise_mask * blank_image) + cond_concat.append(latent_image) #NOTE: the latent_image should be masked by the mask in pixel space else: if ck == "mask": cond_concat.append(torch.ones_like(noise)[:,:1]) diff --git a/nodes.py b/nodes.py index 0a41511c..5784ba3d 100644 --- a/nodes.py +++ b/nodes.py @@ -121,6 +121,38 @@ class VAEEncode: return ({"samples":t}, ) +class VAEEncodeForInpaint: + def __init__(self, device="cpu"): + self.device = device + + @classmethod + def INPUT_TYPES(s): + return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", ), "mask": ("MASK", )}} + RETURN_TYPES = ("LATENT",) + FUNCTION = "encode" + + CATEGORY = "latent/inpaint" + + def encode(self, vae, pixels, mask): + print(pixels.shape, mask.shape) + x = (pixels.shape[1] // 64) * 64 + y = (pixels.shape[2] // 64) * 64 + if pixels.shape[1] != x or pixels.shape[2] != y: + pixels = pixels[:,:x,:y,:] + mask = mask[:x,:y] + + #shave off a few pixels to keep things seamless + kernel_tensor = torch.ones((1, 1, 6, 6)) + mask_erosion = torch.clamp(torch.nn.functional.conv2d((1.0 - mask.round())[None], kernel_tensor, padding=3), 0, 1) + print(mask_erosion.shape, pixels.shape) + for i in range(3): + pixels[:,:,:,i] -= 0.5 + pixels[:,:,:,i] *= mask_erosion[0][:x,:y].round() + pixels[:,:,:,i] += 0.5 + t = vae.encode(pixels) + + return ({"samples":t, "noise_mask": mask}, ) + class CheckpointLoader: models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") config_dir = os.path.join(models_dir, "configs") @@ -391,7 +423,7 @@ class SetLatentNoiseMask: RETURN_TYPES = ("LATENT",) FUNCTION = "set_mask" - CATEGORY = "latent" + CATEGORY = "latent/inpaint" def set_mask(self, samples, mask): s = samples.copy() @@ -654,6 +686,7 @@ NODE_CLASS_MAPPINGS = { "CLIPTextEncode": CLIPTextEncode, "VAEDecode": VAEDecode, "VAEEncode": VAEEncode, + "VAEEncodeForInpaint": VAEEncodeForInpaint, "VAELoader": VAELoader, "EmptyLatentImage": EmptyLatentImage, "LatentUpscale": LatentUpscale,