mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Always return unprojected pooled output for gligen.
This commit is contained in:
parent
1cb3f6a83b
commit
c2cb8e889b
@ -133,7 +133,7 @@ class CLIPTextModel(torch.nn.Module):
|
|||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
x = self.text_model(*args, **kwargs)
|
x = self.text_model(*args, **kwargs)
|
||||||
out = self.text_projection(x[2])
|
out = self.text_projection(x[2])
|
||||||
return (x[0], x[1], out)
|
return (x[0], x[1], out, x[2])
|
||||||
|
|
||||||
|
|
||||||
class CLIPVisionEmbeddings(torch.nn.Module):
|
class CLIPVisionEmbeddings(torch.nn.Module):
|
||||||
|
@ -201,9 +201,9 @@ def model_lora_keys_clip(model, key_map={}):
|
|||||||
key_map[lora_key] = k
|
key_map[lora_key] = k
|
||||||
|
|
||||||
|
|
||||||
k = "clip_g.text_projection"
|
k = "clip_g.transformer.text_projection.weight"
|
||||||
if k in sdk:
|
if k in sdk:
|
||||||
key_map["lora_prior_te_text_projection"] = k #cascade lora
|
key_map["lora_prior_te_text_projection"] = k #cascade lora?
|
||||||
# key_map["text_encoder.text_projection"] = k #TODO: check if other lora have the text_projection too
|
# key_map["text_encoder.text_projection"] = k #TODO: check if other lora have the text_projection too
|
||||||
# key_map["lora_te_text_projection"] = k
|
# key_map["lora_te_text_projection"] = k
|
||||||
|
|
||||||
|
@ -123,10 +123,13 @@ class CLIP:
|
|||||||
return self.tokenizer.tokenize_with_weights(text, return_word_ids)
|
return self.tokenizer.tokenize_with_weights(text, return_word_ids)
|
||||||
|
|
||||||
def encode_from_tokens(self, tokens, return_pooled=False):
|
def encode_from_tokens(self, tokens, return_pooled=False):
|
||||||
|
self.cond_stage_model.reset_clip_options()
|
||||||
|
|
||||||
if self.layer_idx is not None:
|
if self.layer_idx is not None:
|
||||||
self.cond_stage_model.clip_layer(self.layer_idx)
|
self.cond_stage_model.set_clip_options({"layer": self.layer_idx})
|
||||||
else:
|
|
||||||
self.cond_stage_model.reset_clip_layer()
|
if return_pooled == "unprojected":
|
||||||
|
self.cond_stage_model.set_clip_options({"projected_pooled": False})
|
||||||
|
|
||||||
self.load_model()
|
self.load_model()
|
||||||
cond, pooled = self.cond_stage_model.encode_token_weights(tokens)
|
cond, pooled = self.cond_stage_model.encode_token_weights(tokens)
|
||||||
|
@ -91,11 +91,13 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
self.enable_attention_masks = enable_attention_masks
|
self.enable_attention_masks = enable_attention_masks
|
||||||
|
|
||||||
self.layer_norm_hidden_state = layer_norm_hidden_state
|
self.layer_norm_hidden_state = layer_norm_hidden_state
|
||||||
|
self.return_projected_pooled = True
|
||||||
|
|
||||||
if layer == "hidden":
|
if layer == "hidden":
|
||||||
assert layer_idx is not None
|
assert layer_idx is not None
|
||||||
assert abs(layer_idx) < self.num_layers
|
assert abs(layer_idx) < self.num_layers
|
||||||
self.clip_layer(layer_idx)
|
self.set_clip_options({"layer": layer_idx})
|
||||||
self.layer_default = (self.layer, self.layer_idx)
|
self.options_default = (self.layer, self.layer_idx, self.return_projected_pooled)
|
||||||
|
|
||||||
def freeze(self):
|
def freeze(self):
|
||||||
self.transformer = self.transformer.eval()
|
self.transformer = self.transformer.eval()
|
||||||
@ -103,16 +105,19 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
for param in self.parameters():
|
for param in self.parameters():
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
|
|
||||||
def clip_layer(self, layer_idx):
|
def set_clip_options(self, options):
|
||||||
if abs(layer_idx) > self.num_layers:
|
layer_idx = options.get("layer", self.layer_idx)
|
||||||
|
self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled)
|
||||||
|
if layer_idx is None or abs(layer_idx) > self.num_layers:
|
||||||
self.layer = "last"
|
self.layer = "last"
|
||||||
else:
|
else:
|
||||||
self.layer = "hidden"
|
self.layer = "hidden"
|
||||||
self.layer_idx = layer_idx
|
self.layer_idx = layer_idx
|
||||||
|
|
||||||
def reset_clip_layer(self):
|
def reset_clip_options(self):
|
||||||
self.layer = self.layer_default[0]
|
self.layer = self.options_default[0]
|
||||||
self.layer_idx = self.layer_default[1]
|
self.layer_idx = self.options_default[1]
|
||||||
|
self.return_projected_pooled = self.options_default[2]
|
||||||
|
|
||||||
def set_up_textual_embeddings(self, tokens, current_embeds):
|
def set_up_textual_embeddings(self, tokens, current_embeds):
|
||||||
out_tokens = []
|
out_tokens = []
|
||||||
@ -177,10 +182,12 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
else:
|
else:
|
||||||
z = outputs[1]
|
z = outputs[1]
|
||||||
|
|
||||||
if outputs[2] is not None:
|
|
||||||
pooled_output = outputs[2].float()
|
|
||||||
else:
|
|
||||||
pooled_output = None
|
pooled_output = None
|
||||||
|
if len(outputs) >= 3:
|
||||||
|
if not self.return_projected_pooled and len(outputs) >= 4 and outputs[3] is not None:
|
||||||
|
pooled_output = outputs[3].float()
|
||||||
|
elif outputs[2] is not None:
|
||||||
|
pooled_output = outputs[2].float()
|
||||||
|
|
||||||
return z.float(), pooled_output
|
return z.float(), pooled_output
|
||||||
|
|
||||||
@ -497,11 +504,11 @@ class SD1ClipModel(torch.nn.Module):
|
|||||||
self.clip = "clip_{}".format(self.clip_name)
|
self.clip = "clip_{}".format(self.clip_name)
|
||||||
setattr(self, self.clip, clip_model(device=device, dtype=dtype, **kwargs))
|
setattr(self, self.clip, clip_model(device=device, dtype=dtype, **kwargs))
|
||||||
|
|
||||||
def clip_layer(self, layer_idx):
|
def set_clip_options(self, options):
|
||||||
getattr(self, self.clip).clip_layer(layer_idx)
|
getattr(self, self.clip).set_clip_options(options)
|
||||||
|
|
||||||
def reset_clip_layer(self):
|
def reset_clip_options(self):
|
||||||
getattr(self, self.clip).reset_clip_layer()
|
getattr(self, self.clip).reset_clip_options()
|
||||||
|
|
||||||
def encode_token_weights(self, token_weight_pairs):
|
def encode_token_weights(self, token_weight_pairs):
|
||||||
token_weight_pairs = token_weight_pairs[self.clip_name]
|
token_weight_pairs = token_weight_pairs[self.clip_name]
|
||||||
|
@ -40,13 +40,13 @@ class SDXLClipModel(torch.nn.Module):
|
|||||||
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False)
|
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False)
|
||||||
self.clip_g = SDXLClipG(device=device, dtype=dtype)
|
self.clip_g = SDXLClipG(device=device, dtype=dtype)
|
||||||
|
|
||||||
def clip_layer(self, layer_idx):
|
def set_clip_options(self, options):
|
||||||
self.clip_l.clip_layer(layer_idx)
|
self.clip_l.set_clip_options(options)
|
||||||
self.clip_g.clip_layer(layer_idx)
|
self.clip_g.set_clip_options(options)
|
||||||
|
|
||||||
def reset_clip_layer(self):
|
def reset_clip_options(self):
|
||||||
self.clip_g.reset_clip_layer()
|
self.clip_g.reset_clip_options()
|
||||||
self.clip_l.reset_clip_layer()
|
self.clip_l.reset_clip_options()
|
||||||
|
|
||||||
def encode_token_weights(self, token_weight_pairs):
|
def encode_token_weights(self, token_weight_pairs):
|
||||||
token_weight_pairs_g = token_weight_pairs["g"]
|
token_weight_pairs_g = token_weight_pairs["g"]
|
||||||
|
2
nodes.py
2
nodes.py
@ -1003,7 +1003,7 @@ class GLIGENTextBoxApply:
|
|||||||
|
|
||||||
def append(self, conditioning_to, clip, gligen_textbox_model, text, width, height, x, y):
|
def append(self, conditioning_to, clip, gligen_textbox_model, text, width, height, x, y):
|
||||||
c = []
|
c = []
|
||||||
cond, cond_pooled = clip.encode_from_tokens(clip.tokenize(text), return_pooled=True)
|
cond, cond_pooled = clip.encode_from_tokens(clip.tokenize(text), return_pooled="unprojected")
|
||||||
for t in conditioning_to:
|
for t in conditioning_to:
|
||||||
n = [t[0], t[1].copy()]
|
n = [t[0], t[1].copy()]
|
||||||
position_params = [(cond_pooled, height // 8, width // 8, y // 8, x // 8)]
|
position_params = [(cond_pooled, height // 8, width // 8, y // 8, x // 8)]
|
||||||
|
Loading…
Reference in New Issue
Block a user