Add support for attention masking in Flux (#5942)

* fix attention OOM in xformers

* allow passing attention mask in flux attention

* allow an attn_mask in flux

* attn masks can be done using replace patches instead of a separate dict

* fix return types

* fix return order

* enumerate

* patch the right keys

* arg names

* fix a silly bug

* fix xformers masks

* replace match with if, elif, else

* mask with image_ref_size

* remove unused import

* remove unused import 2

* fix pytorch/xformers attention

This corrects a weird inconsistency with skip_reshape.
It also allows masks of various shapes to be passed, which will be
automtically expanded (in a memory-efficient way) to a size that is
compatible with xformers or pytorch sdpa respectively.

* fix mask shapes
This commit is contained in:
Raphael Walker 2024-12-17 00:21:17 +01:00 committed by GitHub
parent 0f954f34af
commit 61b50720d0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 182 additions and 48 deletions

View File

@ -142,7 +142,7 @@ class DoubleStreamBlock(nn.Module):
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
) )
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor): def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None):
img_mod1, img_mod2 = self.img_mod(vec) img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec) txt_mod1, txt_mod2 = self.txt_mod(vec)
@ -163,7 +163,8 @@ class DoubleStreamBlock(nn.Module):
# run actual attention # run actual attention
attn = attention(torch.cat((txt_q, img_q), dim=2), attn = attention(torch.cat((txt_q, img_q), dim=2),
torch.cat((txt_k, img_k), dim=2), torch.cat((txt_k, img_k), dim=2),
torch.cat((txt_v, img_v), dim=2), pe=pe) torch.cat((txt_v, img_v), dim=2),
pe=pe, mask=attn_mask)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
@ -217,7 +218,7 @@ class SingleStreamBlock(nn.Module):
self.mlp_act = nn.GELU(approximate="tanh") self.mlp_act = nn.GELU(approximate="tanh")
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations) self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None) -> Tensor:
mod, _ = self.modulation(vec) mod, _ = self.modulation(vec)
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
@ -226,7 +227,7 @@ class SingleStreamBlock(nn.Module):
q, k = self.norm(q, k, v) q, k = self.norm(q, k, v)
# compute attention # compute attention
attn = attention(q, k, v, pe=pe) attn = attention(q, k, v, pe=pe, mask=attn_mask)
# compute activation in mlp stream, cat again and run second linear layer # compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
x += mod.gate * output x += mod.gate * output

View File

@ -1,14 +1,15 @@
import torch import torch
from einops import rearrange from einops import rearrange
from torch import Tensor from torch import Tensor
from comfy.ldm.modules.attention import optimized_attention from comfy.ldm.modules.attention import optimized_attention
import comfy.model_management import comfy.model_management
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor: def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
q, k = apply_rope(q, k, pe) q, k = apply_rope(q, k, pe)
heads = q.shape[1] heads = q.shape[1]
x = optimized_attention(q, k, v, heads, skip_reshape=True) x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask)
return x return x
@ -33,3 +34,4 @@ def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)

View File

@ -4,6 +4,8 @@ from dataclasses import dataclass
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
from einops import rearrange, repeat
import comfy.ldm.common_dit
from .layers import ( from .layers import (
DoubleStreamBlock, DoubleStreamBlock,
@ -14,9 +16,6 @@ from .layers import (
timestep_embedding, timestep_embedding,
) )
from einops import rearrange, repeat
import comfy.ldm.common_dit
@dataclass @dataclass
class FluxParams: class FluxParams:
in_channels: int in_channels: int
@ -100,6 +99,7 @@ class Flux(nn.Module):
guidance: Tensor = None, guidance: Tensor = None,
control = None, control = None,
transformer_options={}, transformer_options={},
attn_mask: Tensor = None,
) -> Tensor: ) -> Tensor:
patches_replace = transformer_options.get("patches_replace", {}) patches_replace = transformer_options.get("patches_replace", {})
if img.ndim != 3 or txt.ndim != 3: if img.ndim != 3 or txt.ndim != 3:
@ -124,14 +124,27 @@ class Flux(nn.Module):
if ("double_block", i) in blocks_replace: if ("double_block", i) in blocks_replace:
def block_wrap(args): def block_wrap(args):
out = {} out = {}
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"]) out["img"], out["txt"] = block(img=args["img"],
txt=args["txt"],
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"))
return out return out
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe}, {"original_block": block_wrap}) out = blocks_replace[("double_block", i)]({"img": img,
"txt": txt,
"vec": vec,
"pe": pe,
"attn_mask": attn_mask},
{"original_block": block_wrap})
txt = out["txt"] txt = out["txt"]
img = out["img"] img = out["img"]
else: else:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe) img, txt = block(img=img,
txt=txt,
vec=vec,
pe=pe,
attn_mask=attn_mask)
if control is not None: # Controlnet if control is not None: # Controlnet
control_i = control.get("input") control_i = control.get("input")
@ -146,13 +159,20 @@ class Flux(nn.Module):
if ("single_block", i) in blocks_replace: if ("single_block", i) in blocks_replace:
def block_wrap(args): def block_wrap(args):
out = {} out = {}
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"]) out["img"] = block(args["img"],
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"))
return out return out
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe}, {"original_block": block_wrap}) out = blocks_replace[("single_block", i)]({"img": img,
"vec": vec,
"pe": pe,
"attn_mask": attn_mask},
{"original_block": block_wrap})
img = out["img"] img = out["img"]
else: else:
img = block(img, vec=vec, pe=pe) img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
if control is not None: # Controlnet if control is not None: # Controlnet
control_o = control.get("output") control_o = control.get("output")
@ -181,5 +201,5 @@ class Flux(nn.Module):
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options) out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w] return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]

