Allow zeroing out of embeds with unused attention mask.

This commit is contained in:
comfyanonymous 2024-07-05 23:48:17 -04:00
parent b4c2d03d47
commit ce649d61c0

View File

@ -169,7 +169,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
tokens = torch.LongTensor(tokens).to(device)
attention_mask = None
if self.enable_attention_masks:
if self.enable_attention_masks or self.zero_out_masked:
attention_mask = torch.zeros_like(tokens)
end_token = self.special_tokens.get("end", -1)
for x in range(attention_mask.shape[0]):
@ -178,7 +178,11 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
if tokens[x, y] == end_token:
break
outputs = self.transformer(tokens, attention_mask, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state)
attention_mask_model = None
if self.enable_attention_masks:
attention_mask_model = attention_mask
outputs = self.transformer(tokens, attention_mask_model, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state)
self.transformer.set_input_embeddings(backup_embeds)
if self.layer == "last":
@ -186,7 +190,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
else:
z = outputs[1].float()
if self.zero_out_masked and attention_mask is not None:
if self.zero_out_masked:
z *= attention_mask.unsqueeze(-1).float()
pooled_output = None