Always return unprojected pooled output for gligen.

This commit is contained in:
comfyanonymous 2024-02-25 07:20:31 -05:00
parent 1cb3f6a83b
commit c2cb8e889b
6 changed files with 38 additions and 28 deletions

View File

@ -133,7 +133,7 @@ class CLIPTextModel(torch.nn.Module):
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
x = self.text_model(*args, **kwargs) x = self.text_model(*args, **kwargs)
out = self.text_projection(x[2]) out = self.text_projection(x[2])
return (x[0], x[1], out) return (x[0], x[1], out, x[2])
class CLIPVisionEmbeddings(torch.nn.Module): class CLIPVisionEmbeddings(torch.nn.Module):

View File

@ -201,9 +201,9 @@ def model_lora_keys_clip(model, key_map={}):
key_map[lora_key] = k key_map[lora_key] = k
k = "clip_g.text_projection" k = "clip_g.transformer.text_projection.weight"
if k in sdk: if k in sdk:
key_map["lora_prior_te_text_projection"] = k #cascade lora key_map["lora_prior_te_text_projection"] = k #cascade lora?
# key_map["text_encoder.text_projection"] = k #TODO: check if other lora have the text_projection too # key_map["text_encoder.text_projection"] = k #TODO: check if other lora have the text_projection too
# key_map["lora_te_text_projection"] = k # key_map["lora_te_text_projection"] = k

View File

@ -123,10 +123,13 @@ class CLIP:
return self.tokenizer.tokenize_with_weights(text, return_word_ids) return self.tokenizer.tokenize_with_weights(text, return_word_ids)
def encode_from_tokens(self, tokens, return_pooled=False): def encode_from_tokens(self, tokens, return_pooled=False):
self.cond_stage_model.reset_clip_options()
if self.layer_idx is not None: if self.layer_idx is not None:
self.cond_stage_model.clip_layer(self.layer_idx) self.cond_stage_model.set_clip_options({"layer": self.layer_idx})
else:
self.cond_stage_model.reset_clip_layer() if return_pooled == "unprojected":
self.cond_stage_model.set_clip_options({"projected_pooled": False})
self.load_model() self.load_model()
cond, pooled = self.cond_stage_model.encode_token_weights(tokens) cond, pooled = self.cond_stage_model.encode_token_weights(tokens)

View File

@ -91,11 +91,13 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
self.enable_attention_masks = enable_attention_masks self.enable_attention_masks = enable_attention_masks
self.layer_norm_hidden_state = layer_norm_hidden_state self.layer_norm_hidden_state = layer_norm_hidden_state
self.return_projected_pooled = True
if layer == "hidden": if layer == "hidden":
assert layer_idx is not None assert layer_idx is not None
assert abs(layer_idx) < self.num_layers assert abs(layer_idx) < self.num_layers
self.clip_layer(layer_idx) self.set_clip_options({"layer": layer_idx})
self.layer_default = (self.layer, self.layer_idx) self.options_default = (self.layer, self.layer_idx, self.return_projected_pooled)
def freeze(self): def freeze(self):
self.transformer = self.transformer.eval() self.transformer = self.transformer.eval()
@ -103,16 +105,19 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
for param in self.parameters(): for param in self.parameters():
param.requires_grad = False param.requires_grad = False
def clip_layer(self, layer_idx): def set_clip_options(self, options):
if abs(layer_idx) > self.num_layers: layer_idx = options.get("layer", self.layer_idx)
self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled)
if layer_idx is None or abs(layer_idx) > self.num_layers:
self.layer = "last" self.layer = "last"
else: else:
self.layer = "hidden" self.layer = "hidden"
self.layer_idx = layer_idx self.layer_idx = layer_idx
def reset_clip_layer(self): def reset_clip_options(self):
self.layer = self.layer_default[0] self.layer = self.options_default[0]
self.layer_idx = self.layer_default[1] self.layer_idx = self.options_default[1]
self.return_projected_pooled = self.options_default[2]
def set_up_textual_embeddings(self, tokens, current_embeds): def set_up_textual_embeddings(self, tokens, current_embeds):
out_tokens = [] out_tokens = []
@ -177,10 +182,12 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
else: else:
z = outputs[1] z = outputs[1]
if outputs[2] is not None:
pooled_output = outputs[2].float()
else:
pooled_output = None pooled_output = None
if len(outputs) >= 3:
if not self.return_projected_pooled and len(outputs) >= 4 and outputs[3] is not None:
pooled_output = outputs[3].float()
elif outputs[2] is not None:
pooled_output = outputs[2].float()
return z.float(), pooled_output return z.float(), pooled_output
@ -497,11 +504,11 @@ class SD1ClipModel(torch.nn.Module):
self.clip = "clip_{}".format(self.clip_name) self.clip = "clip_{}".format(self.clip_name)
setattr(self, self.clip, clip_model(device=device, dtype=dtype, **kwargs)) setattr(self, self.clip, clip_model(device=device, dtype=dtype, **kwargs))
def clip_layer(self, layer_idx): def set_clip_options(self, options):
getattr(self, self.clip).clip_layer(layer_idx) getattr(self, self.clip).set_clip_options(options)
def reset_clip_layer(self): def reset_clip_options(self):
getattr(self, self.clip).reset_clip_layer() getattr(self, self.clip).reset_clip_options()
def encode_token_weights(self, token_weight_pairs): def encode_token_weights(self, token_weight_pairs):
token_weight_pairs = token_weight_pairs[self.clip_name] token_weight_pairs = token_weight_pairs[self.clip_name]

View File

@ -40,13 +40,13 @@ class SDXLClipModel(torch.nn.Module):
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False) self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False)
self.clip_g = SDXLClipG(device=device, dtype=dtype) self.clip_g = SDXLClipG(device=device, dtype=dtype)
def clip_layer(self, layer_idx): def set_clip_options(self, options):
self.clip_l.clip_layer(layer_idx) self.clip_l.set_clip_options(options)
self.clip_g.clip_layer(layer_idx) self.clip_g.set_clip_options(options)
def reset_clip_layer(self): def reset_clip_options(self):
self.clip_g.reset_clip_layer() self.clip_g.reset_clip_options()
self.clip_l.reset_clip_layer() self.clip_l.reset_clip_options()
def encode_token_weights(self, token_weight_pairs): def encode_token_weights(self, token_weight_pairs):
token_weight_pairs_g = token_weight_pairs["g"] token_weight_pairs_g = token_weight_pairs["g"]

View File

@ -1003,7 +1003,7 @@ class GLIGENTextBoxApply:
def append(self, conditioning_to, clip, gligen_textbox_model, text, width, height, x, y): def append(self, conditioning_to, clip, gligen_textbox_model, text, width, height, x, y):
c = [] c = []
cond, cond_pooled = clip.encode_from_tokens(clip.tokenize(text), return_pooled=True) cond, cond_pooled = clip.encode_from_tokens(clip.tokenize(text), return_pooled="unprojected")
for t in conditioning_to: for t in conditioning_to:
n = [t[0], t[1].copy()] n = [t[0], t[1].copy()]
position_params = [(cond_pooled, height // 8, width // 8, y // 8, x // 8)] position_params = [(cond_pooled, height // 8, width // 8, y // 8, x // 8)]