Refactor cond_concat into model object.

This commit is contained in:
comfyanonymous 2023-10-18 16:48:37 -04:00
parent 0d45a565da
commit 782a24fce6
2 changed files with 37 additions and 25 deletions

View File

@ -26,6 +26,7 @@ class BaseModel(torch.nn.Module):
self.adm_channels = unet_config.get("adm_in_channels", None)
if self.adm_channels is None:
self.adm_channels = 0
self.inpaint_model = False
print("model_type", model_type.name)
print("adm", self.adm_channels)
@ -71,6 +72,37 @@ class BaseModel(torch.nn.Module):
def encode_adm(self, **kwargs):
return None
def cond_concat(self, **kwargs):
if self.inpaint_model:
concat_keys = ("mask", "masked_image")
cond_concat = []
denoise_mask = kwargs.get("denoise_mask", None)
latent_image = kwargs.get("latent_image", None)
noise = kwargs.get("noise", None)
def blank_inpaint_image_like(latent_image):
blank_image = torch.ones_like(latent_image)
# these are the values for "zero" in pixel space translated to latent space
blank_image[:,0] *= 0.8223
blank_image[:,1] *= -0.6876
blank_image[:,2] *= 0.6364
blank_image[:,3] *= 0.1380
return blank_image
for ck in concat_keys:
if denoise_mask is not None:
if ck == "mask":
cond_concat.append(denoise_mask[:,:1])
elif ck == "masked_image":
cond_concat.append(latent_image) #NOTE: the latent_image should be masked by the mask in pixel space
else:
if ck == "mask":
cond_concat.append(torch.ones_like(noise)[:,:1])
elif ck == "masked_image":
cond_concat.append(blank_inpaint_image_like(noise))
return cond_concat
return None
def load_model_weights(self, sd, unet_prefix=""):
to_load = {}
keys = list(sd.keys())
@ -112,7 +144,7 @@ class BaseModel(torch.nn.Module):
return {**unet_state_dict, **vae_state_dict, **clip_state_dict}
def set_inpaint(self):
self.concat_keys = ("mask", "masked_image")
self.inpaint_model = True
def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0):
adm_inputs = []

View File

@ -358,15 +358,6 @@ def sgm_scheduler(model, steps):
sigs += [0.0]
return torch.FloatTensor(sigs)
def blank_inpaint_image_like(latent_image):
blank_image = torch.ones_like(latent_image)
# these are the values for "zero" in pixel space translated to latent space
blank_image[:,0] *= 0.8223
blank_image[:,1] *= -0.6876
blank_image[:,2] *= 0.6364
blank_image[:,3] *= 0.1380
return blank_image
def get_mask_aabb(masks):
if masks.numel() == 0:
return torch.zeros((0, 4), device=masks.device, dtype=torch.int)
@ -671,21 +662,10 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed}
cond_concat = None
if hasattr(model, 'concat_keys'): #inpaint
cond_concat = []
for ck in model.concat_keys:
if denoise_mask is not None:
if ck == "mask":
cond_concat.append(denoise_mask[:,:1])
elif ck == "masked_image":
cond_concat.append(latent_image) #NOTE: the latent_image should be masked by the mask in pixel space
else:
if ck == "mask":
cond_concat.append(torch.ones_like(noise)[:,:1])
elif ck == "masked_image":
cond_concat.append(blank_inpaint_image_like(noise))
extra_args["cond_concat"] = cond_concat
if hasattr(model, 'cond_concat'):
cond_concat = model.cond_concat(noise=noise, latent_image=latent_image, denoise_mask=denoise_mask)
if cond_concat is not None:
extra_args["cond_concat"] = cond_concat
samples = sampler.sample(model_wrap, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
return model.process_latent_out(samples.to(torch.float32))