Always return unprojected pooled output for gligen.

This commit is contained in:
comfyanonymous 2024-02-25 07:20:31 -05:00
parent 1cb3f6a83b
commit c2cb8e889b
6 changed files with 38 additions and 28 deletions

View File

@ -133,7 +133,7 @@ class CLIPTextModel(torch.nn.Module):
def forward(self, *args, **kwargs):
x = self.text_model(*args, **kwargs)
out = self.text_projection(x[2])
return (x[0], x[1], out)
return (x[0], x[1], out, x[2])
class CLIPVisionEmbeddings(torch.nn.Module):

View File

@ -201,9 +201,9 @@ def model_lora_keys_clip(model, key_map={}):
key_map[lora_key] = k
k = "clip_g.text_projection"
k = "clip_g.transformer.text_projection.weight"
if k in sdk:
key_map["lora_prior_te_text_projection"] = k #cascade lora
key_map["lora_prior_te_text_projection"] = k #cascade lora?
# key_map["text_encoder.text_projection"] = k #TODO: check if other lora have the text_projection too
# key_map["lora_te_text_projection"] = k

View File

@ -123,10 +123,13 @@ class CLIP:
return self.tokenizer.tokenize_with_weights(text, return_word_ids)
def encode_from_tokens(self, tokens, return_pooled=False):
self.cond_stage_model.reset_clip_options()
if self.layer_idx is not None:
self.cond_stage_model.clip_layer(self.layer_idx)
else:
self.cond_stage_model.reset_clip_layer()
self.cond_stage_model.set_clip_options({"layer": self.layer_idx})
if return_pooled == "unprojected":
self.cond_stage_model.set_clip_options({"projected_pooled": False})
self.load_model()
cond, pooled = self.cond_stage_model.encode_token_weights(tokens)

View File

@ -91,11 +91,13 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
self.enable_attention_masks = enable_attention_masks
self.layer_norm_hidden_state = layer_norm_hidden_state
self.return_projected_pooled = True
if layer == "hidden":
assert layer_idx is not None
assert abs(layer_idx) < self.num_layers
self.clip_layer(layer_idx)
self.layer_default = (self.layer, self.layer_idx)
self.set_clip_options({"layer": layer_idx})
self.options_default = (self.layer, self.layer_idx, self.return_projected_pooled)
def freeze(self):
self.transformer = self.transformer.eval()
@ -103,16 +105,19 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
for param in self.parameters():
param.requires_grad = False
def clip_layer(self, layer_idx):
if abs(layer_idx) > self.num_layers:
def set_clip_options(self, options):
layer_idx = options.get("layer", self.layer_idx)
self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled)
if layer_idx is None or abs(layer_idx) > self.num_layers:
self.layer = "last"
else:
self.layer = "hidden"
self.layer_idx = layer_idx
def reset_clip_layer(self):
self.layer = self.layer_default[0]
self.layer_idx = self.layer_default[1]
def reset_clip_options(self):
self.layer = self.options_default[0]
self.layer_idx = self.options_default[1]
self.return_projected_pooled = self.options_default[2]
def set_up_textual_embeddings(self, tokens, current_embeds):
out_tokens = []
@ -177,10 +182,12 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
else:
z = outputs[1]
if outputs[2] is not None:
pooled_output = outputs[2].float()
else:
pooled_output = None
if len(outputs) >= 3:
if not self.return_projected_pooled and len(outputs) >= 4 and outputs[3] is not None:
pooled_output = outputs[3].float()
elif outputs[2] is not None:
pooled_output = outputs[2].float()
return z.float(), pooled_output
@ -497,11 +504,11 @@ class SD1ClipModel(torch.nn.Module):
self.clip = "clip_{}".format(self.clip_name)
setattr(self, self.clip, clip_model(device=device, dtype=dtype, **kwargs))
def clip_layer(self, layer_idx):
getattr(self, self.clip).clip_layer(layer_idx)
def set_clip_options(self, options):
getattr(self, self.clip).set_clip_options(options)
def reset_clip_layer(self):
getattr(self, self.clip).reset_clip_layer()
def reset_clip_options(self):
getattr(self, self.clip).reset_clip_options()
def encode_token_weights(self, token_weight_pairs):
token_weight_pairs = token_weight_pairs[self.clip_name]

View File

@ -40,13 +40,13 @@ class SDXLClipModel(torch.nn.Module):
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False)
self.clip_g = SDXLClipG(device=device, dtype=dtype)
def clip_layer(self, layer_idx):
self.clip_l.clip_layer(layer_idx)
self.clip_g.clip_layer(layer_idx)
def set_clip_options(self, options):
self.clip_l.set_clip_options(options)
self.clip_g.set_clip_options(options)
def reset_clip_layer(self):
self.clip_g.reset_clip_layer()
self.clip_l.reset_clip_layer()
def reset_clip_options(self):
self.clip_g.reset_clip_options()
self.clip_l.reset_clip_options()
def encode_token_weights(self, token_weight_pairs):
token_weight_pairs_g = token_weight_pairs["g"]

View File

@ -1003,7 +1003,7 @@ class GLIGENTextBoxApply:
def append(self, conditioning_to, clip, gligen_textbox_model, text, width, height, x, y):
c = []
cond, cond_pooled = clip.encode_from_tokens(clip.tokenize(text), return_pooled=True)
cond, cond_pooled = clip.encode_from_tokens(clip.tokenize(text), return_pooled="unprojected")
for t in conditioning_to:
n = [t[0], t[1].copy()]
position_params = [(cond_pooled, height // 8, width // 8, y // 8, x // 8)]