From efb704c758f916bdf3b8fcaa3c2ade69d03a27f8 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 7 Dec 2023 02:51:02 -0500 Subject: [PATCH] Support attention masking in CLIP implementation. --- comfy/clip_model.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/comfy/clip_model.py b/comfy/clip_model.py index e6a7bfa6..c61353dc 100644 --- a/comfy/clip_model.py +++ b/comfy/clip_model.py @@ -100,8 +100,12 @@ class CLIPTextModel_(torch.nn.Module): def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True): x = self.embeddings(input_tokens) - #TODO: attention_mask - x, i = self.encoder(x, intermediate_output=intermediate_output) + mask = None + if attention_mask is not None: + mask = 1.0 - attention_mask.to(x.dtype).unsqueeze(1).unsqueeze(1).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]) + mask = mask.masked_fill(mask.to(torch.bool), float("-inf")) + + x, i = self.encoder(x, mask=mask, intermediate_output=intermediate_output) x = self.final_layer_norm(x) if i is not None and final_layer_norm_intermediate: i = self.final_layer_norm(i)