mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Support returning text encoder attention masks.
This commit is contained in:
parent
90389b3b8a
commit
e44fa5667f
@ -38,7 +38,9 @@ class ClipTokenWeightEncoder:
|
|||||||
if has_weights or sections == 0:
|
if has_weights or sections == 0:
|
||||||
to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len))
|
to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len))
|
||||||
|
|
||||||
out, pooled = self.encode(to_encode)
|
o = self.encode(to_encode)
|
||||||
|
out, pooled = o[:2]
|
||||||
|
|
||||||
if pooled is not None:
|
if pooled is not None:
|
||||||
first_pooled = pooled[0:1].to(model_management.intermediate_device())
|
first_pooled = pooled[0:1].to(model_management.intermediate_device())
|
||||||
else:
|
else:
|
||||||
@ -57,8 +59,11 @@ class ClipTokenWeightEncoder:
|
|||||||
output.append(z)
|
output.append(z)
|
||||||
|
|
||||||
if (len(output) == 0):
|
if (len(output) == 0):
|
||||||
return out[-1:].to(model_management.intermediate_device()), first_pooled
|
r = (out[-1:].to(model_management.intermediate_device()), first_pooled)
|
||||||
return torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled
|
else:
|
||||||
|
r = (torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled)
|
||||||
|
r = r + tuple(map(lambda a: a[:sections].flatten().unsqueeze(dim=0).to(model_management.intermediate_device()), o[2:]))
|
||||||
|
return r
|
||||||
|
|
||||||
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||||
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
||||||
@ -70,7 +75,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
|
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
|
||||||
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=comfy.clip_model.CLIPTextModel,
|
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=comfy.clip_model.CLIPTextModel,
|
||||||
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False,
|
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False,
|
||||||
return_projected_pooled=True): # clip-vit-base-patch32
|
return_projected_pooled=True, return_attention_masks=False): # clip-vit-base-patch32
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert layer in self.LAYERS
|
assert layer in self.LAYERS
|
||||||
|
|
||||||
@ -96,6 +101,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
|
|
||||||
self.layer_norm_hidden_state = layer_norm_hidden_state
|
self.layer_norm_hidden_state = layer_norm_hidden_state
|
||||||
self.return_projected_pooled = return_projected_pooled
|
self.return_projected_pooled = return_projected_pooled
|
||||||
|
self.return_attention_masks = return_attention_masks
|
||||||
|
|
||||||
if layer == "hidden":
|
if layer == "hidden":
|
||||||
assert layer_idx is not None
|
assert layer_idx is not None
|
||||||
@ -169,7 +175,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
tokens = torch.LongTensor(tokens).to(device)
|
tokens = torch.LongTensor(tokens).to(device)
|
||||||
|
|
||||||
attention_mask = None
|
attention_mask = None
|
||||||
if self.enable_attention_masks or self.zero_out_masked:
|
if self.enable_attention_masks or self.zero_out_masked or self.return_attention_masks:
|
||||||
attention_mask = torch.zeros_like(tokens)
|
attention_mask = torch.zeros_like(tokens)
|
||||||
end_token = self.special_tokens.get("end", -1)
|
end_token = self.special_tokens.get("end", -1)
|
||||||
for x in range(attention_mask.shape[0]):
|
for x in range(attention_mask.shape[0]):
|
||||||
@ -200,6 +206,9 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
elif outputs[2] is not None:
|
elif outputs[2] is not None:
|
||||||
pooled_output = outputs[2].float()
|
pooled_output = outputs[2].float()
|
||||||
|
|
||||||
|
if self.return_attention_masks:
|
||||||
|
return z, pooled_output, attention_mask
|
||||||
|
|
||||||
return z, pooled_output
|
return z, pooled_output
|
||||||
|
|
||||||
def encode(self, tokens):
|
def encode(self, tokens):
|
||||||
|
Loading…
Reference in New Issue
Block a user