View File

@ -340,12 +340,9 @@ except:
pass pass
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False): def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
if skip_reshape: b = q.shape[0]
b, _, _, dim_head = q.shape dim_head = q.shape[-1]
else: # check to make sure xformers isn't broken
b, _, dim_head = q.shape
dim_head //= heads
disabled_xformers = False disabled_xformers = False
if BROKEN_XFORMERS: if BROKEN_XFORMERS:
@ -360,32 +357,41 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape) return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape)
if skip_reshape: if skip_reshape:
# b h k d -> b k h d
q, k, v = map( q, k, v = map(
lambda t: t.reshape(b * heads, -1, dim_head), lambda t: t.permute(0, 2, 1, 3),
(q, k, v), (q, k, v),
) )
# actually do the reshaping
else: else:
dim_head //= heads
q, k, v = map( q, k, v = map(
lambda t: t.reshape(b, -1, heads, dim_head), lambda t: t.reshape(b, -1, heads, dim_head),
(q, k, v), (q, k, v),
) )
if mask is not None: if mask is not None:
# add a singleton batch dimension
if mask.ndim == 2:
mask = mask.unsqueeze(0)
# add a singleton heads dimension
if mask.ndim == 3:
mask = mask.unsqueeze(1)
# pad to a multiple of 8
pad = 8 - mask.shape[-1] % 8 pad = 8 - mask.shape[-1] % 8
mask_out = torch.empty([q.shape[0], q.shape[2], q.shape[1], mask.shape[-1] + pad], dtype=q.dtype, device=q.device) # the xformers docs says that it's allowed to have a mask of shape (1, Nq, Nk)
# but when using separated heads, the shape has to be (B, H, Nq, Nk)
# in flux, this matrix ends up being over 1GB
# here, we create a mask with the same batch/head size as the input mask (potentially singleton or full)
mask_out = torch.empty([mask.shape[0], mask.shape[1], q.shape[1], mask.shape[-1] + pad], dtype=q.dtype, device=q.device)
mask_out[..., :mask.shape[-1]] = mask mask_out[..., :mask.shape[-1]] = mask
# doesn't this remove the padding again??
mask = mask_out[..., :mask.shape[-1]] mask = mask_out[..., :mask.shape[-1]]
mask = mask.expand(b, heads, -1, -1)
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask) out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
if skip_reshape:
out = (
out.unsqueeze(0)
.reshape(b, heads, -1, dim_head)
.permute(0, 2, 1, 3)
.reshape(b, -1, heads * dim_head)
)
else:
out = ( out = (
out.reshape(b, -1, heads * dim_head) out.reshape(b, -1, heads * dim_head)
) )
@ -410,15 +416,31 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
(q, k, v), (q, k, v),
) )
if SDP_BATCH_LIMIT >= q.shape[0]: if mask is not None:
# add a batch dimension if there isn't already one
if mask.ndim == 2:
mask = mask.unsqueeze(0)
# add a heads dimension if there isn't already one
if mask.ndim == 3:
mask = mask.unsqueeze(1)
mask = mask.expand(b, heads, -1, -1)
if SDP_BATCH_LIMIT >= b:
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
out = ( out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head) out.transpose(1, 2).reshape(b, -1, heads * dim_head)
) )
else: else:
out = torch.empty((q.shape[0], q.shape[2], heads * dim_head), dtype=q.dtype, layout=q.layout, device=q.device) out = torch.empty((b, q.shape[2], heads * dim_head), dtype=q.dtype, layout=q.layout, device=q.device)
for i in range(0, q.shape[0], SDP_BATCH_LIMIT): for i in range(0, b, SDP_BATCH_LIMIT):
out[i : i + SDP_BATCH_LIMIT] = torch.nn.functional.scaled_dot_product_attention(q[i : i + SDP_BATCH_LIMIT], k[i : i + SDP_BATCH_LIMIT], v[i : i + SDP_BATCH_LIMIT], attn_mask=mask, dropout_p=0.0, is_causal=False).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head) out[i : i + SDP_BATCH_LIMIT] = torch.nn.functional.scaled_dot_product_attention(
q[i : i + SDP_BATCH_LIMIT],
k[i : i + SDP_BATCH_LIMIT],
v[i : i + SDP_BATCH_LIMIT],
attn_mask=None if mask is None else mask[i : i + SDP_BATCH_LIMIT],
dropout_p=0.0, is_causal=False
).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head)
return out return out

