CLIP code refactor and improvements.

More generic clip model class that can be used on more types of text
encoders.

Don't apply weighting algorithm when weight is 1.0

Don't compute an empty token output when it's not needed.
This commit is contained in:
comfyanonymous 2023-11-06 13:43:50 -05:00
parent b3fcd64c6c
commit 656c0b5d90
3 changed files with 62 additions and 33 deletions

View File

@ -8,32 +8,54 @@ import zipfile
from . import model_management from . import model_management
import contextlib import contextlib
def gen_empty_tokens(special_tokens, length):
start_token = special_tokens.get("start", None)
end_token = special_tokens.get("end", None)
pad_token = special_tokens.get("pad")
output = []
if start_token is not None:
output.append(start_token)
if end_token is not None:
output.append(end_token)
output += [pad_token] * (length - len(output))
return output
class ClipTokenWeightEncoder: class ClipTokenWeightEncoder:
def encode_token_weights(self, token_weight_pairs): def encode_token_weights(self, token_weight_pairs):
to_encode = list(self.empty_tokens) to_encode = list()
max_token_len = 0
has_weights = False
for x in token_weight_pairs: for x in token_weight_pairs:
tokens = list(map(lambda a: a[0], x)) tokens = list(map(lambda a: a[0], x))
max_token_len = max(len(tokens), max_token_len)
has_weights = has_weights or not all(map(lambda a: a[1] == 1.0, x))
to_encode.append(tokens) to_encode.append(tokens)
sections = len(to_encode)
if has_weights or sections == 0:
to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len))
out, pooled = self.encode(to_encode) out, pooled = self.encode(to_encode)
z_empty = out[0:1] if pooled is not None:
if pooled.shape[0] > 1: first_pooled = pooled[0:1].cpu()
first_pooled = pooled[1:2]
else: else:
first_pooled = pooled[0:1] first_pooled = pooled
output = [] output = []
for k in range(1, out.shape[0]): for k in range(0, sections):
z = out[k:k+1] z = out[k:k+1]
if has_weights:
z_empty = out[-1]
for i in range(len(z)): for i in range(len(z)):
for j in range(len(z[i])): for j in range(len(z[i])):
weight = token_weight_pairs[k - 1][j][1] weight = token_weight_pairs[k][j][1]
z[i][j] = (z[i][j] - z_empty[0][j]) * weight + z_empty[0][j] if weight != 1.0:
z[i][j] = (z[i][j] - z_empty[j]) * weight + z_empty[j]
output.append(z) output.append(z)
if (len(output) == 0): if (len(output) == 0):
return z_empty.cpu(), first_pooled.cpu() return out[-1:].cpu(), first_pooled
return torch.cat(output, dim=-2).cpu(), first_pooled.cpu() return torch.cat(output, dim=-2).cpu(), first_pooled
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
"""Uses the CLIP transformer encoder for text (from huggingface)""" """Uses the CLIP transformer encoder for text (from huggingface)"""
@ -43,37 +65,43 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
"hidden" "hidden"
] ]
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77, def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, textmodel_path=None, dtype=None): # clip-vit-base-patch32 freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, textmodel_path=None, dtype=None,
special_tokens={"start": 49406, "end": 49407, "pad": 49407},layer_norm_hidden_state=True, config_class=CLIPTextConfig,
model_class=CLIPTextModel, inner_name="text_model"): # clip-vit-base-patch32
super().__init__() super().__init__()
assert layer in self.LAYERS assert layer in self.LAYERS
self.num_layers = 12 self.num_layers = 12
if textmodel_path is not None: if textmodel_path is not None:
self.transformer = CLIPTextModel.from_pretrained(textmodel_path) self.transformer = model_class.from_pretrained(textmodel_path)
else: else:
if textmodel_json_config is None: if textmodel_json_config is None:
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json") textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
config = CLIPTextConfig.from_json_file(textmodel_json_config) config = config_class.from_json_file(textmodel_json_config)
self.num_layers = config.num_hidden_layers self.num_layers = config.num_hidden_layers
with comfy.ops.use_comfy_ops(device, dtype): with comfy.ops.use_comfy_ops(device, dtype):
with modeling_utils.no_init_weights(): with modeling_utils.no_init_weights():
self.transformer = CLIPTextModel(config) self.transformer = model_class(config)
self.inner_name = inner_name
if dtype is not None: if dtype is not None:
self.transformer.to(dtype) self.transformer.to(dtype)
self.transformer.text_model.embeddings.token_embedding.to(torch.float32) inner_model = getattr(self.transformer, self.inner_name)
self.transformer.text_model.embeddings.position_embedding.to(torch.float32) if hasattr(inner_model, "embeddings"):
inner_model.embeddings.to(torch.float32)
else:
self.transformer.set_input_embeddings(self.transformer.get_input_embeddings().to(torch.float32))
self.max_length = max_length self.max_length = max_length
if freeze: if freeze:
self.freeze() self.freeze()
self.layer = layer self.layer = layer
self.layer_idx = None self.layer_idx = None
self.empty_tokens = [[49406] + [49407] * 76] self.special_tokens = special_tokens
self.text_projection = torch.nn.Parameter(torch.eye(self.transformer.get_input_embeddings().weight.shape[1])) self.text_projection = torch.nn.Parameter(torch.eye(self.transformer.get_input_embeddings().weight.shape[1]))
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055)) self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
self.enable_attention_masks = False self.enable_attention_masks = False
self.layer_norm_hidden_state = True self.layer_norm_hidden_state = layer_norm_hidden_state
if layer == "hidden": if layer == "hidden":
assert layer_idx is not None assert layer_idx is not None
assert abs(layer_idx) <= self.num_layers assert abs(layer_idx) <= self.num_layers
@ -117,7 +145,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
else: else:
print("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored", y.shape[0], current_embeds.weight.shape[1]) print("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored", y.shape[0], current_embeds.weight.shape[1])
while len(tokens_temp) < len(x): while len(tokens_temp) < len(x):
tokens_temp += [self.empty_tokens[0][-1]] tokens_temp += [self.special_tokens["pad"]]
out_tokens += [tokens_temp] out_tokens += [tokens_temp]
n = token_dict_size n = token_dict_size
@ -142,7 +170,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
tokens = self.set_up_textual_embeddings(tokens, backup_embeds) tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
tokens = torch.LongTensor(tokens).to(device) tokens = torch.LongTensor(tokens).to(device)
if self.transformer.text_model.final_layer_norm.weight.dtype != torch.float32: if getattr(self.transformer, self.inner_name).final_layer_norm.weight.dtype != torch.float32:
precision_scope = torch.autocast precision_scope = torch.autocast
else: else:
precision_scope = lambda a, b: contextlib.nullcontext(a) precision_scope = lambda a, b: contextlib.nullcontext(a)
@ -168,12 +196,16 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
else: else:
z = outputs.hidden_states[self.layer_idx] z = outputs.hidden_states[self.layer_idx]
if self.layer_norm_hidden_state: if self.layer_norm_hidden_state:
z = self.transformer.text_model.final_layer_norm(z) z = getattr(self.transformer, self.inner_name).final_layer_norm(z)
pooled_output = outputs.pooler_output if hasattr(outputs, "pooler_output"):
if self.text_projection is not None: pooled_output = outputs.pooler_output.float()
else:
pooled_output = None
if self.text_projection is not None and pooled_output is not None:
pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float() pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float()
return z.float(), pooled_output.float() return z.float(), pooled_output
def encode(self, tokens): def encode(self, tokens):
return self(tokens) return self(tokens)

