CLIP code refactor and improvements.

More generic clip model class that can be used on more types of text
encoders.

Don't apply weighting algorithm when weight is 1.0

Don't compute an empty token output when it's not needed.
This commit is contained in:
comfyanonymous 2023-11-06 13:43:50 -05:00
parent b3fcd64c6c
commit 656c0b5d90
3 changed files with 62 additions and 33 deletions

View File

@ -8,32 +8,54 @@ import zipfile
from . import model_management
import contextlib
def gen_empty_tokens(special_tokens, length):
start_token = special_tokens.get("start", None)
end_token = special_tokens.get("end", None)
pad_token = special_tokens.get("pad")
output = []
if start_token is not None:
output.append(start_token)
if end_token is not None:
output.append(end_token)
output += [pad_token] * (length - len(output))
return output
class ClipTokenWeightEncoder:
def encode_token_weights(self, token_weight_pairs):
to_encode = list(self.empty_tokens)
to_encode = list()
max_token_len = 0
has_weights = False
for x in token_weight_pairs:
tokens = list(map(lambda a: a[0], x))
max_token_len = max(len(tokens), max_token_len)
has_weights = has_weights or not all(map(lambda a: a[1] == 1.0, x))
to_encode.append(tokens)
sections = len(to_encode)
if has_weights or sections == 0:
to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len))
out, pooled = self.encode(to_encode)
z_empty = out[0:1]
if pooled.shape[0] > 1:
first_pooled = pooled[1:2]
if pooled is not None:
first_pooled = pooled[0:1].cpu()
else:
first_pooled = pooled[0:1]
first_pooled = pooled
output = []
for k in range(1, out.shape[0]):
for k in range(0, sections):
z = out[k:k+1]
for i in range(len(z)):
for j in range(len(z[i])):
weight = token_weight_pairs[k - 1][j][1]
z[i][j] = (z[i][j] - z_empty[0][j]) * weight + z_empty[0][j]
if has_weights:
z_empty = out[-1]
for i in range(len(z)):
for j in range(len(z[i])):
weight = token_weight_pairs[k][j][1]
if weight != 1.0:
z[i][j] = (z[i][j] - z_empty[j]) * weight + z_empty[j]
output.append(z)
if (len(output) == 0):
return z_empty.cpu(), first_pooled.cpu()
return torch.cat(output, dim=-2).cpu(), first_pooled.cpu()
return out[-1:].cpu(), first_pooled
return torch.cat(output, dim=-2).cpu(), first_pooled
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
"""Uses the CLIP transformer encoder for text (from huggingface)"""
@ -43,37 +65,43 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
"hidden"
]
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, textmodel_path=None, dtype=None): # clip-vit-base-patch32
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, textmodel_path=None, dtype=None,
special_tokens={"start": 49406, "end": 49407, "pad": 49407},layer_norm_hidden_state=True, config_class=CLIPTextConfig,
model_class=CLIPTextModel, inner_name="text_model"): # clip-vit-base-patch32
super().__init__()
assert layer in self.LAYERS
self.num_layers = 12
if textmodel_path is not None:
self.transformer = CLIPTextModel.from_pretrained(textmodel_path)
self.transformer = model_class.from_pretrained(textmodel_path)
else:
if textmodel_json_config is None:
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
config = CLIPTextConfig.from_json_file(textmodel_json_config)
config = config_class.from_json_file(textmodel_json_config)
self.num_layers = config.num_hidden_layers
with comfy.ops.use_comfy_ops(device, dtype):
with modeling_utils.no_init_weights():
self.transformer = CLIPTextModel(config)
self.transformer = model_class(config)
self.inner_name = inner_name
if dtype is not None:
self.transformer.to(dtype)
self.transformer.text_model.embeddings.token_embedding.to(torch.float32)
self.transformer.text_model.embeddings.position_embedding.to(torch.float32)
inner_model = getattr(self.transformer, self.inner_name)
if hasattr(inner_model, "embeddings"):
inner_model.embeddings.to(torch.float32)
else:
self.transformer.set_input_embeddings(self.transformer.get_input_embeddings().to(torch.float32))
self.max_length = max_length
if freeze:
self.freeze()
self.layer = layer
self.layer_idx = None
self.empty_tokens = [[49406] + [49407] * 76]
self.special_tokens = special_tokens
self.text_projection = torch.nn.Parameter(torch.eye(self.transformer.get_input_embeddings().weight.shape[1]))
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
self.enable_attention_masks = False
self.layer_norm_hidden_state = True
self.layer_norm_hidden_state = layer_norm_hidden_state
if layer == "hidden":
assert layer_idx is not None
assert abs(layer_idx) <= self.num_layers
@ -117,7 +145,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
else:
print("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored", y.shape[0], current_embeds.weight.shape[1])
while len(tokens_temp) < len(x):
tokens_temp += [self.empty_tokens[0][-1]]
tokens_temp += [self.special_tokens["pad"]]
out_tokens += [tokens_temp]
n = token_dict_size
@ -142,7 +170,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
tokens = torch.LongTensor(tokens).to(device)
if self.transformer.text_model.final_layer_norm.weight.dtype != torch.float32:
if getattr(self.transformer, self.inner_name).final_layer_norm.weight.dtype != torch.float32:
precision_scope = torch.autocast
else:
precision_scope = lambda a, b: contextlib.nullcontext(a)
@ -168,12 +196,16 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
else:
z = outputs.hidden_states[self.layer_idx]
if self.layer_norm_hidden_state:
z = self.transformer.text_model.final_layer_norm(z)
z = getattr(self.transformer, self.inner_name).final_layer_norm(z)
pooled_output = outputs.pooler_output
if self.text_projection is not None:
if hasattr(outputs, "pooler_output"):
pooled_output = outputs.pooler_output.float()
else:
pooled_output = None
if self.text_projection is not None and pooled_output is not None:
pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float()
return z.float(), pooled_output.float()
return z.float(), pooled_output
def encode(self, tokens):
return self(tokens)

View File

@ -9,8 +9,7 @@ class SD2ClipHModel(sd1_clip.SDClipModel):
layer_idx=23
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json")
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype)
self.empty_tokens = [[49406] + [49407] + [0] * 75]
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0})
class SD2ClipHTokenizer(sd1_clip.SDTokenizer):
def __init__(self, tokenizer_path=None, embedding_directory=None):

View File

@ -9,9 +9,8 @@ class SDXLClipG(sd1_clip.SDClipModel):
layer_idx=-2
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype)
self.empty_tokens = [[49406] + [49407] + [0] * 75]
self.layer_norm_hidden_state = False
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype,
special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False)
def load_sd(self, sd):
return super().load_sd(sd)
@ -38,8 +37,7 @@ class SDXLTokenizer:
class SDXLClipModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None):
super().__init__()
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=11, device=device, dtype=dtype)
self.clip_l.layer_norm_hidden_state = False
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=11, device=device, dtype=dtype, layer_norm_hidden_state=False)
self.clip_g = SDXLClipG(device=device, dtype=dtype)
def clip_layer(self, layer_idx):