mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
CLIP code refactor and improvements.
More generic clip model class that can be used on more types of text encoders. Don't apply weighting algorithm when weight is 1.0 Don't compute an empty token output when it's not needed.
This commit is contained in:
parent
b3fcd64c6c
commit
656c0b5d90
@ -8,32 +8,54 @@ import zipfile
|
||||
from . import model_management
|
||||
import contextlib
|
||||
|
||||
def gen_empty_tokens(special_tokens, length):
|
||||
start_token = special_tokens.get("start", None)
|
||||
end_token = special_tokens.get("end", None)
|
||||
pad_token = special_tokens.get("pad")
|
||||
output = []
|
||||
if start_token is not None:
|
||||
output.append(start_token)
|
||||
if end_token is not None:
|
||||
output.append(end_token)
|
||||
output += [pad_token] * (length - len(output))
|
||||
return output
|
||||
|
||||
class ClipTokenWeightEncoder:
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
to_encode = list(self.empty_tokens)
|
||||
to_encode = list()
|
||||
max_token_len = 0
|
||||
has_weights = False
|
||||
for x in token_weight_pairs:
|
||||
tokens = list(map(lambda a: a[0], x))
|
||||
max_token_len = max(len(tokens), max_token_len)
|
||||
has_weights = has_weights or not all(map(lambda a: a[1] == 1.0, x))
|
||||
to_encode.append(tokens)
|
||||
|
||||
sections = len(to_encode)
|
||||
if has_weights or sections == 0:
|
||||
to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len))
|
||||
|
||||
out, pooled = self.encode(to_encode)
|
||||
z_empty = out[0:1]
|
||||
if pooled.shape[0] > 1:
|
||||
first_pooled = pooled[1:2]
|
||||
if pooled is not None:
|
||||
first_pooled = pooled[0:1].cpu()
|
||||
else:
|
||||
first_pooled = pooled[0:1]
|
||||
first_pooled = pooled
|
||||
|
||||
output = []
|
||||
for k in range(1, out.shape[0]):
|
||||
for k in range(0, sections):
|
||||
z = out[k:k+1]
|
||||
for i in range(len(z)):
|
||||
for j in range(len(z[i])):
|
||||
weight = token_weight_pairs[k - 1][j][1]
|
||||
z[i][j] = (z[i][j] - z_empty[0][j]) * weight + z_empty[0][j]
|
||||
if has_weights:
|
||||
z_empty = out[-1]
|
||||
for i in range(len(z)):
|
||||
for j in range(len(z[i])):
|
||||
weight = token_weight_pairs[k][j][1]
|
||||
if weight != 1.0:
|
||||
z[i][j] = (z[i][j] - z_empty[j]) * weight + z_empty[j]
|
||||
output.append(z)
|
||||
|
||||
if (len(output) == 0):
|
||||
return z_empty.cpu(), first_pooled.cpu()
|
||||
return torch.cat(output, dim=-2).cpu(), first_pooled.cpu()
|
||||
return out[-1:].cpu(), first_pooled
|
||||
return torch.cat(output, dim=-2).cpu(), first_pooled
|
||||
|
||||
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
||||
@ -43,37 +65,43 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
"hidden"
|
||||
]
|
||||
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
|
||||
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, textmodel_path=None, dtype=None): # clip-vit-base-patch32
|
||||
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, textmodel_path=None, dtype=None,
|
||||
special_tokens={"start": 49406, "end": 49407, "pad": 49407},layer_norm_hidden_state=True, config_class=CLIPTextConfig,
|
||||
model_class=CLIPTextModel, inner_name="text_model"): # clip-vit-base-patch32
|
||||
super().__init__()
|
||||
assert layer in self.LAYERS
|
||||
self.num_layers = 12
|
||||
if textmodel_path is not None:
|
||||
self.transformer = CLIPTextModel.from_pretrained(textmodel_path)
|
||||
self.transformer = model_class.from_pretrained(textmodel_path)
|
||||
else:
|
||||
if textmodel_json_config is None:
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
|
||||
config = CLIPTextConfig.from_json_file(textmodel_json_config)
|
||||
config = config_class.from_json_file(textmodel_json_config)
|
||||
self.num_layers = config.num_hidden_layers
|
||||
with comfy.ops.use_comfy_ops(device, dtype):
|
||||
with modeling_utils.no_init_weights():
|
||||
self.transformer = CLIPTextModel(config)
|
||||
self.transformer = model_class(config)
|
||||
|
||||
self.inner_name = inner_name
|
||||
if dtype is not None:
|
||||
self.transformer.to(dtype)
|
||||
self.transformer.text_model.embeddings.token_embedding.to(torch.float32)
|
||||
self.transformer.text_model.embeddings.position_embedding.to(torch.float32)
|
||||
inner_model = getattr(self.transformer, self.inner_name)
|
||||
if hasattr(inner_model, "embeddings"):
|
||||
inner_model.embeddings.to(torch.float32)
|
||||
else:
|
||||
self.transformer.set_input_embeddings(self.transformer.get_input_embeddings().to(torch.float32))
|
||||
|
||||
self.max_length = max_length
|
||||
if freeze:
|
||||
self.freeze()
|
||||
self.layer = layer
|
||||
self.layer_idx = None
|
||||
self.empty_tokens = [[49406] + [49407] * 76]
|
||||
self.special_tokens = special_tokens
|
||||
self.text_projection = torch.nn.Parameter(torch.eye(self.transformer.get_input_embeddings().weight.shape[1]))
|
||||
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
|
||||
self.enable_attention_masks = False
|
||||
|
||||
self.layer_norm_hidden_state = True
|
||||
self.layer_norm_hidden_state = layer_norm_hidden_state
|
||||
if layer == "hidden":
|
||||
assert layer_idx is not None
|
||||
assert abs(layer_idx) <= self.num_layers
|
||||
@ -117,7 +145,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
else:
|
||||
print("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored", y.shape[0], current_embeds.weight.shape[1])
|
||||
while len(tokens_temp) < len(x):
|
||||
tokens_temp += [self.empty_tokens[0][-1]]
|
||||
tokens_temp += [self.special_tokens["pad"]]
|
||||
out_tokens += [tokens_temp]
|
||||
|
||||
n = token_dict_size
|
||||
@ -142,7 +170,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
|
||||
tokens = torch.LongTensor(tokens).to(device)
|
||||
|
||||
if self.transformer.text_model.final_layer_norm.weight.dtype != torch.float32:
|
||||
if getattr(self.transformer, self.inner_name).final_layer_norm.weight.dtype != torch.float32:
|
||||
precision_scope = torch.autocast
|
||||
else:
|
||||
precision_scope = lambda a, b: contextlib.nullcontext(a)
|
||||
@ -168,12 +196,16 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
else:
|
||||
z = outputs.hidden_states[self.layer_idx]
|
||||
if self.layer_norm_hidden_state:
|
||||
z = self.transformer.text_model.final_layer_norm(z)
|
||||
z = getattr(self.transformer, self.inner_name).final_layer_norm(z)
|
||||
|
||||
pooled_output = outputs.pooler_output
|
||||
if self.text_projection is not None:
|
||||
if hasattr(outputs, "pooler_output"):
|
||||
pooled_output = outputs.pooler_output.float()
|
||||
else:
|
||||
pooled_output = None
|
||||
|
||||
if self.text_projection is not None and pooled_output is not None:
|
||||
pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float()
|
||||
return z.float(), pooled_output.float()
|
||||
return z.float(), pooled_output
|
||||
|
||||
def encode(self, tokens):
|
||||
return self(tokens)
|
||||
|
@ -9,8 +9,7 @@ class SD2ClipHModel(sd1_clip.SDClipModel):
|
||||
layer_idx=23
|
||||
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json")
|
||||
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype)
|
||||
self.empty_tokens = [[49406] + [49407] + [0] * 75]
|
||||
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0})
|
||||
|
||||
class SD2ClipHTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, tokenizer_path=None, embedding_directory=None):
|
||||
|
@ -9,9 +9,8 @@ class SDXLClipG(sd1_clip.SDClipModel):
|
||||
layer_idx=-2
|
||||
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
|
||||
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype)
|
||||
self.empty_tokens = [[49406] + [49407] + [0] * 75]
|
||||
self.layer_norm_hidden_state = False
|
||||
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype,
|
||||
special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False)
|
||||
|
||||
def load_sd(self, sd):
|
||||
return super().load_sd(sd)
|
||||
@ -38,8 +37,7 @@ class SDXLTokenizer:
|
||||
class SDXLClipModel(torch.nn.Module):
|
||||
def __init__(self, device="cpu", dtype=None):
|
||||
super().__init__()
|
||||
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=11, device=device, dtype=dtype)
|
||||
self.clip_l.layer_norm_hidden_state = False
|
||||
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=11, device=device, dtype=dtype, layer_norm_hidden_state=False)
|
||||
self.clip_g = SDXLClipG(device=device, dtype=dtype)
|
||||
|
||||
def clip_layer(self, layer_idx):
|
||||
|
Loading…
Reference in New Issue
Block a user