mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-20 03:13:30 +00:00
More flexibility with text encoder return values.
Text encoders can now return other values to the CONDITIONING than the cond and pooled output.
This commit is contained in:
parent
e44fa5667f
commit
391c1046cf
12
comfy/sd.py
12
comfy/sd.py
@ -130,7 +130,7 @@ class CLIP:
|
|||||||
def tokenize(self, text, return_word_ids=False):
|
def tokenize(self, text, return_word_ids=False):
|
||||||
return self.tokenizer.tokenize_with_weights(text, return_word_ids)
|
return self.tokenizer.tokenize_with_weights(text, return_word_ids)
|
||||||
|
|
||||||
def encode_from_tokens(self, tokens, return_pooled=False):
|
def encode_from_tokens(self, tokens, return_pooled=False, return_dict=False):
|
||||||
self.cond_stage_model.reset_clip_options()
|
self.cond_stage_model.reset_clip_options()
|
||||||
|
|
||||||
if self.layer_idx is not None:
|
if self.layer_idx is not None:
|
||||||
@ -140,7 +140,15 @@ class CLIP:
|
|||||||
self.cond_stage_model.set_clip_options({"projected_pooled": False})
|
self.cond_stage_model.set_clip_options({"projected_pooled": False})
|
||||||
|
|
||||||
self.load_model()
|
self.load_model()
|
||||||
cond, pooled = self.cond_stage_model.encode_token_weights(tokens)
|
o = self.cond_stage_model.encode_token_weights(tokens)
|
||||||
|
cond, pooled = o[:2]
|
||||||
|
if return_dict:
|
||||||
|
out = {"cond": cond, "pooled_output": pooled}
|
||||||
|
if len(o) > 2:
|
||||||
|
for k in o[2]:
|
||||||
|
out[k] = o[2][k]
|
||||||
|
return out
|
||||||
|
|
||||||
if return_pooled:
|
if return_pooled:
|
||||||
return cond, pooled
|
return cond, pooled
|
||||||
return cond
|
return cond
|
||||||
|
@ -62,7 +62,16 @@ class ClipTokenWeightEncoder:
|
|||||||
r = (out[-1:].to(model_management.intermediate_device()), first_pooled)
|
r = (out[-1:].to(model_management.intermediate_device()), first_pooled)
|
||||||
else:
|
else:
|
||||||
r = (torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled)
|
r = (torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled)
|
||||||
r = r + tuple(map(lambda a: a[:sections].flatten().unsqueeze(dim=0).to(model_management.intermediate_device()), o[2:]))
|
|
||||||
|
if len(o) > 2:
|
||||||
|
extra = {}
|
||||||
|
for k in o[2]:
|
||||||
|
v = o[2][k]
|
||||||
|
if k == "attention_mask":
|
||||||
|
v = v[:sections].flatten().unsqueeze(dim=0).to(model_management.intermediate_device())
|
||||||
|
extra[k] = v
|
||||||
|
|
||||||
|
r = r + (extra,)
|
||||||
return r
|
return r
|
||||||
|
|
||||||
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||||
@ -206,8 +215,12 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
elif outputs[2] is not None:
|
elif outputs[2] is not None:
|
||||||
pooled_output = outputs[2].float()
|
pooled_output = outputs[2].float()
|
||||||
|
|
||||||
|
extra = {}
|
||||||
if self.return_attention_masks:
|
if self.return_attention_masks:
|
||||||
return z, pooled_output, attention_mask
|
extra["attention_mask"] = attention_mask
|
||||||
|
|
||||||
|
if len(extra) > 0:
|
||||||
|
return z, pooled_output, extra
|
||||||
|
|
||||||
return z, pooled_output
|
return z, pooled_output
|
||||||
|
|
||||||
@ -547,8 +560,8 @@ class SD1ClipModel(torch.nn.Module):
|
|||||||
|
|
||||||
def encode_token_weights(self, token_weight_pairs):
|
def encode_token_weights(self, token_weight_pairs):
|
||||||
token_weight_pairs = token_weight_pairs[self.clip_name]
|
token_weight_pairs = token_weight_pairs[self.clip_name]
|
||||||
out, pooled = getattr(self, self.clip).encode_token_weights(token_weight_pairs)
|
out = getattr(self, self.clip).encode_token_weights(token_weight_pairs)
|
||||||
return out, pooled
|
return out
|
||||||
|
|
||||||
def load_sd(self, sd):
|
def load_sd(self, sd):
|
||||||
return getattr(self, self.clip).load_sd(sd)
|
return getattr(self, self.clip).load_sd(sd)
|
||||||
|
5
nodes.py
5
nodes.py
@ -55,8 +55,9 @@ class CLIPTextEncode:
|
|||||||
|
|
||||||
def encode(self, clip, text):
|
def encode(self, clip, text):
|
||||||
tokens = clip.tokenize(text)
|
tokens = clip.tokenize(text)
|
||||||
cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True)
|
output = clip.encode_from_tokens(tokens, return_pooled=True, return_dict=True)
|
||||||
return ([[cond, {"pooled_output": pooled}]], )
|
cond = output.pop("cond")
|
||||||
|
return ([[cond, output]], )
|
||||||
|
|
||||||
class ConditioningCombine:
|
class ConditioningCombine:
|
||||||
@classmethod
|
@classmethod
|
||||||
|
Loading…
Reference in New Issue
Block a user