View File

@ -686,6 +686,7 @@ class StableAudio1(BaseModel):
sd["{}{}".format(k, l)] = s[l] sd["{}{}".format(k, l)] = s[l]
return sd return sd
class HunyuanDiT(BaseModel): class HunyuanDiT(BaseModel):
def __init__(self, model_config, model_type=ModelType.V_PREDICTION, device=None): def __init__(self, model_config, model_type=ModelType.V_PREDICTION, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hydit.models.HunYuanDiT) super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hydit.models.HunYuanDiT)
@ -766,6 +767,16 @@ class Flux(BaseModel):
cross_attn = kwargs.get("cross_attn", None) cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None: if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
# upscale the attention mask, since now we
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
shape = kwargs["noise"].shape
mask_ref_size = kwargs["attention_mask_img_shape"]
# the model will pad to the patch size, and then divide
# essentially dividing and rounding up
(h_tok, w_tok) = (math.ceil(shape[2] / self.diffusion_model.patch_size), math.ceil(shape[3] / self.diffusion_model.patch_size))
attention_mask = utils.upscale_dit_mask(attention_mask, mask_ref_size, (h_tok, w_tok))
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([kwargs.get("guidance", 3.5)])) out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([kwargs.get("guidance", 3.5)]))
return out return out

View File

@ -26,6 +26,8 @@ import numpy as np
from PIL import Image from PIL import Image
import logging import logging
import itertools import itertools
from torch.nn.functional import interpolate
from einops import rearrange
def load_torch_file(ckpt, safe_load=False, device=None): def load_torch_file(ckpt, safe_load=False, device=None):
if device is None: if device is None:
@ -873,5 +875,46 @@ def reshape_mask(input_mask, output_shape):
mask = torch.nn.functional.interpolate(input_mask, size=output_shape[2:], mode=scale_mode) mask = torch.nn.functional.interpolate(input_mask, size=output_shape[2:], mode=scale_mode)
if mask.shape[1] < output_shape[1]: if mask.shape[1] < output_shape[1]:
mask = mask.repeat((1, output_shape[1]) + (1,) * dims)[:,:output_shape[1]] mask = mask.repeat((1, output_shape[1]) + (1,) * dims)[:,:output_shape[1]]
mask = comfy.utils.repeat_to_batch_size(mask, output_shape[0]) mask = repeat_to_batch_size(mask, output_shape[0])
return mask return mask
def upscale_dit_mask(mask: torch.Tensor, img_size_in, img_size_out):
hi, wi = img_size_in
ho, wo = img_size_out
# if it's already the correct size, no need to do anything
if (hi, wi) == (ho, wo):
return mask
if mask.ndim == 2:
mask = mask.unsqueeze(0)
if mask.ndim != 3:
raise ValueError(f"Got a mask of shape {list(mask.shape)}, expected [b, q, k] or [q, k]")
txt_tokens = mask.shape[1] - (hi * wi)
# quadrants of the mask
txt_to_txt = mask[:, :txt_tokens, :txt_tokens]
txt_to_img = mask[:, :txt_tokens, txt_tokens:]
img_to_img = mask[:, txt_tokens:, txt_tokens:]
img_to_txt = mask[:, txt_tokens:, :txt_tokens]
# convert to 1d x 2d, interpolate, then back to 1d x 1d
txt_to_img = rearrange (txt_to_img, "b t (h w) -> b t h w", h=hi, w=wi)
txt_to_img = interpolate(txt_to_img, size=img_size_out, mode="bilinear")
txt_to_img = rearrange (txt_to_img, "b t h w -> b t (h w)")
# this one is hard because we have to do it twice
# convert to 1d x 2d, interpolate, then to 2d x 1d, interpolate, then 1d x 1d
img_to_img = rearrange (img_to_img, "b hw (h w) -> b hw h w", h=hi, w=wi)
img_to_img = interpolate(img_to_img, size=img_size_out, mode="bilinear")
img_to_img = rearrange (img_to_img, "b (hk wk) hq wq -> b (hq wq) hk wk", hk=hi, wk=wi)
img_to_img = interpolate(img_to_img, size=img_size_out, mode="bilinear")
img_to_img = rearrange (img_to_img, "b (hq wq) hk wk -> b (hk wk) (hq wq)", hq=ho, wq=wo)
# convert to 2d x 1d, interpolate, then back to 1d x 1d
img_to_txt = rearrange (img_to_txt, "b (h w) t -> b t h w", h=hi, w=wi)
img_to_txt = interpolate(img_to_txt, size=img_size_out, mode="bilinear")
img_to_txt = rearrange (img_to_txt, "b t h w -> b (h w) t")
# reassemble the mask from blocks
out = torch.cat([
torch.cat([txt_to_txt, txt_to_img], dim=2),
torch.cat([img_to_txt, img_to_img], dim=2)],
dim=1
)
return out

