mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 10:25:16 +00:00
Properly fix attention masks in CLIP with batches.
This commit is contained in:
parent
5b40e7a5ed
commit
3b9969c1c5
@ -97,7 +97,7 @@ class CLIPTextModel_(torch.nn.Module):
|
|||||||
x = self.embeddings(input_tokens)
|
x = self.embeddings(input_tokens)
|
||||||
mask = None
|
mask = None
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], attention_mask.shape[-1], attention_mask.shape[-1])
|
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
|
||||||
mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
|
mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
|
||||||
|
|
||||||
causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)
|
causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)
|
||||||
|
@ -114,7 +114,8 @@ def attention_basic(q, k, v, heads, mask=None):
|
|||||||
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
||||||
sim.masked_fill_(~mask, max_neg_value)
|
sim.masked_fill_(~mask, max_neg_value)
|
||||||
else:
|
else:
|
||||||
sim += mask
|
mask = mask.reshape(mask.shape[0], -1, mask.shape[-2], mask.shape[-1]).expand(-1, heads, -1, -1).reshape(sim.shape)
|
||||||
|
sim.add_(mask)
|
||||||
|
|
||||||
# attention, what we cannot get enough of
|
# attention, what we cannot get enough of
|
||||||
sim = sim.softmax(dim=-1)
|
sim = sim.softmax(dim=-1)
|
||||||
@ -165,6 +166,9 @@ def attention_sub_quad(query, key, value, heads, mask=None):
|
|||||||
if query_chunk_size is None:
|
if query_chunk_size is None:
|
||||||
query_chunk_size = 512
|
query_chunk_size = 512
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
mask = mask.reshape(mask.shape[0], -1, mask.shape[-2], mask.shape[-1]).expand(-1, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1])
|
||||||
|
|
||||||
hidden_states = efficient_dot_product_attention(
|
hidden_states = efficient_dot_product_attention(
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
@ -223,6 +227,9 @@ def attention_split(q, k, v, heads, mask=None):
|
|||||||
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
|
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
|
||||||
f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free')
|
f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free')
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
mask = mask.reshape(mask.shape[0], -1, mask.shape[-2], mask.shape[-1]).expand(-1, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1])
|
||||||
|
|
||||||
# print("steps", steps, mem_required, mem_free_total, modifier, q.element_size(), tensor_size)
|
# print("steps", steps, mem_required, mem_free_total, modifier, q.element_size(), tensor_size)
|
||||||
first_op_done = False
|
first_op_done = False
|
||||||
cleared_cache = False
|
cleared_cache = False
|
||||||
|
Loading…
Reference in New Issue
Block a user