diff --git a/comfy/model_base.py b/comfy/model_base.py index c3c807a6..ad661ec7 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -105,6 +105,29 @@ class BaseModel(torch.nn.Module): return {**unet_state_dict, **vae_state_dict, **clip_state_dict} +def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0): + adm_inputs = [] + weights = [] + noise_aug = [] + for unclip_cond in unclip_conditioning: + for adm_cond in unclip_cond["clip_vision_output"].image_embeds: + weight = unclip_cond["strength"] + noise_augment = unclip_cond["noise_augmentation"] + noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment) + c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device)) + adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight + weights.append(weight) + noise_aug.append(noise_augment) + adm_inputs.append(adm_out) + + if len(noise_aug) > 1: + adm_out = torch.stack(adm_inputs).sum(0) + noise_augment = noise_augment_merge + noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment) + c_adm, noise_level_emb = noise_augmentor(adm_out[:, :noise_augmentor.time_embed.dim], noise_level=torch.tensor([noise_level], device=device)) + adm_out = torch.cat((c_adm, noise_level_emb), 1) + + return adm_out class SD21UNCLIP(BaseModel): def __init__(self, model_config, noise_aug_config, model_type=ModelType.V_PREDICTION, device=None): @@ -114,33 +137,11 @@ class SD21UNCLIP(BaseModel): def encode_adm(self, **kwargs): unclip_conditioning = kwargs.get("unclip_conditioning", None) device = kwargs["device"] - - if unclip_conditioning is not None: - adm_inputs = [] - weights = [] - noise_aug = [] - for unclip_cond in unclip_conditioning: - for adm_cond in unclip_cond["clip_vision_output"].image_embeds: - weight = unclip_cond["strength"] - noise_augment = unclip_cond["noise_augmentation"] - noise_level = round((self.noise_augmentor.max_noise_level - 1) * noise_augment) - c_adm, noise_level_emb = self.noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device)) - adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight - weights.append(weight) - noise_aug.append(noise_augment) - adm_inputs.append(adm_out) - - if len(noise_aug) > 1: - adm_out = torch.stack(adm_inputs).sum(0) - #TODO: add a way to control this - noise_augment = 0.05 - noise_level = round((self.noise_augmentor.max_noise_level - 1) * noise_augment) - c_adm, noise_level_emb = self.noise_augmentor(adm_out[:, :self.noise_augmentor.time_embed.dim], noise_level=torch.tensor([noise_level], device=device)) - adm_out = torch.cat((c_adm, noise_level_emb), 1) + if unclip_conditioning is None: + return torch.zeros((1, self.adm_channels)) else: - adm_out = torch.zeros((1, self.adm_channels)) + return unclip_adm(unclip_conditioning, device, self.noise_augmentor, kwargs.get("unclip_noise_augment_merge", 0.05)) - return adm_out class SDInpaint(BaseModel): def __init__(self, model_config, model_type=ModelType.EPS, device=None):