View File

@ -9,8 +9,7 @@ class SD2ClipHModel(sd1_clip.SDClipModel):
layer_idx=23 layer_idx=23
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json") textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json")
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype) super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0})
self.empty_tokens = [[49406] + [49407] + [0] * 75]
class SD2ClipHTokenizer(sd1_clip.SDTokenizer): class SD2ClipHTokenizer(sd1_clip.SDTokenizer):
def __init__(self, tokenizer_path=None, embedding_directory=None): def __init__(self, tokenizer_path=None, embedding_directory=None):

View File

@ -9,9 +9,8 @@ class SDXLClipG(sd1_clip.SDClipModel):
layer_idx=-2 layer_idx=-2
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json") textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype) super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype,
self.empty_tokens = [[49406] + [49407] + [0] * 75] special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False)
self.layer_norm_hidden_state = False
def load_sd(self, sd): def load_sd(self, sd):
return super().load_sd(sd) return super().load_sd(sd)
@ -38,8 +37,7 @@ class SDXLTokenizer:
class SDXLClipModel(torch.nn.Module): class SDXLClipModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None): def __init__(self, device="cpu", dtype=None):
super().__init__() super().__init__()
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=11, device=device, dtype=dtype) self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=11, device=device, dtype=dtype, layer_norm_hidden_state=False)
self.clip_l.layer_norm_hidden_state = False
self.clip_g = SDXLClipG(device=device, dtype=dtype) self.clip_g = SDXLClipG(device=device, dtype=dtype)
def clip_layer(self, layer_idx): def clip_layer(self, layer_idx):