View File

@ -1008,23 +1008,58 @@ class StyleModelApply:
"style_model": ("STYLE_MODEL", ), "style_model": ("STYLE_MODEL", ),
"clip_vision_output": ("CLIP_VISION_OUTPUT", ), "clip_vision_output": ("CLIP_VISION_OUTPUT", ),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}), "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}),
"strength_type": (["multiply"], ), "strength_type": (["multiply", "attn_bias"], ),
}} }}
RETURN_TYPES = ("CONDITIONING",) RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "apply_stylemodel" FUNCTION = "apply_stylemodel"
CATEGORY = "conditioning/style_model" CATEGORY = "conditioning/style_model"
def apply_stylemodel(self, clip_vision_output, style_model, conditioning, strength, strength_type): def apply_stylemodel(self, conditioning, style_model, clip_vision_output, strength, strength_type):
cond = style_model.get_cond(clip_vision_output).flatten(start_dim=0, end_dim=1).unsqueeze(dim=0) cond = style_model.get_cond(clip_vision_output).flatten(start_dim=0, end_dim=1).unsqueeze(dim=0)
if strength_type == "multiply": if strength_type == "multiply":
cond *= strength cond *= strength
c = [] n = cond.shape[1]
c_out = []
for t in conditioning: for t in conditioning:
n = [torch.cat((t[0], cond), dim=1), t[1].copy()] (txt, keys) = t
c.append(n) keys = keys.copy()
return (c, ) if strength_type == "attn_bias" and strength != 1.0:
# math.log raises an error if the argument is zero
# torch.log returns -inf, which is what we want
attn_bias = torch.log(torch.Tensor([strength]))
# get the size of the mask image
mask_ref_size = keys.get("attention_mask_img_shape", (1, 1))
n_ref = mask_ref_size[0] * mask_ref_size[1]
n_txt = txt.shape[1]
# grab the existing mask
mask = keys.get("attention_mask", None)
# create a default mask if it doesn't exist
if mask is None:
mask = torch.zeros((txt.shape[0], n_txt + n_ref, n_txt + n_ref), dtype=torch.float16)
# convert the mask dtype, because it might be boolean
# we want it to be interpreted as a bias
if mask.dtype == torch.bool:
# log(True) = log(1) = 0
# log(False) = log(0) = -inf
mask = torch.log(mask.to(dtype=torch.float16))
# now we make the mask bigger to add space for our new tokens
new_mask = torch.zeros((txt.shape[0], n_txt + n + n_ref, n_txt + n + n_ref), dtype=torch.float16)
# copy over the old mask, in quandrants
new_mask[:, :n_txt, :n_txt] = mask[:, :n_txt, :n_txt]
new_mask[:, :n_txt, n_txt+n:] = mask[:, :n_txt, n_txt:]
new_mask[:, n_txt+n:, :n_txt] = mask[:, n_txt:, :n_txt]
new_mask[:, n_txt+n:, n_txt+n:] = mask[:, n_txt:, n_txt:]
# now fill in the attention bias to our redux tokens
new_mask[:, :n_txt, n_txt:n_txt+n] = attn_bias
new_mask[:, n_txt+n:, n_txt:n_txt+n] = attn_bias
keys["attention_mask"] = new_mask.to(txt.device)
keys["attention_mask_img_shape"] = mask_ref_size
c_out.append([torch.cat((txt, cond), dim=1), keys])
return (c_out,)
class unCLIPConditioning: class unCLIPConditioning:
@classmethod @classmethod