Use inpaint models the proper way by using VAEEncodeForInpaint.

This commit is contained in:
comfyanonymous 2023-02-15 20:44:51 -05:00
parent 59bef84bc8
commit bc69fb5245
2 changed files with 35 additions and 3 deletions

View File

@ -330,8 +330,7 @@ class KSampler:
if ck == "mask": if ck == "mask":
cond_concat.append(denoise_mask[:,:1]) cond_concat.append(denoise_mask[:,:1])
elif ck == "masked_image": elif ck == "masked_image":
blank_image = blank_inpaint_image_like(latent_image) cond_concat.append(latent_image) #NOTE: the latent_image should be masked by the mask in pixel space
cond_concat.append(latent_image * (1.0 - denoise_mask) + denoise_mask * blank_image)
else: else:
if ck == "mask": if ck == "mask":
cond_concat.append(torch.ones_like(noise)[:,:1]) cond_concat.append(torch.ones_like(noise)[:,:1])

View File

@ -121,6 +121,38 @@ class VAEEncode:
return ({"samples":t}, ) return ({"samples":t}, )
class VAEEncodeForInpaint:
def __init__(self, device="cpu"):
self.device = device
@classmethod
def INPUT_TYPES(s):
return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", ), "mask": ("MASK", )}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "encode"
CATEGORY = "latent/inpaint"
def encode(self, vae, pixels, mask):
print(pixels.shape, mask.shape)
x = (pixels.shape[1] // 64) * 64
y = (pixels.shape[2] // 64) * 64
if pixels.shape[1] != x or pixels.shape[2] != y:
pixels = pixels[:,:x,:y,:]
mask = mask[:x,:y]
#shave off a few pixels to keep things seamless
kernel_tensor = torch.ones((1, 1, 6, 6))
mask_erosion = torch.clamp(torch.nn.functional.conv2d((1.0 - mask.round())[None], kernel_tensor, padding=3), 0, 1)
print(mask_erosion.shape, pixels.shape)
for i in range(3):
pixels[:,:,:,i] -= 0.5
pixels[:,:,:,i] *= mask_erosion[0][:x,:y].round()
pixels[:,:,:,i] += 0.5
t = vae.encode(pixels)
return ({"samples":t, "noise_mask": mask}, )
class CheckpointLoader: class CheckpointLoader:
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
config_dir = os.path.join(models_dir, "configs") config_dir = os.path.join(models_dir, "configs")
@ -391,7 +423,7 @@ class SetLatentNoiseMask:
RETURN_TYPES = ("LATENT",) RETURN_TYPES = ("LATENT",)
FUNCTION = "set_mask" FUNCTION = "set_mask"
CATEGORY = "latent" CATEGORY = "latent/inpaint"
def set_mask(self, samples, mask): def set_mask(self, samples, mask):
s = samples.copy() s = samples.copy()
@ -654,6 +686,7 @@ NODE_CLASS_MAPPINGS = {
"CLIPTextEncode": CLIPTextEncode, "CLIPTextEncode": CLIPTextEncode,
"VAEDecode": VAEDecode, "VAEDecode": VAEDecode,
"VAEEncode": VAEEncode, "VAEEncode": VAEEncode,
"VAEEncodeForInpaint": VAEEncodeForInpaint,
"VAELoader": VAELoader, "VAELoader": VAELoader,
"EmptyLatentImage": EmptyLatentImage, "EmptyLatentImage": EmptyLatentImage,
"LatentUpscale": LatentUpscale, "LatentUpscale": LatentUpscale,