Use maximum negative value instead of -inf for masks in text encoders.

This is probably more correct.
This commit is contained in:
comfyanonymous 2025-02-02 09:45:07 -05:00
parent 0a0df5f136
commit 44e19a28d3
3 changed files with 4 additions and 4 deletions

View File

@ -102,9 +102,9 @@ class CLIPTextModel_(torch.nn.Module):
mask = None
if attention_mask is not None:
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
mask = mask.masked_fill(mask.to(torch.bool), -torch.finfo(x.dtype).max)
causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)
causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(-torch.finfo(x.dtype).max).triu_(1)
if mask is not None:
mask += causal_mask
else:

View File

@ -118,7 +118,7 @@ class BertModel_(torch.nn.Module):
mask = None
if attention_mask is not None:
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
mask = mask.masked_fill(mask.to(torch.bool), -torch.finfo(x.dtype).max)
x, i = self.encoder(x, mask, intermediate_output)
return x, i

View File

@ -203,7 +203,7 @@ class T5Stack(torch.nn.Module):
mask = None
if attention_mask is not None:
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
mask = mask.masked_fill(mask.to(torch.bool), -torch.finfo(x.dtype).max)
intermediate = None
optimized_attention = optimized_attention_for_device(x.device, mask=attention_mask is not None, small_input=True)