diff --git a/comfy/model_base.py b/comfy/model_base.py index f5952620..52c87ede 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -100,11 +100,29 @@ class BaseModel(torch.nn.Module): if self.inpaint_model: concat_keys = ("mask", "masked_image") cond_concat = [] - denoise_mask = kwargs.get("denoise_mask", None) - latent_image = kwargs.get("latent_image", None) + denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) + concat_latent_image = kwargs.get("concat_latent_image", None) + if concat_latent_image is None: + concat_latent_image = kwargs.get("latent_image", None) + else: + concat_latent_image = self.process_latent_in(concat_latent_image) + noise = kwargs.get("noise", None) device = kwargs["device"] + if concat_latent_image.shape[1:] != noise.shape[1:]: + concat_latent_image = utils.common_upscale(concat_latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center") + + concat_latent_image = utils.resize_to_batch_size(concat_latent_image, noise.shape[0]) + + if len(denoise_mask.shape) == len(noise.shape): + denoise_mask = denoise_mask[:,:1] + + denoise_mask = denoise_mask.reshape((-1, 1, denoise_mask.shape[-2], denoise_mask.shape[-1])) + if denoise_mask.shape[-2:] != noise.shape[-2:]: + denoise_mask = utils.common_upscale(denoise_mask, noise.shape[-1], noise.shape[-2], "bilinear", "center") + denoise_mask = utils.resize_to_batch_size(denoise_mask.round(), noise.shape[0]) + def blank_inpaint_image_like(latent_image): blank_image = torch.ones_like(latent_image) # these are the values for "zero" in pixel space translated to latent space @@ -117,9 +135,9 @@ class BaseModel(torch.nn.Module): for ck in concat_keys: if denoise_mask is not None: if ck == "mask": - cond_concat.append(denoise_mask[:,:1].to(device)) + cond_concat.append(denoise_mask.to(device)) elif ck == "masked_image": - cond_concat.append(latent_image.to(device)) #NOTE: the latent_image should be masked by the mask in pixel space + cond_concat.append(concat_latent_image.to(device)) #NOTE: the latent_image should be masked by the mask in pixel space else: if ck == "mask": cond_concat.append(torch.ones_like(noise)[:,:1]) diff --git a/nodes.py b/nodes.py index 96d9e51a..6c7317b6 100644 --- a/nodes.py +++ b/nodes.py @@ -359,6 +359,62 @@ class VAEEncodeForInpaint: return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, ) + +class InpaintModelConditioning: + @classmethod + def INPUT_TYPES(s): + return {"required": {"positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "vae": ("VAE", ), + "pixels": ("IMAGE", ), + "mask": ("MASK", ), + }} + + RETURN_TYPES = ("CONDITIONING","CONDITIONING","LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + FUNCTION = "encode" + + CATEGORY = "conditioning/inpaint" + + def encode(self, positive, negative, pixels, vae, mask): + x = (pixels.shape[1] // 8) * 8 + y = (pixels.shape[2] // 8) * 8 + mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear") + + orig_pixels = pixels + pixels = orig_pixels.clone() + if pixels.shape[1] != x or pixels.shape[2] != y: + x_offset = (pixels.shape[1] % 8) // 2 + y_offset = (pixels.shape[2] % 8) // 2 + pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:] + mask = mask[:,:,x_offset:x + x_offset, y_offset:y + y_offset] + + m = (1.0 - mask.round()).squeeze(1) + for i in range(3): + pixels[:,:,:,i] -= 0.5 + pixels[:,:,:,i] *= m + pixels[:,:,:,i] += 0.5 + concat_latent = vae.encode(pixels) + orig_latent = vae.encode(orig_pixels) + + out_latent = {} + + out_latent["samples"] = orig_latent + out_latent["noise_mask"] = mask + + out = [] + for conditioning in [positive, negative]: + c = [] + for t in conditioning: + d = t[1].copy() + d["concat_latent_image"] = concat_latent + d["concat_mask"] = mask + n = [t[0], d] + c.append(n) + out.append(c) + return (out[0], out[1], out_latent) + + class SaveLatent: def __init__(self): self.output_dir = folder_paths.get_output_directory() @@ -1628,10 +1684,11 @@ class ImagePadForOutpaint: def expand_image(self, image, left, top, right, bottom, feathering): d1, d2, d3, d4 = image.size() - new_image = torch.zeros( + new_image = torch.ones( (d1, d2 + top + bottom, d3 + left + right, d4), dtype=torch.float32, - ) + ) * 0.5 + new_image[:, top:top + d2, left:left + d3, :] = image mask = torch.ones( @@ -1723,6 +1780,7 @@ NODE_CLASS_MAPPINGS = { "unCLIPCheckpointLoader": unCLIPCheckpointLoader, "GLIGENLoader": GLIGENLoader, "GLIGENTextBoxApply": GLIGENTextBoxApply, + "InpaintModelConditioning": InpaintModelConditioning, "CheckpointLoader": CheckpointLoader, "DiffusersLoader": DiffusersLoader,