From e44fa5667fd0a5469b1b5efa08187282779f4d44 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 10 Jul 2024 19:31:22 -0400 Subject: [PATCH] Support returning text encoder attention masks. --- comfy/sd1_clip.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 0fe1f1d1..4da2b46f 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -38,7 +38,9 @@ class ClipTokenWeightEncoder: if has_weights or sections == 0: to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len)) - out, pooled = self.encode(to_encode) + o = self.encode(to_encode) + out, pooled = o[:2] + if pooled is not None: first_pooled = pooled[0:1].to(model_management.intermediate_device()) else: @@ -57,8 +59,11 @@ class ClipTokenWeightEncoder: output.append(z) if (len(output) == 0): - return out[-1:].to(model_management.intermediate_device()), first_pooled - return torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled + r = (out[-1:].to(model_management.intermediate_device()), first_pooled) + else: + r = (torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled) + r = r + tuple(map(lambda a: a[:sections].flatten().unsqueeze(dim=0).to(model_management.intermediate_device()), o[2:])) + return r class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): """Uses the CLIP transformer encoder for text (from huggingface)""" @@ -70,7 +75,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77, freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=comfy.clip_model.CLIPTextModel, special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False, - return_projected_pooled=True): # clip-vit-base-patch32 + return_projected_pooled=True, return_attention_masks=False): # clip-vit-base-patch32 super().__init__() assert layer in self.LAYERS @@ -96,6 +101,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): self.layer_norm_hidden_state = layer_norm_hidden_state self.return_projected_pooled = return_projected_pooled + self.return_attention_masks = return_attention_masks if layer == "hidden": assert layer_idx is not None @@ -169,7 +175,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): tokens = torch.LongTensor(tokens).to(device) attention_mask = None - if self.enable_attention_masks or self.zero_out_masked: + if self.enable_attention_masks or self.zero_out_masked or self.return_attention_masks: attention_mask = torch.zeros_like(tokens) end_token = self.special_tokens.get("end", -1) for x in range(attention_mask.shape[0]): @@ -200,6 +206,9 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): elif outputs[2] is not None: pooled_output = outputs[2].float() + if self.return_attention_masks: + return z, pooled_output, attention_mask + return z, pooled_output def encode(self, tokens):