diff --git a/comfy/model_base.py b/comfy/model_base.py index 7370c19f..9adea9a5 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -60,6 +60,37 @@ class SD21UNCLIP(BaseModel): super().__init__(unet_config, v_prediction) self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**noise_aug_config) + def encode_adm(self, **kwargs): + unclip_conditioning = kwargs.get("unclip_conditioning", None) + device = kwargs["device"] + + if unclip_conditioning is not None: + adm_inputs = [] + weights = [] + noise_aug = [] + for unclip_cond in unclip_conditioning: + adm_cond = unclip_cond["clip_vision_output"].image_embeds + weight = unclip_cond["strength"] + noise_augment = unclip_cond["noise_augmentation"] + noise_level = round((self.noise_augmentor.max_noise_level - 1) * noise_augment) + c_adm, noise_level_emb = self.noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device)) + adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight + weights.append(weight) + noise_aug.append(noise_augment) + adm_inputs.append(adm_out) + + if len(noise_aug) > 1: + adm_out = torch.stack(adm_inputs).sum(0) + #TODO: add a way to control this + noise_augment = 0.05 + noise_level = round((self.noise_augmentor.max_noise_level - 1) * noise_augment) + c_adm, noise_level_emb = self.noise_augmentor(adm_out[:, :self.noise_augmentor.time_embed.dim], noise_level=torch.tensor([noise_level], device=device)) + adm_out = torch.cat((c_adm, noise_level_emb), 1) + else: + adm_out = torch.zeros((1, self.adm_channels)) + + return adm_out + class SDInpaint(BaseModel): def __init__(self, unet_config, v_prediction=False): super().__init__(unet_config, v_prediction) diff --git a/comfy/samplers.py b/comfy/samplers.py index a33d150d..d3cd901e 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -460,42 +460,18 @@ def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func): uncond[temp[1]] = [o[0], n] -def encode_adm(conds, batch_size, device, noise_augmentor=None): +def encode_adm(model, conds, batch_size, device): for t in range(len(conds)): x = conds[t] adm_out = None - if noise_augmentor is not None: - if 'adm' in x[1]: - adm_inputs = [] - weights = [] - noise_aug = [] - adm_in = x[1]["adm"] - for adm_c in adm_in: - adm_cond = adm_c[0].image_embeds - weight = adm_c[1] - noise_augment = adm_c[2] - noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment) - c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device)) - adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight - weights.append(weight) - noise_aug.append(noise_augment) - adm_inputs.append(adm_out) - - if len(noise_aug) > 1: - adm_out = torch.stack(adm_inputs).sum(0) - #TODO: add a way to control this - noise_augment = 0.05 - noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment) - c_adm, noise_level_emb = noise_augmentor(adm_out[:, :noise_augmentor.time_embed.dim], noise_level=torch.tensor([noise_level], device=device)) - adm_out = torch.cat((c_adm, noise_level_emb), 1) - else: - adm_out = torch.zeros((1, noise_augmentor.time_embed.dim * 2), device=device) + if 'adm' in x[1]: + adm_out = x[1]["adm"] else: - if 'adm' in x[1]: - adm_out = x[1]["adm"].to(device) + params = x[1].copy() + adm_out = model.encode_adm(device=device, **params) if adm_out is not None: x[1] = x[1].copy() - x[1]["adm_encoded"] = torch.cat([adm_out] * batch_size) + x[1]["adm_encoded"] = torch.cat([adm_out] * batch_size).to(device) return conds @@ -603,11 +579,8 @@ class KSampler: precision_scope = contextlib.nullcontext if self.model.is_adm(): - noise_augmentor = None - if hasattr(self.model, 'noise_augmentor'): #unclip - noise_augmentor = self.model.noise_augmentor - positive = encode_adm(positive, noise.shape[0], self.device, noise_augmentor) - negative = encode_adm(negative, noise.shape[0], self.device, noise_augmentor) + positive = encode_adm(self.model, positive, noise.shape[0], self.device) + negative = encode_adm(self.model, negative, noise.shape[0], self.device) extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": self.model_options} diff --git a/nodes.py b/nodes.py index b057504e..12243fab 100644 --- a/nodes.py +++ b/nodes.py @@ -623,11 +623,11 @@ class unCLIPConditioning: c = [] for t in conditioning: o = t[1].copy() - x = (clip_vision_output, strength, noise_augmentation) - if "adm" in o: - o["adm"] = o["adm"][:] + [x] + x = {"clip_vision_output": clip_vision_output, "strength": strength, "noise_augmentation": noise_augmentation} + if "unclip_conditioning" in o: + o["unclip_conditioning"] = o["unclip_conditioning"][:] + [x] else: - o["adm"] = [x] + o["unclip_conditioning"] = [x] n = [t[0], o] c.append(n) return (c, )