mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Add support for unCLIP SD2.x models.
See _for_testing/unclip in the UI for the new nodes. unCLIPCheckpointLoader is used to load them. unCLIPConditioning is used to add the image cond and takes as input a CLIPVisionEncode output which has been moved to the conditioning section.
This commit is contained in:
parent
0d972b85e6
commit
809bcc8ceb
62
comfy/clip_vision.py
Normal file
62
comfy/clip_vision.py
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, CLIPImageProcessor
|
||||||
|
from .utils import load_torch_file, transformers_convert
|
||||||
|
import os
|
||||||
|
|
||||||
|
class ClipVisionModel():
|
||||||
|
def __init__(self, json_config):
|
||||||
|
config = CLIPVisionConfig.from_json_file(json_config)
|
||||||
|
self.model = CLIPVisionModelWithProjection(config)
|
||||||
|
self.processor = CLIPImageProcessor(crop_size=224,
|
||||||
|
do_center_crop=True,
|
||||||
|
do_convert_rgb=True,
|
||||||
|
do_normalize=True,
|
||||||
|
do_resize=True,
|
||||||
|
image_mean=[ 0.48145466,0.4578275,0.40821073],
|
||||||
|
image_std=[0.26862954,0.26130258,0.27577711],
|
||||||
|
resample=3, #bicubic
|
||||||
|
size=224)
|
||||||
|
|
||||||
|
def load_sd(self, sd):
|
||||||
|
self.model.load_state_dict(sd, strict=False)
|
||||||
|
|
||||||
|
def encode_image(self, image):
|
||||||
|
inputs = self.processor(images=[image[0]], return_tensors="pt")
|
||||||
|
outputs = self.model(**inputs)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def convert_to_transformers(sd):
|
||||||
|
sd_k = sd.keys()
|
||||||
|
if "embedder.model.visual.transformer.resblocks.0.attn.in_proj_weight" in sd_k:
|
||||||
|
keys_to_replace = {
|
||||||
|
"embedder.model.visual.class_embedding": "vision_model.embeddings.class_embedding",
|
||||||
|
"embedder.model.visual.conv1.weight": "vision_model.embeddings.patch_embedding.weight",
|
||||||
|
"embedder.model.visual.positional_embedding": "vision_model.embeddings.position_embedding.weight",
|
||||||
|
"embedder.model.visual.ln_post.bias": "vision_model.post_layernorm.bias",
|
||||||
|
"embedder.model.visual.ln_post.weight": "vision_model.post_layernorm.weight",
|
||||||
|
"embedder.model.visual.ln_pre.bias": "vision_model.pre_layrnorm.bias",
|
||||||
|
"embedder.model.visual.ln_pre.weight": "vision_model.pre_layrnorm.weight",
|
||||||
|
}
|
||||||
|
|
||||||
|
for x in keys_to_replace:
|
||||||
|
if x in sd_k:
|
||||||
|
sd[keys_to_replace[x]] = sd.pop(x)
|
||||||
|
|
||||||
|
if "embedder.model.visual.proj" in sd_k:
|
||||||
|
sd['visual_projection.weight'] = sd.pop("embedder.model.visual.proj").transpose(0, 1)
|
||||||
|
|
||||||
|
sd = transformers_convert(sd, "embedder.model.visual", "vision_model", 32)
|
||||||
|
return sd
|
||||||
|
|
||||||
|
def load_clipvision_from_sd(sd):
|
||||||
|
sd = convert_to_transformers(sd)
|
||||||
|
if "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
|
||||||
|
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")
|
||||||
|
else:
|
||||||
|
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
|
||||||
|
clip = ClipVisionModel(json_config)
|
||||||
|
clip.load_sd(sd)
|
||||||
|
return clip
|
||||||
|
|
||||||
|
def load(ckpt_path):
|
||||||
|
sd = load_torch_file(ckpt_path)
|
||||||
|
return load_clipvision_from_sd(sd)
|
18
comfy/clip_vision_config_h.json
Normal file
18
comfy/clip_vision_config_h.json
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
{
|
||||||
|
"attention_dropout": 0.0,
|
||||||
|
"dropout": 0.0,
|
||||||
|
"hidden_act": "gelu",
|
||||||
|
"hidden_size": 1280,
|
||||||
|
"image_size": 224,
|
||||||
|
"initializer_factor": 1.0,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"intermediate_size": 5120,
|
||||||
|
"layer_norm_eps": 1e-05,
|
||||||
|
"model_type": "clip_vision_model",
|
||||||
|
"num_attention_heads": 16,
|
||||||
|
"num_channels": 3,
|
||||||
|
"num_hidden_layers": 32,
|
||||||
|
"patch_size": 14,
|
||||||
|
"projection_dim": 1024,
|
||||||
|
"torch_dtype": "float32"
|
||||||
|
}
|
@ -1,8 +1,4 @@
|
|||||||
{
|
{
|
||||||
"_name_or_path": "openai/clip-vit-large-patch14",
|
|
||||||
"architectures": [
|
|
||||||
"CLIPVisionModel"
|
|
||||||
],
|
|
||||||
"attention_dropout": 0.0,
|
"attention_dropout": 0.0,
|
||||||
"dropout": 0.0,
|
"dropout": 0.0,
|
||||||
"hidden_act": "quick_gelu",
|
"hidden_act": "quick_gelu",
|
||||||
@ -18,6 +14,5 @@
|
|||||||
"num_hidden_layers": 24,
|
"num_hidden_layers": 24,
|
||||||
"patch_size": 14,
|
"patch_size": 14,
|
||||||
"projection_dim": 768,
|
"projection_dim": 768,
|
||||||
"torch_dtype": "float32",
|
"torch_dtype": "float32"
|
||||||
"transformers_version": "4.24.0"
|
|
||||||
}
|
}
|
@ -1801,3 +1801,75 @@ class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion):
|
|||||||
log = super().log_images(*args, **kwargs)
|
log = super().log_images(*args, **kwargs)
|
||||||
log["lr"] = rearrange(args[0]["lr"], 'b h w c -> b c h w')
|
log["lr"] = rearrange(args[0]["lr"], 'b h w c -> b c h w')
|
||||||
return log
|
return log
|
||||||
|
|
||||||
|
|
||||||
|
class ImageEmbeddingConditionedLatentDiffusion(LatentDiffusion):
|
||||||
|
def __init__(self, embedder_config=None, embedding_key="jpg", embedding_dropout=0.5,
|
||||||
|
freeze_embedder=True, noise_aug_config=None, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.embed_key = embedding_key
|
||||||
|
self.embedding_dropout = embedding_dropout
|
||||||
|
# self._init_embedder(embedder_config, freeze_embedder)
|
||||||
|
self._init_noise_aug(noise_aug_config)
|
||||||
|
|
||||||
|
def _init_embedder(self, config, freeze=True):
|
||||||
|
embedder = instantiate_from_config(config)
|
||||||
|
if freeze:
|
||||||
|
self.embedder = embedder.eval()
|
||||||
|
self.embedder.train = disabled_train
|
||||||
|
for param in self.embedder.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
def _init_noise_aug(self, config):
|
||||||
|
if config is not None:
|
||||||
|
# use the KARLO schedule for noise augmentation on CLIP image embeddings
|
||||||
|
noise_augmentor = instantiate_from_config(config)
|
||||||
|
assert isinstance(noise_augmentor, nn.Module)
|
||||||
|
noise_augmentor = noise_augmentor.eval()
|
||||||
|
noise_augmentor.train = disabled_train
|
||||||
|
self.noise_augmentor = noise_augmentor
|
||||||
|
else:
|
||||||
|
self.noise_augmentor = None
|
||||||
|
|
||||||
|
def get_input(self, batch, k, cond_key=None, bs=None, **kwargs):
|
||||||
|
outputs = LatentDiffusion.get_input(self, batch, k, bs=bs, **kwargs)
|
||||||
|
z, c = outputs[0], outputs[1]
|
||||||
|
img = batch[self.embed_key][:bs]
|
||||||
|
img = rearrange(img, 'b h w c -> b c h w')
|
||||||
|
c_adm = self.embedder(img)
|
||||||
|
if self.noise_augmentor is not None:
|
||||||
|
c_adm, noise_level_emb = self.noise_augmentor(c_adm)
|
||||||
|
# assume this gives embeddings of noise levels
|
||||||
|
c_adm = torch.cat((c_adm, noise_level_emb), 1)
|
||||||
|
if self.training:
|
||||||
|
c_adm = torch.bernoulli((1. - self.embedding_dropout) * torch.ones(c_adm.shape[0],
|
||||||
|
device=c_adm.device)[:, None]) * c_adm
|
||||||
|
all_conds = {"c_crossattn": [c], "c_adm": c_adm}
|
||||||
|
noutputs = [z, all_conds]
|
||||||
|
noutputs.extend(outputs[2:])
|
||||||
|
return noutputs
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def log_images(self, batch, N=8, n_row=4, **kwargs):
|
||||||
|
log = dict()
|
||||||
|
z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, bs=N, return_first_stage_outputs=True,
|
||||||
|
return_original_cond=True)
|
||||||
|
log["inputs"] = x
|
||||||
|
log["reconstruction"] = xrec
|
||||||
|
assert self.model.conditioning_key is not None
|
||||||
|
assert self.cond_stage_key in ["caption", "txt"]
|
||||||
|
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
|
||||||
|
log["conditioning"] = xc
|
||||||
|
uc = self.get_unconditional_conditioning(N, kwargs.get('unconditional_guidance_label', ''))
|
||||||
|
unconditional_guidance_scale = kwargs.get('unconditional_guidance_scale', 5.)
|
||||||
|
|
||||||
|
uc_ = {"c_crossattn": [uc], "c_adm": c["c_adm"]}
|
||||||
|
ema_scope = self.ema_scope if kwargs.get('use_ema_scope', True) else nullcontext
|
||||||
|
with ema_scope(f"Sampling"):
|
||||||
|
samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=True,
|
||||||
|
ddim_steps=kwargs.get('ddim_steps', 50), eta=kwargs.get('ddim_eta', 0.),
|
||||||
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||||
|
unconditional_conditioning=uc_, )
|
||||||
|
x_samples_cfg = self.decode_first_stage(samples_cfg)
|
||||||
|
log[f"samplescfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
|
||||||
|
return log
|
||||||
|
@ -307,7 +307,16 @@ def model_wrapper(
|
|||||||
else:
|
else:
|
||||||
x_in = torch.cat([x] * 2)
|
x_in = torch.cat([x] * 2)
|
||||||
t_in = torch.cat([t_continuous] * 2)
|
t_in = torch.cat([t_continuous] * 2)
|
||||||
c_in = torch.cat([unconditional_condition, condition])
|
if isinstance(condition, dict):
|
||||||
|
assert isinstance(unconditional_condition, dict)
|
||||||
|
c_in = dict()
|
||||||
|
for k in condition:
|
||||||
|
if isinstance(condition[k], list):
|
||||||
|
c_in[k] = [torch.cat([unconditional_condition[k][i], condition[k][i]]) for i in range(len(condition[k]))]
|
||||||
|
else:
|
||||||
|
c_in[k] = torch.cat([unconditional_condition[k], condition[k]])
|
||||||
|
else:
|
||||||
|
c_in = torch.cat([unconditional_condition, condition])
|
||||||
noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
|
noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
|
||||||
return noise_uncond + guidance_scale * (noise - noise_uncond)
|
return noise_uncond + guidance_scale * (noise - noise_uncond)
|
||||||
|
|
||||||
|
@ -3,7 +3,6 @@ import torch
|
|||||||
|
|
||||||
from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
|
from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
|
||||||
|
|
||||||
|
|
||||||
MODEL_TYPES = {
|
MODEL_TYPES = {
|
||||||
"eps": "noise",
|
"eps": "noise",
|
||||||
"v": "v"
|
"v": "v"
|
||||||
@ -51,12 +50,20 @@ class DPMSolverSampler(object):
|
|||||||
):
|
):
|
||||||
if conditioning is not None:
|
if conditioning is not None:
|
||||||
if isinstance(conditioning, dict):
|
if isinstance(conditioning, dict):
|
||||||
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
|
ctmp = conditioning[list(conditioning.keys())[0]]
|
||||||
if cbs != batch_size:
|
while isinstance(ctmp, list): ctmp = ctmp[0]
|
||||||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
if isinstance(ctmp, torch.Tensor):
|
||||||
|
cbs = ctmp.shape[0]
|
||||||
|
if cbs != batch_size:
|
||||||
|
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||||
|
elif isinstance(conditioning, list):
|
||||||
|
for ctmp in conditioning:
|
||||||
|
if ctmp.shape[0] != batch_size:
|
||||||
|
print(f"Warning: Got {ctmp.shape[0]} conditionings but batch-size is {batch_size}")
|
||||||
else:
|
else:
|
||||||
if conditioning.shape[0] != batch_size:
|
if isinstance(conditioning, torch.Tensor):
|
||||||
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
if conditioning.shape[0] != batch_size:
|
||||||
|
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
||||||
|
|
||||||
# sampling
|
# sampling
|
||||||
C, H, W = shape
|
C, H, W = shape
|
||||||
@ -83,6 +90,7 @@ class DPMSolverSampler(object):
|
|||||||
)
|
)
|
||||||
|
|
||||||
dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
|
dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
|
||||||
x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True)
|
x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2,
|
||||||
|
lower_order_final=True)
|
||||||
|
|
||||||
return x.to(device), None
|
return x.to(device), None
|
@ -409,6 +409,15 @@ class QKVAttention(nn.Module):
|
|||||||
return count_flops_attn(model, _x, y)
|
return count_flops_attn(model, _x, y)
|
||||||
|
|
||||||
|
|
||||||
|
class Timestep(nn.Module):
|
||||||
|
def __init__(self, dim):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
|
||||||
|
def forward(self, t):
|
||||||
|
return timestep_embedding(t, self.dim)
|
||||||
|
|
||||||
|
|
||||||
class UNetModel(nn.Module):
|
class UNetModel(nn.Module):
|
||||||
"""
|
"""
|
||||||
The full UNet model with attention and timestep embedding.
|
The full UNet model with attention and timestep embedding.
|
||||||
@ -470,6 +479,7 @@ class UNetModel(nn.Module):
|
|||||||
num_attention_blocks=None,
|
num_attention_blocks=None,
|
||||||
disable_middle_self_attn=False,
|
disable_middle_self_attn=False,
|
||||||
use_linear_in_transformer=False,
|
use_linear_in_transformer=False,
|
||||||
|
adm_in_channels=None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if use_spatial_transformer:
|
if use_spatial_transformer:
|
||||||
@ -538,6 +548,15 @@ class UNetModel(nn.Module):
|
|||||||
elif self.num_classes == "continuous":
|
elif self.num_classes == "continuous":
|
||||||
print("setting up linear c_adm embedding layer")
|
print("setting up linear c_adm embedding layer")
|
||||||
self.label_emb = nn.Linear(1, time_embed_dim)
|
self.label_emb = nn.Linear(1, time_embed_dim)
|
||||||
|
elif self.num_classes == "sequential":
|
||||||
|
assert adm_in_channels is not None
|
||||||
|
self.label_emb = nn.Sequential(
|
||||||
|
nn.Sequential(
|
||||||
|
linear(adm_in_channels, time_embed_dim),
|
||||||
|
nn.SiLU(),
|
||||||
|
linear(time_embed_dim, time_embed_dim),
|
||||||
|
)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError()
|
raise ValueError()
|
||||||
|
|
||||||
|
@ -34,6 +34,13 @@ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2,
|
|||||||
betas = 1 - alphas[1:] / alphas[:-1]
|
betas = 1 - alphas[1:] / alphas[:-1]
|
||||||
betas = np.clip(betas, a_min=0, a_max=0.999)
|
betas = np.clip(betas, a_min=0, a_max=0.999)
|
||||||
|
|
||||||
|
elif schedule == "squaredcos_cap_v2": # used for karlo prior
|
||||||
|
# return early
|
||||||
|
return betas_for_alpha_bar(
|
||||||
|
n_timestep,
|
||||||
|
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
|
||||||
|
)
|
||||||
|
|
||||||
elif schedule == "sqrt_linear":
|
elif schedule == "sqrt_linear":
|
||||||
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
|
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
|
||||||
elif schedule == "sqrt":
|
elif schedule == "sqrt":
|
||||||
@ -218,6 +225,7 @@ class GroupNorm32(nn.GroupNorm):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return super().forward(x.float()).type(x.dtype)
|
return super().forward(x.float()).type(x.dtype)
|
||||||
|
|
||||||
|
|
||||||
def conv_nd(dims, *args, **kwargs):
|
def conv_nd(dims, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Create a 1D, 2D, or 3D convolution module.
|
Create a 1D, 2D, or 3D convolution module.
|
||||||
|
59
comfy/ldm/modules/encoders/kornia_functions.py
Normal file
59
comfy/ldm/modules/encoders/kornia_functions.py
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
|
||||||
|
|
||||||
|
from typing import List, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
#from: https://github.com/kornia/kornia/blob/master/kornia/enhance/normalize.py
|
||||||
|
|
||||||
|
def enhance_normalize(data: torch.Tensor, mean: torch.Tensor, std: torch.Tensor) -> torch.Tensor:
|
||||||
|
r"""Normalize an image/video tensor with mean and standard deviation.
|
||||||
|
.. math::
|
||||||
|
\text{input[channel] = (input[channel] - mean[channel]) / std[channel]}
|
||||||
|
Where `mean` is :math:`(M_1, ..., M_n)` and `std` :math:`(S_1, ..., S_n)` for `n` channels,
|
||||||
|
Args:
|
||||||
|
data: Image tensor of size :math:`(B, C, *)`.
|
||||||
|
mean: Mean for each channel.
|
||||||
|
std: Standard deviations for each channel.
|
||||||
|
Return:
|
||||||
|
Normalised tensor with same size as input :math:`(B, C, *)`.
|
||||||
|
Examples:
|
||||||
|
>>> x = torch.rand(1, 4, 3, 3)
|
||||||
|
>>> out = normalize(x, torch.tensor([0.0]), torch.tensor([255.]))
|
||||||
|
>>> out.shape
|
||||||
|
torch.Size([1, 4, 3, 3])
|
||||||
|
>>> x = torch.rand(1, 4, 3, 3)
|
||||||
|
>>> mean = torch.zeros(4)
|
||||||
|
>>> std = 255. * torch.ones(4)
|
||||||
|
>>> out = normalize(x, mean, std)
|
||||||
|
>>> out.shape
|
||||||
|
torch.Size([1, 4, 3, 3])
|
||||||
|
"""
|
||||||
|
shape = data.shape
|
||||||
|
if len(mean.shape) == 0 or mean.shape[0] == 1:
|
||||||
|
mean = mean.expand(shape[1])
|
||||||
|
if len(std.shape) == 0 or std.shape[0] == 1:
|
||||||
|
std = std.expand(shape[1])
|
||||||
|
|
||||||
|
# Allow broadcast on channel dimension
|
||||||
|
if mean.shape and mean.shape[0] != 1:
|
||||||
|
if mean.shape[0] != data.shape[1] and mean.shape[:2] != data.shape[:2]:
|
||||||
|
raise ValueError(f"mean length and number of channels do not match. Got {mean.shape} and {data.shape}.")
|
||||||
|
|
||||||
|
# Allow broadcast on channel dimension
|
||||||
|
if std.shape and std.shape[0] != 1:
|
||||||
|
if std.shape[0] != data.shape[1] and std.shape[:2] != data.shape[:2]:
|
||||||
|
raise ValueError(f"std length and number of channels do not match. Got {std.shape} and {data.shape}.")
|
||||||
|
|
||||||
|
mean = torch.as_tensor(mean, device=data.device, dtype=data.dtype)
|
||||||
|
std = torch.as_tensor(std, device=data.device, dtype=data.dtype)
|
||||||
|
|
||||||
|
if mean.shape:
|
||||||
|
mean = mean[..., :, None]
|
||||||
|
if std.shape:
|
||||||
|
std = std[..., :, None]
|
||||||
|
|
||||||
|
out: torch.Tensor = (data.view(shape[0], shape[1], -1) - mean) / std
|
||||||
|
|
||||||
|
return out.view(shape)
|
@ -1,5 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from . import kornia_functions
|
||||||
from torch.utils.checkpoint import checkpoint
|
from torch.utils.checkpoint import checkpoint
|
||||||
|
|
||||||
from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
|
from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
|
||||||
@ -37,7 +38,7 @@ class ClassEmbedder(nn.Module):
|
|||||||
c = batch[key][:, None]
|
c = batch[key][:, None]
|
||||||
if self.ucg_rate > 0. and not disable_dropout:
|
if self.ucg_rate > 0. and not disable_dropout:
|
||||||
mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
|
mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
|
||||||
c = mask * c + (1-mask) * torch.ones_like(c)*(self.n_classes-1)
|
c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes - 1)
|
||||||
c = c.long()
|
c = c.long()
|
||||||
c = self.embedding(c)
|
c = self.embedding(c)
|
||||||
return c
|
return c
|
||||||
@ -57,18 +58,20 @@ def disabled_train(self, mode=True):
|
|||||||
|
|
||||||
class FrozenT5Embedder(AbstractEncoder):
|
class FrozenT5Embedder(AbstractEncoder):
|
||||||
"""Uses the T5 transformer encoder for text"""
|
"""Uses the T5 transformer encoder for text"""
|
||||||
def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
|
|
||||||
|
def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77,
|
||||||
|
freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.tokenizer = T5Tokenizer.from_pretrained(version)
|
self.tokenizer = T5Tokenizer.from_pretrained(version)
|
||||||
self.transformer = T5EncoderModel.from_pretrained(version)
|
self.transformer = T5EncoderModel.from_pretrained(version)
|
||||||
self.device = device
|
self.device = device
|
||||||
self.max_length = max_length # TODO: typical value?
|
self.max_length = max_length # TODO: typical value?
|
||||||
if freeze:
|
if freeze:
|
||||||
self.freeze()
|
self.freeze()
|
||||||
|
|
||||||
def freeze(self):
|
def freeze(self):
|
||||||
self.transformer = self.transformer.eval()
|
self.transformer = self.transformer.eval()
|
||||||
#self.train = disabled_train
|
# self.train = disabled_train
|
||||||
for param in self.parameters():
|
for param in self.parameters():
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
|
|
||||||
@ -92,6 +95,7 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
|||||||
"pooled",
|
"pooled",
|
||||||
"hidden"
|
"hidden"
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77,
|
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77,
|
||||||
freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32
|
freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -110,7 +114,7 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
|||||||
|
|
||||||
def freeze(self):
|
def freeze(self):
|
||||||
self.transformer = self.transformer.eval()
|
self.transformer = self.transformer.eval()
|
||||||
#self.train = disabled_train
|
# self.train = disabled_train
|
||||||
for param in self.parameters():
|
for param in self.parameters():
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
|
|
||||||
@ -118,7 +122,7 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
|||||||
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
|
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
|
||||||
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
||||||
tokens = batch_encoding["input_ids"].to(self.device)
|
tokens = batch_encoding["input_ids"].to(self.device)
|
||||||
outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
|
outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer == "hidden")
|
||||||
if self.layer == "last":
|
if self.layer == "last":
|
||||||
z = outputs.last_hidden_state
|
z = outputs.last_hidden_state
|
||||||
elif self.layer == "pooled":
|
elif self.layer == "pooled":
|
||||||
@ -131,15 +135,55 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
|||||||
return self(text)
|
return self(text)
|
||||||
|
|
||||||
|
|
||||||
|
class ClipImageEmbedder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model,
|
||||||
|
jit=False,
|
||||||
|
device='cuda' if torch.cuda.is_available() else 'cpu',
|
||||||
|
antialias=True,
|
||||||
|
ucg_rate=0.
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
from clip import load as load_clip
|
||||||
|
self.model, _ = load_clip(name=model, device=device, jit=jit)
|
||||||
|
|
||||||
|
self.antialias = antialias
|
||||||
|
|
||||||
|
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
|
||||||
|
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
|
||||||
|
self.ucg_rate = ucg_rate
|
||||||
|
|
||||||
|
def preprocess(self, x):
|
||||||
|
# normalize to [0,1]
|
||||||
|
# x = kornia_functions.geometry_resize(x, (224, 224),
|
||||||
|
# interpolation='bicubic', align_corners=True,
|
||||||
|
# antialias=self.antialias)
|
||||||
|
x = torch.nn.functional.interpolate(x, size=(224, 224), mode='bicubic', align_corners=True, antialias=True)
|
||||||
|
x = (x + 1.) / 2.
|
||||||
|
# re-normalize according to clip
|
||||||
|
x = kornia_functions.enhance_normalize(x, self.mean, self.std)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x, no_dropout=False):
|
||||||
|
# x is assumed to be in range [-1,1]
|
||||||
|
out = self.model.encode_image(self.preprocess(x))
|
||||||
|
out = out.to(x.dtype)
|
||||||
|
if self.ucg_rate > 0. and not no_dropout:
|
||||||
|
out = torch.bernoulli((1. - self.ucg_rate) * torch.ones(out.shape[0], device=out.device))[:, None] * out
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
class FrozenOpenCLIPEmbedder(AbstractEncoder):
|
class FrozenOpenCLIPEmbedder(AbstractEncoder):
|
||||||
"""
|
"""
|
||||||
Uses the OpenCLIP transformer encoder for text
|
Uses the OpenCLIP transformer encoder for text
|
||||||
"""
|
"""
|
||||||
LAYERS = [
|
LAYERS = [
|
||||||
#"pooled",
|
# "pooled",
|
||||||
"last",
|
"last",
|
||||||
"penultimate"
|
"penultimate"
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
|
def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
|
||||||
freeze=True, layer="last"):
|
freeze=True, layer="last"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -179,7 +223,7 @@ class FrozenOpenCLIPEmbedder(AbstractEncoder):
|
|||||||
x = self.model.ln_final(x)
|
x = self.model.ln_final(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def text_transformer_forward(self, x: torch.Tensor, attn_mask = None):
|
def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
|
||||||
for i, r in enumerate(self.model.transformer.resblocks):
|
for i, r in enumerate(self.model.transformer.resblocks):
|
||||||
if i == len(self.model.transformer.resblocks) - self.layer_idx:
|
if i == len(self.model.transformer.resblocks) - self.layer_idx:
|
||||||
break
|
break
|
||||||
@ -193,14 +237,73 @@ class FrozenOpenCLIPEmbedder(AbstractEncoder):
|
|||||||
return self(text)
|
return self(text)
|
||||||
|
|
||||||
|
|
||||||
|
class FrozenOpenCLIPImageEmbedder(AbstractEncoder):
|
||||||
|
"""
|
||||||
|
Uses the OpenCLIP vision transformer encoder for images
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
|
||||||
|
freeze=True, layer="pooled", antialias=True, ucg_rate=0.):
|
||||||
|
super().__init__()
|
||||||
|
model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'),
|
||||||
|
pretrained=version, )
|
||||||
|
del model.transformer
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
self.device = device
|
||||||
|
self.max_length = max_length
|
||||||
|
if freeze:
|
||||||
|
self.freeze()
|
||||||
|
self.layer = layer
|
||||||
|
if self.layer == "penultimate":
|
||||||
|
raise NotImplementedError()
|
||||||
|
self.layer_idx = 1
|
||||||
|
|
||||||
|
self.antialias = antialias
|
||||||
|
|
||||||
|
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
|
||||||
|
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
|
||||||
|
self.ucg_rate = ucg_rate
|
||||||
|
|
||||||
|
def preprocess(self, x):
|
||||||
|
# normalize to [0,1]
|
||||||
|
# x = kornia.geometry.resize(x, (224, 224),
|
||||||
|
# interpolation='bicubic', align_corners=True,
|
||||||
|
# antialias=self.antialias)
|
||||||
|
x = torch.nn.functional.interpolate(x, size=(224, 224), mode='bicubic', align_corners=True, antialias=True)
|
||||||
|
x = (x + 1.) / 2.
|
||||||
|
# renormalize according to clip
|
||||||
|
x = kornia_functions.enhance_normalize(x, self.mean, self.std)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def freeze(self):
|
||||||
|
self.model = self.model.eval()
|
||||||
|
for param in self.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
def forward(self, image, no_dropout=False):
|
||||||
|
z = self.encode_with_vision_transformer(image)
|
||||||
|
if self.ucg_rate > 0. and not no_dropout:
|
||||||
|
z = torch.bernoulli((1. - self.ucg_rate) * torch.ones(z.shape[0], device=z.device))[:, None] * z
|
||||||
|
return z
|
||||||
|
|
||||||
|
def encode_with_vision_transformer(self, img):
|
||||||
|
img = self.preprocess(img)
|
||||||
|
x = self.model.visual(img)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def encode(self, text):
|
||||||
|
return self(text)
|
||||||
|
|
||||||
|
|
||||||
class FrozenCLIPT5Encoder(AbstractEncoder):
|
class FrozenCLIPT5Encoder(AbstractEncoder):
|
||||||
def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda",
|
def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda",
|
||||||
clip_max_length=77, t5_max_length=77):
|
clip_max_length=77, t5_max_length=77):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length)
|
self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length)
|
||||||
self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
|
self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
|
||||||
print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, "
|
print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, "
|
||||||
f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params.")
|
f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params.")
|
||||||
|
|
||||||
def encode(self, text):
|
def encode(self, text):
|
||||||
return self(text)
|
return self(text)
|
||||||
@ -209,5 +312,3 @@ class FrozenCLIPT5Encoder(AbstractEncoder):
|
|||||||
clip_z = self.clip_encoder.encode(text)
|
clip_z = self.clip_encoder.encode(text)
|
||||||
t5_z = self.t5_encoder.encode(text)
|
t5_z = self.t5_encoder.encode(text)
|
||||||
return [clip_z, t5_z]
|
return [clip_z, t5_z]
|
||||||
|
|
||||||
|
|
||||||
|
35
comfy/ldm/modules/encoders/noise_aug_modules.py
Normal file
35
comfy/ldm/modules/encoders/noise_aug_modules.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
from ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
|
||||||
|
from ldm.modules.diffusionmodules.openaimodel import Timestep
|
||||||
|
import torch
|
||||||
|
|
||||||
|
class CLIPEmbeddingNoiseAugmentation(ImageConcatWithNoiseAugmentation):
|
||||||
|
def __init__(self, *args, clip_stats_path=None, timestep_dim=256, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
if clip_stats_path is None:
|
||||||
|
clip_mean, clip_std = torch.zeros(timestep_dim), torch.ones(timestep_dim)
|
||||||
|
else:
|
||||||
|
clip_mean, clip_std = torch.load(clip_stats_path, map_location="cpu")
|
||||||
|
self.register_buffer("data_mean", clip_mean[None, :], persistent=False)
|
||||||
|
self.register_buffer("data_std", clip_std[None, :], persistent=False)
|
||||||
|
self.time_embed = Timestep(timestep_dim)
|
||||||
|
|
||||||
|
def scale(self, x):
|
||||||
|
# re-normalize to centered mean and unit variance
|
||||||
|
x = (x - self.data_mean) * 1. / self.data_std
|
||||||
|
return x
|
||||||
|
|
||||||
|
def unscale(self, x):
|
||||||
|
# back to original data stats
|
||||||
|
x = (x * self.data_std) + self.data_mean
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x, noise_level=None):
|
||||||
|
if noise_level is None:
|
||||||
|
noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
|
||||||
|
else:
|
||||||
|
assert isinstance(noise_level, torch.Tensor)
|
||||||
|
x = self.scale(x)
|
||||||
|
z = self.q_sample(x, noise_level)
|
||||||
|
z = self.unscale(z)
|
||||||
|
noise_level = self.time_embed(noise_level)
|
||||||
|
return z, noise_level
|
@ -35,6 +35,10 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
|||||||
if 'strength' in cond[1]:
|
if 'strength' in cond[1]:
|
||||||
strength = cond[1]['strength']
|
strength = cond[1]['strength']
|
||||||
|
|
||||||
|
adm_cond = None
|
||||||
|
if 'adm' in cond[1]:
|
||||||
|
adm_cond = cond[1]['adm']
|
||||||
|
|
||||||
input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
|
input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
|
||||||
mult = torch.ones_like(input_x) * strength
|
mult = torch.ones_like(input_x) * strength
|
||||||
|
|
||||||
@ -60,6 +64,9 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
|||||||
cropped.append(cr)
|
cropped.append(cr)
|
||||||
conditionning['c_concat'] = torch.cat(cropped, dim=1)
|
conditionning['c_concat'] = torch.cat(cropped, dim=1)
|
||||||
|
|
||||||
|
if adm_cond is not None:
|
||||||
|
conditionning['c_adm'] = adm_cond
|
||||||
|
|
||||||
control = None
|
control = None
|
||||||
if 'control' in cond[1]:
|
if 'control' in cond[1]:
|
||||||
control = cond[1]['control']
|
control = cond[1]['control']
|
||||||
@ -76,6 +83,9 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
|||||||
if 'c_concat' in c1:
|
if 'c_concat' in c1:
|
||||||
if c1['c_concat'].shape != c2['c_concat'].shape:
|
if c1['c_concat'].shape != c2['c_concat'].shape:
|
||||||
return False
|
return False
|
||||||
|
if 'c_adm' in c1:
|
||||||
|
if c1['c_adm'].shape != c2['c_adm'].shape:
|
||||||
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def can_concat_cond(c1, c2):
|
def can_concat_cond(c1, c2):
|
||||||
@ -92,16 +102,21 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
|||||||
def cond_cat(c_list):
|
def cond_cat(c_list):
|
||||||
c_crossattn = []
|
c_crossattn = []
|
||||||
c_concat = []
|
c_concat = []
|
||||||
|
c_adm = []
|
||||||
for x in c_list:
|
for x in c_list:
|
||||||
if 'c_crossattn' in x:
|
if 'c_crossattn' in x:
|
||||||
c_crossattn.append(x['c_crossattn'])
|
c_crossattn.append(x['c_crossattn'])
|
||||||
if 'c_concat' in x:
|
if 'c_concat' in x:
|
||||||
c_concat.append(x['c_concat'])
|
c_concat.append(x['c_concat'])
|
||||||
|
if 'c_adm' in x:
|
||||||
|
c_adm.append(x['c_adm'])
|
||||||
out = {}
|
out = {}
|
||||||
if len(c_crossattn) > 0:
|
if len(c_crossattn) > 0:
|
||||||
out['c_crossattn'] = [torch.cat(c_crossattn)]
|
out['c_crossattn'] = [torch.cat(c_crossattn)]
|
||||||
if len(c_concat) > 0:
|
if len(c_concat) > 0:
|
||||||
out['c_concat'] = [torch.cat(c_concat)]
|
out['c_concat'] = [torch.cat(c_concat)]
|
||||||
|
if len(c_adm) > 0:
|
||||||
|
out['c_adm'] = torch.cat(c_adm)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, cond_concat_in, model_options):
|
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, cond_concat_in, model_options):
|
||||||
@ -327,6 +342,30 @@ def apply_control_net_to_equal_area(conds, uncond):
|
|||||||
n['control'] = cond_cnets[x]
|
n['control'] = cond_cnets[x]
|
||||||
uncond[temp[1]] = [o[0], n]
|
uncond[temp[1]] = [o[0], n]
|
||||||
|
|
||||||
|
def encode_adm(noise_augmentor, conds, batch_size, device):
|
||||||
|
for t in range(len(conds)):
|
||||||
|
x = conds[t]
|
||||||
|
if 'adm' in x[1]:
|
||||||
|
adm_inputs = []
|
||||||
|
weights = []
|
||||||
|
adm_in = x[1]["adm"]
|
||||||
|
for adm_c in adm_in:
|
||||||
|
adm_cond = adm_c[0].image_embeds
|
||||||
|
weight = adm_c[1]
|
||||||
|
c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([0], device=device))
|
||||||
|
adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight
|
||||||
|
weights.append(weight)
|
||||||
|
adm_inputs.append(adm_out)
|
||||||
|
|
||||||
|
adm_out = torch.stack(adm_inputs).sum(0)
|
||||||
|
#TODO: Apply Noise to Embedding Mix
|
||||||
|
else:
|
||||||
|
adm_out = torch.zeros((1, noise_augmentor.time_embed.dim * 2), device=device)
|
||||||
|
x[1] = x[1].copy()
|
||||||
|
x[1]["adm"] = torch.cat([adm_out] * batch_size)
|
||||||
|
|
||||||
|
return conds
|
||||||
|
|
||||||
class KSampler:
|
class KSampler:
|
||||||
SCHEDULERS = ["karras", "normal", "simple", "ddim_uniform"]
|
SCHEDULERS = ["karras", "normal", "simple", "ddim_uniform"]
|
||||||
SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
|
SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
|
||||||
@ -422,10 +461,14 @@ class KSampler:
|
|||||||
else:
|
else:
|
||||||
precision_scope = contextlib.nullcontext
|
precision_scope = contextlib.nullcontext
|
||||||
|
|
||||||
|
if hasattr(self.model, 'noise_augmentor'): #unclip
|
||||||
|
positive = encode_adm(self.model.noise_augmentor, positive, noise.shape[0], self.device)
|
||||||
|
negative = encode_adm(self.model.noise_augmentor, negative, noise.shape[0], self.device)
|
||||||
|
|
||||||
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": self.model_options}
|
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": self.model_options}
|
||||||
|
|
||||||
cond_concat = None
|
cond_concat = None
|
||||||
if hasattr(self.model, 'concat_keys'):
|
if hasattr(self.model, 'concat_keys'): #inpaint
|
||||||
cond_concat = []
|
cond_concat = []
|
||||||
for ck in self.model.concat_keys:
|
for ck in self.model.concat_keys:
|
||||||
if denoise_mask is not None:
|
if denoise_mask is not None:
|
||||||
|
93
comfy/sd.py
93
comfy/sd.py
@ -12,20 +12,7 @@ from .cldm import cldm
|
|||||||
from .t2i_adapter import adapter
|
from .t2i_adapter import adapter
|
||||||
|
|
||||||
from . import utils
|
from . import utils
|
||||||
|
from . import clip_vision
|
||||||
def load_torch_file(ckpt):
|
|
||||||
if ckpt.lower().endswith(".safetensors"):
|
|
||||||
import safetensors.torch
|
|
||||||
sd = safetensors.torch.load_file(ckpt, device="cpu")
|
|
||||||
else:
|
|
||||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
|
||||||
if "global_step" in pl_sd:
|
|
||||||
print(f"Global Step: {pl_sd['global_step']}")
|
|
||||||
if "state_dict" in pl_sd:
|
|
||||||
sd = pl_sd["state_dict"]
|
|
||||||
else:
|
|
||||||
sd = pl_sd
|
|
||||||
return sd
|
|
||||||
|
|
||||||
def load_model_weights(model, sd, verbose=False, load_state_dict_to=[]):
|
def load_model_weights(model, sd, verbose=False, load_state_dict_to=[]):
|
||||||
m, u = model.load_state_dict(sd, strict=False)
|
m, u = model.load_state_dict(sd, strict=False)
|
||||||
@ -53,30 +40,7 @@ def load_model_weights(model, sd, verbose=False, load_state_dict_to=[]):
|
|||||||
if x in sd:
|
if x in sd:
|
||||||
sd[keys_to_replace[x]] = sd.pop(x)
|
sd[keys_to_replace[x]] = sd.pop(x)
|
||||||
|
|
||||||
resblock_to_replace = {
|
sd = utils.transformers_convert(sd, "cond_stage_model.model", "cond_stage_model.transformer.text_model", 24)
|
||||||
"ln_1": "layer_norm1",
|
|
||||||
"ln_2": "layer_norm2",
|
|
||||||
"mlp.c_fc": "mlp.fc1",
|
|
||||||
"mlp.c_proj": "mlp.fc2",
|
|
||||||
"attn.out_proj": "self_attn.out_proj",
|
|
||||||
}
|
|
||||||
|
|
||||||
for resblock in range(24):
|
|
||||||
for x in resblock_to_replace:
|
|
||||||
for y in ["weight", "bias"]:
|
|
||||||
k = "cond_stage_model.model.transformer.resblocks.{}.{}.{}".format(resblock, x, y)
|
|
||||||
k_to = "cond_stage_model.transformer.text_model.encoder.layers.{}.{}.{}".format(resblock, resblock_to_replace[x], y)
|
|
||||||
if k in sd:
|
|
||||||
sd[k_to] = sd.pop(k)
|
|
||||||
|
|
||||||
for y in ["weight", "bias"]:
|
|
||||||
k_from = "cond_stage_model.model.transformer.resblocks.{}.attn.in_proj_{}".format(resblock, y)
|
|
||||||
if k_from in sd:
|
|
||||||
weights = sd.pop(k_from)
|
|
||||||
for x in range(3):
|
|
||||||
p = ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"]
|
|
||||||
k_to = "cond_stage_model.transformer.text_model.encoder.layers.{}.{}.{}".format(resblock, p[x], y)
|
|
||||||
sd[k_to] = weights[1024*x:1024*(x + 1)]
|
|
||||||
|
|
||||||
for x in load_state_dict_to:
|
for x in load_state_dict_to:
|
||||||
x.load_state_dict(sd, strict=False)
|
x.load_state_dict(sd, strict=False)
|
||||||
@ -123,7 +87,7 @@ LORA_UNET_MAP_RESNET = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
def load_lora(path, to_load):
|
def load_lora(path, to_load):
|
||||||
lora = load_torch_file(path)
|
lora = utils.load_torch_file(path)
|
||||||
patch_dict = {}
|
patch_dict = {}
|
||||||
loaded_keys = set()
|
loaded_keys = set()
|
||||||
for x in to_load:
|
for x in to_load:
|
||||||
@ -599,7 +563,7 @@ class ControlNet:
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
def load_controlnet(ckpt_path, model=None):
|
def load_controlnet(ckpt_path, model=None):
|
||||||
controlnet_data = load_torch_file(ckpt_path)
|
controlnet_data = utils.load_torch_file(ckpt_path)
|
||||||
pth_key = 'control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight'
|
pth_key = 'control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight'
|
||||||
pth = False
|
pth = False
|
||||||
sd2 = False
|
sd2 = False
|
||||||
@ -793,7 +757,7 @@ class StyleModel:
|
|||||||
|
|
||||||
|
|
||||||
def load_style_model(ckpt_path):
|
def load_style_model(ckpt_path):
|
||||||
model_data = load_torch_file(ckpt_path)
|
model_data = utils.load_torch_file(ckpt_path)
|
||||||
keys = model_data.keys()
|
keys = model_data.keys()
|
||||||
if "style_embedding" in keys:
|
if "style_embedding" in keys:
|
||||||
model = adapter.StyleAdapter(width=1024, context_dim=768, num_head=8, n_layes=3, num_token=8)
|
model = adapter.StyleAdapter(width=1024, context_dim=768, num_head=8, n_layes=3, num_token=8)
|
||||||
@ -804,7 +768,7 @@ def load_style_model(ckpt_path):
|
|||||||
|
|
||||||
|
|
||||||
def load_clip(ckpt_path, embedding_directory=None):
|
def load_clip(ckpt_path, embedding_directory=None):
|
||||||
clip_data = load_torch_file(ckpt_path)
|
clip_data = utils.load_torch_file(ckpt_path)
|
||||||
config = {}
|
config = {}
|
||||||
if "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data:
|
if "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data:
|
||||||
config['target'] = 'ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder'
|
config['target'] = 'ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder'
|
||||||
@ -847,7 +811,7 @@ def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, e
|
|||||||
load_state_dict_to = [w]
|
load_state_dict_to = [w]
|
||||||
|
|
||||||
model = instantiate_from_config(config["model"])
|
model = instantiate_from_config(config["model"])
|
||||||
sd = load_torch_file(ckpt_path)
|
sd = utils.load_torch_file(ckpt_path)
|
||||||
model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to)
|
model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to)
|
||||||
|
|
||||||
if fp16:
|
if fp16:
|
||||||
@ -856,10 +820,11 @@ def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, e
|
|||||||
return (ModelPatcher(model), clip, vae)
|
return (ModelPatcher(model), clip, vae)
|
||||||
|
|
||||||
|
|
||||||
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=None):
|
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None):
|
||||||
sd = load_torch_file(ckpt_path)
|
sd = utils.load_torch_file(ckpt_path)
|
||||||
sd_keys = sd.keys()
|
sd_keys = sd.keys()
|
||||||
clip = None
|
clip = None
|
||||||
|
clipvision = None
|
||||||
vae = None
|
vae = None
|
||||||
|
|
||||||
fp16 = model_management.should_use_fp16()
|
fp16 = model_management.should_use_fp16()
|
||||||
@ -884,6 +849,29 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, e
|
|||||||
w.cond_stage_model = clip.cond_stage_model
|
w.cond_stage_model = clip.cond_stage_model
|
||||||
load_state_dict_to = [w]
|
load_state_dict_to = [w]
|
||||||
|
|
||||||
|
clipvision_key = "embedder.model.visual.transformer.resblocks.0.attn.in_proj_weight"
|
||||||
|
noise_aug_config = None
|
||||||
|
if clipvision_key in sd_keys:
|
||||||
|
size = sd[clipvision_key].shape[1]
|
||||||
|
|
||||||
|
if output_clipvision:
|
||||||
|
clipvision = clip_vision.load_clipvision_from_sd(sd)
|
||||||
|
|
||||||
|
noise_aug_key = "noise_augmentor.betas"
|
||||||
|
if noise_aug_key in sd_keys:
|
||||||
|
noise_aug_config = {}
|
||||||
|
params = {}
|
||||||
|
noise_schedule_config = {}
|
||||||
|
noise_schedule_config["timesteps"] = sd[noise_aug_key].shape[0]
|
||||||
|
noise_schedule_config["beta_schedule"] = "squaredcos_cap_v2"
|
||||||
|
params["noise_schedule_config"] = noise_schedule_config
|
||||||
|
noise_aug_config['target'] = "ldm.modules.encoders.noise_aug_modules.CLIPEmbeddingNoiseAugmentation"
|
||||||
|
if size == 1280: #h
|
||||||
|
params["timestep_dim"] = 1024
|
||||||
|
elif size == 1024: #l
|
||||||
|
params["timestep_dim"] = 768
|
||||||
|
noise_aug_config['params'] = params
|
||||||
|
|
||||||
sd_config = {
|
sd_config = {
|
||||||
"linear_start": 0.00085,
|
"linear_start": 0.00085,
|
||||||
"linear_end": 0.012,
|
"linear_end": 0.012,
|
||||||
@ -932,7 +920,13 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, e
|
|||||||
sd_config["unet_config"] = {"target": "ldm.modules.diffusionmodules.openaimodel.UNetModel", "params": unet_config}
|
sd_config["unet_config"] = {"target": "ldm.modules.diffusionmodules.openaimodel.UNetModel", "params": unet_config}
|
||||||
model_config = {"target": "ldm.models.diffusion.ddpm.LatentDiffusion", "params": sd_config}
|
model_config = {"target": "ldm.models.diffusion.ddpm.LatentDiffusion", "params": sd_config}
|
||||||
|
|
||||||
if unet_config["in_channels"] > 4: #inpainting model
|
if noise_aug_config is not None: #SD2.x unclip model
|
||||||
|
sd_config["noise_aug_config"] = noise_aug_config
|
||||||
|
sd_config["image_size"] = 96
|
||||||
|
sd_config["embedding_dropout"] = 0.25
|
||||||
|
sd_config["conditioning_key"] = 'crossattn-adm'
|
||||||
|
model_config["target"] = "ldm.models.diffusion.ddpm.ImageEmbeddingConditionedLatentDiffusion"
|
||||||
|
elif unet_config["in_channels"] > 4: #inpainting model
|
||||||
sd_config["conditioning_key"] = "hybrid"
|
sd_config["conditioning_key"] = "hybrid"
|
||||||
sd_config["finetune_keys"] = None
|
sd_config["finetune_keys"] = None
|
||||||
model_config["target"] = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion"
|
model_config["target"] = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion"
|
||||||
@ -944,6 +938,11 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, e
|
|||||||
else:
|
else:
|
||||||
unet_config["num_heads"] = 8 #SD1.x
|
unet_config["num_heads"] = 8 #SD1.x
|
||||||
|
|
||||||
|
unclip = 'model.diffusion_model.label_emb.0.0.weight'
|
||||||
|
if unclip in sd_keys:
|
||||||
|
unet_config["num_classes"] = "sequential"
|
||||||
|
unet_config["adm_in_channels"] = sd[unclip].shape[1]
|
||||||
|
|
||||||
if unet_config["context_dim"] == 1024 and unet_config["in_channels"] == 4: #only SD2.x non inpainting models are v prediction
|
if unet_config["context_dim"] == 1024 and unet_config["in_channels"] == 4: #only SD2.x non inpainting models are v prediction
|
||||||
k = "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1.bias"
|
k = "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1.bias"
|
||||||
out = sd[k]
|
out = sd[k]
|
||||||
@ -956,4 +955,4 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, e
|
|||||||
if fp16:
|
if fp16:
|
||||||
model = model.half()
|
model = model.half()
|
||||||
|
|
||||||
return (ModelPatcher(model), clip, vae)
|
return (ModelPatcher(model), clip, vae, clipvision)
|
||||||
|
@ -1,5 +1,47 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
def load_torch_file(ckpt):
|
||||||
|
if ckpt.lower().endswith(".safetensors"):
|
||||||
|
import safetensors.torch
|
||||||
|
sd = safetensors.torch.load_file(ckpt, device="cpu")
|
||||||
|
else:
|
||||||
|
pl_sd = torch.load(ckpt, map_location="cpu")
|
||||||
|
if "global_step" in pl_sd:
|
||||||
|
print(f"Global Step: {pl_sd['global_step']}")
|
||||||
|
if "state_dict" in pl_sd:
|
||||||
|
sd = pl_sd["state_dict"]
|
||||||
|
else:
|
||||||
|
sd = pl_sd
|
||||||
|
return sd
|
||||||
|
|
||||||
|
def transformers_convert(sd, prefix_from, prefix_to, number):
|
||||||
|
resblock_to_replace = {
|
||||||
|
"ln_1": "layer_norm1",
|
||||||
|
"ln_2": "layer_norm2",
|
||||||
|
"mlp.c_fc": "mlp.fc1",
|
||||||
|
"mlp.c_proj": "mlp.fc2",
|
||||||
|
"attn.out_proj": "self_attn.out_proj",
|
||||||
|
}
|
||||||
|
|
||||||
|
for resblock in range(number):
|
||||||
|
for x in resblock_to_replace:
|
||||||
|
for y in ["weight", "bias"]:
|
||||||
|
k = "{}.transformer.resblocks.{}.{}.{}".format(prefix_from, resblock, x, y)
|
||||||
|
k_to = "{}.encoder.layers.{}.{}.{}".format(prefix_to, resblock, resblock_to_replace[x], y)
|
||||||
|
if k in sd:
|
||||||
|
sd[k_to] = sd.pop(k)
|
||||||
|
|
||||||
|
for y in ["weight", "bias"]:
|
||||||
|
k_from = "{}.transformer.resblocks.{}.attn.in_proj_{}".format(prefix_from, resblock, y)
|
||||||
|
if k_from in sd:
|
||||||
|
weights = sd.pop(k_from)
|
||||||
|
shape_from = weights.shape[0] // 3
|
||||||
|
for x in range(3):
|
||||||
|
p = ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"]
|
||||||
|
k_to = "{}.encoder.layers.{}.{}.{}".format(prefix_to, resblock, p[x], y)
|
||||||
|
sd[k_to] = weights[shape_from*x:shape_from*(x + 1)]
|
||||||
|
return sd
|
||||||
|
|
||||||
def common_upscale(samples, width, height, upscale_method, crop):
|
def common_upscale(samples, width, height, upscale_method, crop):
|
||||||
if crop == "center":
|
if crop == "center":
|
||||||
old_width = samples.shape[3]
|
old_width = samples.shape[3]
|
||||||
|
@ -1,32 +0,0 @@
|
|||||||
from transformers import CLIPVisionModel, CLIPVisionConfig, CLIPImageProcessor
|
|
||||||
from comfy.sd import load_torch_file
|
|
||||||
import os
|
|
||||||
|
|
||||||
class ClipVisionModel():
|
|
||||||
def __init__(self):
|
|
||||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config.json")
|
|
||||||
config = CLIPVisionConfig.from_json_file(json_config)
|
|
||||||
self.model = CLIPVisionModel(config)
|
|
||||||
self.processor = CLIPImageProcessor(crop_size=224,
|
|
||||||
do_center_crop=True,
|
|
||||||
do_convert_rgb=True,
|
|
||||||
do_normalize=True,
|
|
||||||
do_resize=True,
|
|
||||||
image_mean=[ 0.48145466,0.4578275,0.40821073],
|
|
||||||
image_std=[0.26862954,0.26130258,0.27577711],
|
|
||||||
resample=3, #bicubic
|
|
||||||
size=224)
|
|
||||||
|
|
||||||
def load_sd(self, sd):
|
|
||||||
self.model.load_state_dict(sd, strict=False)
|
|
||||||
|
|
||||||
def encode_image(self, image):
|
|
||||||
inputs = self.processor(images=[image[0]], return_tensors="pt")
|
|
||||||
outputs = self.model(**inputs)
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
def load(ckpt_path):
|
|
||||||
clip_data = load_torch_file(ckpt_path)
|
|
||||||
clip = ClipVisionModel()
|
|
||||||
clip.load_sd(clip_data)
|
|
||||||
return clip
|
|
@ -1,6 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
from comfy_extras.chainner_models import model_loading
|
from comfy_extras.chainner_models import model_loading
|
||||||
from comfy.sd import load_torch_file
|
|
||||||
import model_management
|
import model_management
|
||||||
import torch
|
import torch
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
@ -18,7 +17,7 @@ class UpscaleModelLoader:
|
|||||||
|
|
||||||
def load_model(self, model_name):
|
def load_model(self, model_name):
|
||||||
model_path = folder_paths.get_full_path("upscale_models", model_name)
|
model_path = folder_paths.get_full_path("upscale_models", model_name)
|
||||||
sd = load_torch_file(model_path)
|
sd = comfy.utils.load_torch_file(model_path)
|
||||||
out = model_loading.load_state_dict(sd).eval()
|
out = model_loading.load_state_dict(sd).eval()
|
||||||
return (out, )
|
return (out, )
|
||||||
|
|
||||||
|
49
nodes.py
49
nodes.py
@ -18,7 +18,7 @@ import comfy.samplers
|
|||||||
import comfy.sd
|
import comfy.sd
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
|
|
||||||
import comfy_extras.clip_vision
|
import comfy.clip_vision
|
||||||
|
|
||||||
import model_management
|
import model_management
|
||||||
import importlib
|
import importlib
|
||||||
@ -219,6 +219,21 @@ class CheckpointLoaderSimple:
|
|||||||
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class unCLIPCheckpointLoader:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("MODEL", "CLIP", "VAE", "CLIP_VISION")
|
||||||
|
FUNCTION = "load_checkpoint"
|
||||||
|
|
||||||
|
CATEGORY = "_for_testing/unclip"
|
||||||
|
|
||||||
|
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
|
||||||
|
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
|
||||||
|
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||||
|
return out
|
||||||
|
|
||||||
class CLIPSetLastLayer:
|
class CLIPSetLastLayer:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -370,7 +385,7 @@ class CLIPVisionLoader:
|
|||||||
|
|
||||||
def load_clip(self, clip_name):
|
def load_clip(self, clip_name):
|
||||||
clip_path = folder_paths.get_full_path("clip_vision", clip_name)
|
clip_path = folder_paths.get_full_path("clip_vision", clip_name)
|
||||||
clip_vision = comfy_extras.clip_vision.load(clip_path)
|
clip_vision = comfy.clip_vision.load(clip_path)
|
||||||
return (clip_vision,)
|
return (clip_vision,)
|
||||||
|
|
||||||
class CLIPVisionEncode:
|
class CLIPVisionEncode:
|
||||||
@ -382,7 +397,7 @@ class CLIPVisionEncode:
|
|||||||
RETURN_TYPES = ("CLIP_VISION_OUTPUT",)
|
RETURN_TYPES = ("CLIP_VISION_OUTPUT",)
|
||||||
FUNCTION = "encode"
|
FUNCTION = "encode"
|
||||||
|
|
||||||
CATEGORY = "conditioning/style_model"
|
CATEGORY = "conditioning"
|
||||||
|
|
||||||
def encode(self, clip_vision, image):
|
def encode(self, clip_vision, image):
|
||||||
output = clip_vision.encode_image(image)
|
output = clip_vision.encode_image(image)
|
||||||
@ -424,6 +439,32 @@ class StyleModelApply:
|
|||||||
c.append(n)
|
c.append(n)
|
||||||
return (c, )
|
return (c, )
|
||||||
|
|
||||||
|
class unCLIPConditioning:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"conditioning": ("CONDITIONING", ),
|
||||||
|
"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
|
||||||
|
"strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("CONDITIONING",)
|
||||||
|
FUNCTION = "apply_adm"
|
||||||
|
|
||||||
|
CATEGORY = "_for_testing/unclip"
|
||||||
|
|
||||||
|
def apply_adm(self, conditioning, clip_vision_output, strength):
|
||||||
|
c = []
|
||||||
|
for t in conditioning:
|
||||||
|
o = t[1].copy()
|
||||||
|
x = (clip_vision_output, strength)
|
||||||
|
if "adm" in o:
|
||||||
|
o["adm"] = o["adm"][:] + [x]
|
||||||
|
else:
|
||||||
|
o["adm"] = [x]
|
||||||
|
n = [t[0], o]
|
||||||
|
c.append(n)
|
||||||
|
return (c, )
|
||||||
|
|
||||||
|
|
||||||
class EmptyLatentImage:
|
class EmptyLatentImage:
|
||||||
def __init__(self, device="cpu"):
|
def __init__(self, device="cpu"):
|
||||||
self.device = device
|
self.device = device
|
||||||
@ -1025,6 +1066,7 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"CLIPLoader": CLIPLoader,
|
"CLIPLoader": CLIPLoader,
|
||||||
"CLIPVisionEncode": CLIPVisionEncode,
|
"CLIPVisionEncode": CLIPVisionEncode,
|
||||||
"StyleModelApply": StyleModelApply,
|
"StyleModelApply": StyleModelApply,
|
||||||
|
"unCLIPConditioning": unCLIPConditioning,
|
||||||
"ControlNetApply": ControlNetApply,
|
"ControlNetApply": ControlNetApply,
|
||||||
"ControlNetLoader": ControlNetLoader,
|
"ControlNetLoader": ControlNetLoader,
|
||||||
"DiffControlNetLoader": DiffControlNetLoader,
|
"DiffControlNetLoader": DiffControlNetLoader,
|
||||||
@ -1033,6 +1075,7 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"VAEDecodeTiled": VAEDecodeTiled,
|
"VAEDecodeTiled": VAEDecodeTiled,
|
||||||
"VAEEncodeTiled": VAEEncodeTiled,
|
"VAEEncodeTiled": VAEEncodeTiled,
|
||||||
"TomePatchModel": TomePatchModel,
|
"TomePatchModel": TomePatchModel,
|
||||||
|
"unCLIPCheckpointLoader": unCLIPCheckpointLoader,
|
||||||
}
|
}
|
||||||
|
|
||||||
def load_custom_node(module_path):
|
def load_custom_node(module_path):
|
||||||
|
Loading…
Reference in New Issue
Block a user