mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Always return unprojected pooled output for gligen.
This commit is contained in:
parent
1cb3f6a83b
commit
c2cb8e889b
@ -133,7 +133,7 @@ class CLIPTextModel(torch.nn.Module):
|
||||
def forward(self, *args, **kwargs):
|
||||
x = self.text_model(*args, **kwargs)
|
||||
out = self.text_projection(x[2])
|
||||
return (x[0], x[1], out)
|
||||
return (x[0], x[1], out, x[2])
|
||||
|
||||
|
||||
class CLIPVisionEmbeddings(torch.nn.Module):
|
||||
|
@ -201,9 +201,9 @@ def model_lora_keys_clip(model, key_map={}):
|
||||
key_map[lora_key] = k
|
||||
|
||||
|
||||
k = "clip_g.text_projection"
|
||||
k = "clip_g.transformer.text_projection.weight"
|
||||
if k in sdk:
|
||||
key_map["lora_prior_te_text_projection"] = k #cascade lora
|
||||
key_map["lora_prior_te_text_projection"] = k #cascade lora?
|
||||
# key_map["text_encoder.text_projection"] = k #TODO: check if other lora have the text_projection too
|
||||
# key_map["lora_te_text_projection"] = k
|
||||
|
||||
|
@ -123,10 +123,13 @@ class CLIP:
|
||||
return self.tokenizer.tokenize_with_weights(text, return_word_ids)
|
||||
|
||||
def encode_from_tokens(self, tokens, return_pooled=False):
|
||||
self.cond_stage_model.reset_clip_options()
|
||||
|
||||
if self.layer_idx is not None:
|
||||
self.cond_stage_model.clip_layer(self.layer_idx)
|
||||
else:
|
||||
self.cond_stage_model.reset_clip_layer()
|
||||
self.cond_stage_model.set_clip_options({"layer": self.layer_idx})
|
||||
|
||||
if return_pooled == "unprojected":
|
||||
self.cond_stage_model.set_clip_options({"projected_pooled": False})
|
||||
|
||||
self.load_model()
|
||||
cond, pooled = self.cond_stage_model.encode_token_weights(tokens)
|
||||
|
@ -91,11 +91,13 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
self.enable_attention_masks = enable_attention_masks
|
||||
|
||||
self.layer_norm_hidden_state = layer_norm_hidden_state
|
||||
self.return_projected_pooled = True
|
||||
|
||||
if layer == "hidden":
|
||||
assert layer_idx is not None
|
||||
assert abs(layer_idx) < self.num_layers
|
||||
self.clip_layer(layer_idx)
|
||||
self.layer_default = (self.layer, self.layer_idx)
|
||||
self.set_clip_options({"layer": layer_idx})
|
||||
self.options_default = (self.layer, self.layer_idx, self.return_projected_pooled)
|
||||
|
||||
def freeze(self):
|
||||
self.transformer = self.transformer.eval()
|
||||
@ -103,16 +105,19 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def clip_layer(self, layer_idx):
|
||||
if abs(layer_idx) > self.num_layers:
|
||||
def set_clip_options(self, options):
|
||||
layer_idx = options.get("layer", self.layer_idx)
|
||||
self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled)
|
||||
if layer_idx is None or abs(layer_idx) > self.num_layers:
|
||||
self.layer = "last"
|
||||
else:
|
||||
self.layer = "hidden"
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
def reset_clip_layer(self):
|
||||
self.layer = self.layer_default[0]
|
||||
self.layer_idx = self.layer_default[1]
|
||||
def reset_clip_options(self):
|
||||
self.layer = self.options_default[0]
|
||||
self.layer_idx = self.options_default[1]
|
||||
self.return_projected_pooled = self.options_default[2]
|
||||
|
||||
def set_up_textual_embeddings(self, tokens, current_embeds):
|
||||
out_tokens = []
|
||||
@ -177,10 +182,12 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
else:
|
||||
z = outputs[1]
|
||||
|
||||
if outputs[2] is not None:
|
||||
pooled_output = outputs[2].float()
|
||||
else:
|
||||
pooled_output = None
|
||||
if len(outputs) >= 3:
|
||||
if not self.return_projected_pooled and len(outputs) >= 4 and outputs[3] is not None:
|
||||
pooled_output = outputs[3].float()
|
||||
elif outputs[2] is not None:
|
||||
pooled_output = outputs[2].float()
|
||||
|
||||
return z.float(), pooled_output
|
||||
|
||||
@ -497,11 +504,11 @@ class SD1ClipModel(torch.nn.Module):
|
||||
self.clip = "clip_{}".format(self.clip_name)
|
||||
setattr(self, self.clip, clip_model(device=device, dtype=dtype, **kwargs))
|
||||
|
||||
def clip_layer(self, layer_idx):
|
||||
getattr(self, self.clip).clip_layer(layer_idx)
|
||||
def set_clip_options(self, options):
|
||||
getattr(self, self.clip).set_clip_options(options)
|
||||
|
||||
def reset_clip_layer(self):
|
||||
getattr(self, self.clip).reset_clip_layer()
|
||||
def reset_clip_options(self):
|
||||
getattr(self, self.clip).reset_clip_options()
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
token_weight_pairs = token_weight_pairs[self.clip_name]
|
||||
|
@ -40,13 +40,13 @@ class SDXLClipModel(torch.nn.Module):
|
||||
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False)
|
||||
self.clip_g = SDXLClipG(device=device, dtype=dtype)
|
||||
|
||||
def clip_layer(self, layer_idx):
|
||||
self.clip_l.clip_layer(layer_idx)
|
||||
self.clip_g.clip_layer(layer_idx)
|
||||
def set_clip_options(self, options):
|
||||
self.clip_l.set_clip_options(options)
|
||||
self.clip_g.set_clip_options(options)
|
||||
|
||||
def reset_clip_layer(self):
|
||||
self.clip_g.reset_clip_layer()
|
||||
self.clip_l.reset_clip_layer()
|
||||
def reset_clip_options(self):
|
||||
self.clip_g.reset_clip_options()
|
||||
self.clip_l.reset_clip_options()
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
token_weight_pairs_g = token_weight_pairs["g"]
|
||||
|
2
nodes.py
2
nodes.py
@ -1003,7 +1003,7 @@ class GLIGENTextBoxApply:
|
||||
|
||||
def append(self, conditioning_to, clip, gligen_textbox_model, text, width, height, x, y):
|
||||
c = []
|
||||
cond, cond_pooled = clip.encode_from_tokens(clip.tokenize(text), return_pooled=True)
|
||||
cond, cond_pooled = clip.encode_from_tokens(clip.tokenize(text), return_pooled="unprojected")
|
||||
for t in conditioning_to:
|
||||
n = [t[0], t[1].copy()]
|
||||
position_params = [(cond_pooled, height // 8, width // 8, y // 8, x // 8)]
|
||||
|
Loading…
Reference in New Issue
Block a user