Support returning text encoder attention masks.

This commit is contained in:
comfyanonymous 2024-07-10 19:31:22 -04:00
parent 90389b3b8a
commit e44fa5667f

View File

@ -38,7 +38,9 @@ class ClipTokenWeightEncoder:
if has_weights or sections == 0: if has_weights or sections == 0:
to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len)) to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len))
out, pooled = self.encode(to_encode) o = self.encode(to_encode)
out, pooled = o[:2]
if pooled is not None: if pooled is not None:
first_pooled = pooled[0:1].to(model_management.intermediate_device()) first_pooled = pooled[0:1].to(model_management.intermediate_device())
else: else:
@ -57,8 +59,11 @@ class ClipTokenWeightEncoder:
output.append(z) output.append(z)
if (len(output) == 0): if (len(output) == 0):
return out[-1:].to(model_management.intermediate_device()), first_pooled r = (out[-1:].to(model_management.intermediate_device()), first_pooled)
return torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled else:
r = (torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled)
r = r + tuple(map(lambda a: a[:sections].flatten().unsqueeze(dim=0).to(model_management.intermediate_device()), o[2:]))
return r
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
"""Uses the CLIP transformer encoder for text (from huggingface)""" """Uses the CLIP transformer encoder for text (from huggingface)"""
@ -70,7 +75,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77, def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=comfy.clip_model.CLIPTextModel, freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=comfy.clip_model.CLIPTextModel,
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False, special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False,
return_projected_pooled=True): # clip-vit-base-patch32 return_projected_pooled=True, return_attention_masks=False): # clip-vit-base-patch32
super().__init__() super().__init__()
assert layer in self.LAYERS assert layer in self.LAYERS
@ -96,6 +101,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
self.layer_norm_hidden_state = layer_norm_hidden_state self.layer_norm_hidden_state = layer_norm_hidden_state
self.return_projected_pooled = return_projected_pooled self.return_projected_pooled = return_projected_pooled
self.return_attention_masks = return_attention_masks
if layer == "hidden": if layer == "hidden":
assert layer_idx is not None assert layer_idx is not None
@ -169,7 +175,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
tokens = torch.LongTensor(tokens).to(device) tokens = torch.LongTensor(tokens).to(device)
attention_mask = None attention_mask = None
if self.enable_attention_masks or self.zero_out_masked: if self.enable_attention_masks or self.zero_out_masked or self.return_attention_masks:
attention_mask = torch.zeros_like(tokens) attention_mask = torch.zeros_like(tokens)
end_token = self.special_tokens.get("end", -1) end_token = self.special_tokens.get("end", -1)
for x in range(attention_mask.shape[0]): for x in range(attention_mask.shape[0]):
@ -200,6 +206,9 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
elif outputs[2] is not None: elif outputs[2] is not None:
pooled_output = outputs[2].float() pooled_output = outputs[2].float()
if self.return_attention_masks:
return z, pooled_output, attention_mask
return z, pooled_output return z, pooled_output
def encode(self, tokens): def encode(self, tokens):