mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Support attention masking in CLIP implementation.
This commit is contained in:
parent
248d9125b0
commit
efb704c758
@ -100,8 +100,12 @@ class CLIPTextModel_(torch.nn.Module):
|
|||||||
|
|
||||||
def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True):
|
def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True):
|
||||||
x = self.embeddings(input_tokens)
|
x = self.embeddings(input_tokens)
|
||||||
#TODO: attention_mask
|
mask = None
|
||||||
x, i = self.encoder(x, intermediate_output=intermediate_output)
|
if attention_mask is not None:
|
||||||
|
mask = 1.0 - attention_mask.to(x.dtype).unsqueeze(1).unsqueeze(1).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
|
||||||
|
mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
|
||||||
|
|
||||||
|
x, i = self.encoder(x, mask=mask, intermediate_output=intermediate_output)
|
||||||
x = self.final_layer_norm(x)
|
x = self.final_layer_norm(x)
|
||||||
if i is not None and final_layer_norm_intermediate:
|
if i is not None and final_layer_norm_intermediate:
|
||||||
i = self.final_layer_norm(i)
|
i = self.final_layer_norm(i)
|
||||||
|
Loading…
Reference in New Issue
Block a user