From 782a24fce65272649191635ce43e3bec5e09c5e2 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 18 Oct 2023 16:48:37 -0400 Subject: [PATCH] Refactor cond_concat into model object. --- comfy/model_base.py | 34 +++++++++++++++++++++++++++++++++- comfy/samplers.py | 28 ++++------------------------ 2 files changed, 37 insertions(+), 25 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index ed2dc83e..8e704022 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -26,6 +26,7 @@ class BaseModel(torch.nn.Module): self.adm_channels = unet_config.get("adm_in_channels", None) if self.adm_channels is None: self.adm_channels = 0 + self.inpaint_model = False print("model_type", model_type.name) print("adm", self.adm_channels) @@ -71,6 +72,37 @@ class BaseModel(torch.nn.Module): def encode_adm(self, **kwargs): return None + def cond_concat(self, **kwargs): + if self.inpaint_model: + concat_keys = ("mask", "masked_image") + cond_concat = [] + denoise_mask = kwargs.get("denoise_mask", None) + latent_image = kwargs.get("latent_image", None) + noise = kwargs.get("noise", None) + + def blank_inpaint_image_like(latent_image): + blank_image = torch.ones_like(latent_image) + # these are the values for "zero" in pixel space translated to latent space + blank_image[:,0] *= 0.8223 + blank_image[:,1] *= -0.6876 + blank_image[:,2] *= 0.6364 + blank_image[:,3] *= 0.1380 + return blank_image + + for ck in concat_keys: + if denoise_mask is not None: + if ck == "mask": + cond_concat.append(denoise_mask[:,:1]) + elif ck == "masked_image": + cond_concat.append(latent_image) #NOTE: the latent_image should be masked by the mask in pixel space + else: + if ck == "mask": + cond_concat.append(torch.ones_like(noise)[:,:1]) + elif ck == "masked_image": + cond_concat.append(blank_inpaint_image_like(noise)) + return cond_concat + return None + def load_model_weights(self, sd, unet_prefix=""): to_load = {} keys = list(sd.keys()) @@ -112,7 +144,7 @@ class BaseModel(torch.nn.Module): return {**unet_state_dict, **vae_state_dict, **clip_state_dict} def set_inpaint(self): - self.concat_keys = ("mask", "masked_image") + self.inpaint_model = True def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0): adm_inputs = [] diff --git a/comfy/samplers.py b/comfy/samplers.py index e43f7a6f..bb8bfdfa 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -358,15 +358,6 @@ def sgm_scheduler(model, steps): sigs += [0.0] return torch.FloatTensor(sigs) -def blank_inpaint_image_like(latent_image): - blank_image = torch.ones_like(latent_image) - # these are the values for "zero" in pixel space translated to latent space - blank_image[:,0] *= 0.8223 - blank_image[:,1] *= -0.6876 - blank_image[:,2] *= 0.6364 - blank_image[:,3] *= 0.1380 - return blank_image - def get_mask_aabb(masks): if masks.numel() == 0: return torch.zeros((0, 4), device=masks.device, dtype=torch.int) @@ -671,21 +662,10 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed} - cond_concat = None - if hasattr(model, 'concat_keys'): #inpaint - cond_concat = [] - for ck in model.concat_keys: - if denoise_mask is not None: - if ck == "mask": - cond_concat.append(denoise_mask[:,:1]) - elif ck == "masked_image": - cond_concat.append(latent_image) #NOTE: the latent_image should be masked by the mask in pixel space - else: - if ck == "mask": - cond_concat.append(torch.ones_like(noise)[:,:1]) - elif ck == "masked_image": - cond_concat.append(blank_inpaint_image_like(noise)) - extra_args["cond_concat"] = cond_concat + if hasattr(model, 'cond_concat'): + cond_concat = model.cond_concat(noise=noise, latent_image=latent_image, denoise_mask=denoise_mask) + if cond_concat is not None: + extra_args["cond_concat"] = cond_concat samples = sampler.sample(model_wrap, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar) return model.process_latent_out(samples.to(torch.float32))