mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Refactor cond_concat into model object.
This commit is contained in:
parent
0d45a565da
commit
782a24fce6
@ -26,6 +26,7 @@ class BaseModel(torch.nn.Module):
|
|||||||
self.adm_channels = unet_config.get("adm_in_channels", None)
|
self.adm_channels = unet_config.get("adm_in_channels", None)
|
||||||
if self.adm_channels is None:
|
if self.adm_channels is None:
|
||||||
self.adm_channels = 0
|
self.adm_channels = 0
|
||||||
|
self.inpaint_model = False
|
||||||
print("model_type", model_type.name)
|
print("model_type", model_type.name)
|
||||||
print("adm", self.adm_channels)
|
print("adm", self.adm_channels)
|
||||||
|
|
||||||
@ -71,6 +72,37 @@ class BaseModel(torch.nn.Module):
|
|||||||
def encode_adm(self, **kwargs):
|
def encode_adm(self, **kwargs):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def cond_concat(self, **kwargs):
|
||||||
|
if self.inpaint_model:
|
||||||
|
concat_keys = ("mask", "masked_image")
|
||||||
|
cond_concat = []
|
||||||
|
denoise_mask = kwargs.get("denoise_mask", None)
|
||||||
|
latent_image = kwargs.get("latent_image", None)
|
||||||
|
noise = kwargs.get("noise", None)
|
||||||
|
|
||||||
|
def blank_inpaint_image_like(latent_image):
|
||||||
|
blank_image = torch.ones_like(latent_image)
|
||||||
|
# these are the values for "zero" in pixel space translated to latent space
|
||||||
|
blank_image[:,0] *= 0.8223
|
||||||
|
blank_image[:,1] *= -0.6876
|
||||||
|
blank_image[:,2] *= 0.6364
|
||||||
|
blank_image[:,3] *= 0.1380
|
||||||
|
return blank_image
|
||||||
|
|
||||||
|
for ck in concat_keys:
|
||||||
|
if denoise_mask is not None:
|
||||||
|
if ck == "mask":
|
||||||
|
cond_concat.append(denoise_mask[:,:1])
|
||||||
|
elif ck == "masked_image":
|
||||||
|
cond_concat.append(latent_image) #NOTE: the latent_image should be masked by the mask in pixel space
|
||||||
|
else:
|
||||||
|
if ck == "mask":
|
||||||
|
cond_concat.append(torch.ones_like(noise)[:,:1])
|
||||||
|
elif ck == "masked_image":
|
||||||
|
cond_concat.append(blank_inpaint_image_like(noise))
|
||||||
|
return cond_concat
|
||||||
|
return None
|
||||||
|
|
||||||
def load_model_weights(self, sd, unet_prefix=""):
|
def load_model_weights(self, sd, unet_prefix=""):
|
||||||
to_load = {}
|
to_load = {}
|
||||||
keys = list(sd.keys())
|
keys = list(sd.keys())
|
||||||
@ -112,7 +144,7 @@ class BaseModel(torch.nn.Module):
|
|||||||
return {**unet_state_dict, **vae_state_dict, **clip_state_dict}
|
return {**unet_state_dict, **vae_state_dict, **clip_state_dict}
|
||||||
|
|
||||||
def set_inpaint(self):
|
def set_inpaint(self):
|
||||||
self.concat_keys = ("mask", "masked_image")
|
self.inpaint_model = True
|
||||||
|
|
||||||
def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0):
|
def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0):
|
||||||
adm_inputs = []
|
adm_inputs = []
|
||||||
|
@ -358,15 +358,6 @@ def sgm_scheduler(model, steps):
|
|||||||
sigs += [0.0]
|
sigs += [0.0]
|
||||||
return torch.FloatTensor(sigs)
|
return torch.FloatTensor(sigs)
|
||||||
|
|
||||||
def blank_inpaint_image_like(latent_image):
|
|
||||||
blank_image = torch.ones_like(latent_image)
|
|
||||||
# these are the values for "zero" in pixel space translated to latent space
|
|
||||||
blank_image[:,0] *= 0.8223
|
|
||||||
blank_image[:,1] *= -0.6876
|
|
||||||
blank_image[:,2] *= 0.6364
|
|
||||||
blank_image[:,3] *= 0.1380
|
|
||||||
return blank_image
|
|
||||||
|
|
||||||
def get_mask_aabb(masks):
|
def get_mask_aabb(masks):
|
||||||
if masks.numel() == 0:
|
if masks.numel() == 0:
|
||||||
return torch.zeros((0, 4), device=masks.device, dtype=torch.int)
|
return torch.zeros((0, 4), device=masks.device, dtype=torch.int)
|
||||||
@ -671,20 +662,9 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
|
|||||||
|
|
||||||
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed}
|
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed}
|
||||||
|
|
||||||
cond_concat = None
|
if hasattr(model, 'cond_concat'):
|
||||||
if hasattr(model, 'concat_keys'): #inpaint
|
cond_concat = model.cond_concat(noise=noise, latent_image=latent_image, denoise_mask=denoise_mask)
|
||||||
cond_concat = []
|
if cond_concat is not None:
|
||||||
for ck in model.concat_keys:
|
|
||||||
if denoise_mask is not None:
|
|
||||||
if ck == "mask":
|
|
||||||
cond_concat.append(denoise_mask[:,:1])
|
|
||||||
elif ck == "masked_image":
|
|
||||||
cond_concat.append(latent_image) #NOTE: the latent_image should be masked by the mask in pixel space
|
|
||||||
else:
|
|
||||||
if ck == "mask":
|
|
||||||
cond_concat.append(torch.ones_like(noise)[:,:1])
|
|
||||||
elif ck == "masked_image":
|
|
||||||
cond_concat.append(blank_inpaint_image_like(noise))
|
|
||||||
extra_args["cond_concat"] = cond_concat
|
extra_args["cond_concat"] = cond_concat
|
||||||
|
|
||||||
samples = sampler.sample(model_wrap, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
|
samples = sampler.sample(model_wrap, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
|
||||||
|
Loading…
Reference in New Issue
Block a user