From 61b50720d074e19217b79f167ecdd719325dd06a Mon Sep 17 00:00:00 2001 From: Raphael Walker Date: Tue, 17 Dec 2024 00:21:17 +0100 Subject: [PATCH] Add support for attention masking in Flux (#5942) * fix attention OOM in xformers * allow passing attention mask in flux attention * allow an attn_mask in flux * attn masks can be done using replace patches instead of a separate dict * fix return types * fix return order * enumerate * patch the right keys * arg names * fix a silly bug * fix xformers masks * replace match with if, elif, else * mask with image_ref_size * remove unused import * remove unused import 2 * fix pytorch/xformers attention This corrects a weird inconsistency with skip_reshape. It also allows masks of various shapes to be passed, which will be automtically expanded (in a memory-efficient way) to a size that is compatible with xformers or pytorch sdpa respectively. * fix mask shapes --- comfy/ldm/flux/layers.py | 9 +++-- comfy/ldm/flux/math.py | 6 ++- comfy/ldm/flux/model.py | 42 ++++++++++++++------ comfy/ldm/modules/attention.py | 70 ++++++++++++++++++++++------------ comfy/model_base.py | 11 ++++++ comfy/utils.py | 45 +++++++++++++++++++++- nodes.py | 47 ++++++++++++++++++++--- 7 files changed, 182 insertions(+), 48 deletions(-) diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py index dabab3e3..90ddab4f 100644 --- a/comfy/ldm/flux/layers.py +++ b/comfy/ldm/flux/layers.py @@ -142,7 +142,7 @@ class DoubleStreamBlock(nn.Module): operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), ) - def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor): + def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None): img_mod1, img_mod2 = self.img_mod(vec) txt_mod1, txt_mod2 = self.txt_mod(vec) @@ -163,7 +163,8 @@ class DoubleStreamBlock(nn.Module): # run actual attention attn = attention(torch.cat((txt_q, img_q), dim=2), torch.cat((txt_k, img_k), dim=2), - torch.cat((txt_v, img_v), dim=2), pe=pe) + torch.cat((txt_v, img_v), dim=2), + pe=pe, mask=attn_mask) txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] @@ -217,7 +218,7 @@ class SingleStreamBlock(nn.Module): self.mlp_act = nn.GELU(approximate="tanh") self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations) - def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: + def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None) -> Tensor: mod, _ = self.modulation(vec) x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) @@ -226,7 +227,7 @@ class SingleStreamBlock(nn.Module): q, k = self.norm(q, k, v) # compute attention - attn = attention(q, k, v, pe=pe) + attn = attention(q, k, v, pe=pe, mask=attn_mask) # compute activation in mlp stream, cat again and run second linear layer output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) x += mod.gate * output diff --git a/comfy/ldm/flux/math.py b/comfy/ldm/flux/math.py index 136ce2aa..b6549585 100644 --- a/comfy/ldm/flux/math.py +++ b/comfy/ldm/flux/math.py @@ -1,14 +1,15 @@ import torch from einops import rearrange from torch import Tensor + from comfy.ldm.modules.attention import optimized_attention import comfy.model_management -def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor: +def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor: q, k = apply_rope(q, k, pe) heads = q.shape[1] - x = optimized_attention(q, k, v, heads, skip_reshape=True) + x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask) return x @@ -33,3 +34,4 @@ def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor): xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) + diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index 97ad8ffe..02be0684 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -4,6 +4,8 @@ from dataclasses import dataclass import torch from torch import Tensor, nn +from einops import rearrange, repeat +import comfy.ldm.common_dit from .layers import ( DoubleStreamBlock, @@ -14,9 +16,6 @@ from .layers import ( timestep_embedding, ) -from einops import rearrange, repeat -import comfy.ldm.common_dit - @dataclass class FluxParams: in_channels: int @@ -98,8 +97,9 @@ class Flux(nn.Module): timesteps: Tensor, y: Tensor, guidance: Tensor = None, - control=None, + control = None, transformer_options={}, + attn_mask: Tensor = None, ) -> Tensor: patches_replace = transformer_options.get("patches_replace", {}) if img.ndim != 3 or txt.ndim != 3: @@ -124,14 +124,27 @@ class Flux(nn.Module): if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} - out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"]) + out["img"], out["txt"] = block(img=args["img"], + txt=args["txt"], + vec=args["vec"], + pe=args["pe"], + attn_mask=args.get("attn_mask")) return out - out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]({"img": img, + "txt": txt, + "vec": vec, + "pe": pe, + "attn_mask": attn_mask}, + {"original_block": block_wrap}) txt = out["txt"] img = out["img"] else: - img, txt = block(img=img, txt=txt, vec=vec, pe=pe) + img, txt = block(img=img, + txt=txt, + vec=vec, + pe=pe, + attn_mask=attn_mask) if control is not None: # Controlnet control_i = control.get("input") @@ -146,13 +159,20 @@ class Flux(nn.Module): if ("single_block", i) in blocks_replace: def block_wrap(args): out = {} - out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"]) + out["img"] = block(args["img"], + vec=args["vec"], + pe=args["pe"], + attn_mask=args.get("attn_mask")) return out - out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe}, {"original_block": block_wrap}) + out = blocks_replace[("single_block", i)]({"img": img, + "vec": vec, + "pe": pe, + "attn_mask": attn_mask}, + {"original_block": block_wrap}) img = out["img"] else: - img = block(img, vec=vec, pe=pe) + img = block(img, vec=vec, pe=pe, attn_mask=attn_mask) if control is not None: # Controlnet control_o = control.get("output") @@ -181,5 +201,5 @@ class Flux(nn.Module): img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) - out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options) + out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None)) return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w] diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index f9a571e0..54b42a4b 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -340,12 +340,9 @@ except: pass def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False): - if skip_reshape: - b, _, _, dim_head = q.shape - else: - b, _, dim_head = q.shape - dim_head //= heads - + b = q.shape[0] + dim_head = q.shape[-1] + # check to make sure xformers isn't broken disabled_xformers = False if BROKEN_XFORMERS: @@ -360,35 +357,44 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape) if skip_reshape: - q, k, v = map( - lambda t: t.reshape(b * heads, -1, dim_head), + # b h k d -> b k h d + q, k, v = map( + lambda t: t.permute(0, 2, 1, 3), (q, k, v), ) + # actually do the reshaping else: + dim_head //= heads q, k, v = map( lambda t: t.reshape(b, -1, heads, dim_head), (q, k, v), ) if mask is not None: + # add a singleton batch dimension + if mask.ndim == 2: + mask = mask.unsqueeze(0) + # add a singleton heads dimension + if mask.ndim == 3: + mask = mask.unsqueeze(1) + # pad to a multiple of 8 pad = 8 - mask.shape[-1] % 8 - mask_out = torch.empty([q.shape[0], q.shape[2], q.shape[1], mask.shape[-1] + pad], dtype=q.dtype, device=q.device) + # the xformers docs says that it's allowed to have a mask of shape (1, Nq, Nk) + # but when using separated heads, the shape has to be (B, H, Nq, Nk) + # in flux, this matrix ends up being over 1GB + # here, we create a mask with the same batch/head size as the input mask (potentially singleton or full) + mask_out = torch.empty([mask.shape[0], mask.shape[1], q.shape[1], mask.shape[-1] + pad], dtype=q.dtype, device=q.device) + mask_out[..., :mask.shape[-1]] = mask + # doesn't this remove the padding again?? mask = mask_out[..., :mask.shape[-1]] + mask = mask.expand(b, heads, -1, -1) out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask) - if skip_reshape: - out = ( - out.unsqueeze(0) - .reshape(b, heads, -1, dim_head) - .permute(0, 2, 1, 3) - .reshape(b, -1, heads * dim_head) - ) - else: - out = ( - out.reshape(b, -1, heads * dim_head) - ) + out = ( + out.reshape(b, -1, heads * dim_head) + ) return out @@ -410,15 +416,31 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha (q, k, v), ) - if SDP_BATCH_LIMIT >= q.shape[0]: + if mask is not None: + # add a batch dimension if there isn't already one + if mask.ndim == 2: + mask = mask.unsqueeze(0) + # add a heads dimension if there isn't already one + if mask.ndim == 3: + mask = mask.unsqueeze(1) + mask = mask.expand(b, heads, -1, -1) + + + if SDP_BATCH_LIMIT >= b: out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) out = ( out.transpose(1, 2).reshape(b, -1, heads * dim_head) ) else: - out = torch.empty((q.shape[0], q.shape[2], heads * dim_head), dtype=q.dtype, layout=q.layout, device=q.device) - for i in range(0, q.shape[0], SDP_BATCH_LIMIT): - out[i : i + SDP_BATCH_LIMIT] = torch.nn.functional.scaled_dot_product_attention(q[i : i + SDP_BATCH_LIMIT], k[i : i + SDP_BATCH_LIMIT], v[i : i + SDP_BATCH_LIMIT], attn_mask=mask, dropout_p=0.0, is_causal=False).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head) + out = torch.empty((b, q.shape[2], heads * dim_head), dtype=q.dtype, layout=q.layout, device=q.device) + for i in range(0, b, SDP_BATCH_LIMIT): + out[i : i + SDP_BATCH_LIMIT] = torch.nn.functional.scaled_dot_product_attention( + q[i : i + SDP_BATCH_LIMIT], + k[i : i + SDP_BATCH_LIMIT], + v[i : i + SDP_BATCH_LIMIT], + attn_mask=None if mask is None else mask[i : i + SDP_BATCH_LIMIT], + dropout_p=0.0, is_causal=False + ).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head) return out diff --git a/comfy/model_base.py b/comfy/model_base.py index c64ab646..eef5d1dc 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -686,6 +686,7 @@ class StableAudio1(BaseModel): sd["{}{}".format(k, l)] = s[l] return sd + class HunyuanDiT(BaseModel): def __init__(self, model_config, model_type=ModelType.V_PREDICTION, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hydit.models.HunYuanDiT) @@ -766,6 +767,16 @@ class Flux(BaseModel): cross_attn = kwargs.get("cross_attn", None) if cross_attn is not None: out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + # upscale the attention mask, since now we + attention_mask = kwargs.get("attention_mask", None) + if attention_mask is not None: + shape = kwargs["noise"].shape + mask_ref_size = kwargs["attention_mask_img_shape"] + # the model will pad to the patch size, and then divide + # essentially dividing and rounding up + (h_tok, w_tok) = (math.ceil(shape[2] / self.diffusion_model.patch_size), math.ceil(shape[3] / self.diffusion_model.patch_size)) + attention_mask = utils.upscale_dit_mask(attention_mask, mask_ref_size, (h_tok, w_tok)) + out['attention_mask'] = comfy.conds.CONDRegular(attention_mask) out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([kwargs.get("guidance", 3.5)])) return out diff --git a/comfy/utils.py b/comfy/utils.py index 3add621f..34659305 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -26,6 +26,8 @@ import numpy as np from PIL import Image import logging import itertools +from torch.nn.functional import interpolate +from einops import rearrange def load_torch_file(ckpt, safe_load=False, device=None): if device is None: @@ -873,5 +875,46 @@ def reshape_mask(input_mask, output_shape): mask = torch.nn.functional.interpolate(input_mask, size=output_shape[2:], mode=scale_mode) if mask.shape[1] < output_shape[1]: mask = mask.repeat((1, output_shape[1]) + (1,) * dims)[:,:output_shape[1]] - mask = comfy.utils.repeat_to_batch_size(mask, output_shape[0]) + mask = repeat_to_batch_size(mask, output_shape[0]) return mask + +def upscale_dit_mask(mask: torch.Tensor, img_size_in, img_size_out): + hi, wi = img_size_in + ho, wo = img_size_out + # if it's already the correct size, no need to do anything + if (hi, wi) == (ho, wo): + return mask + if mask.ndim == 2: + mask = mask.unsqueeze(0) + if mask.ndim != 3: + raise ValueError(f"Got a mask of shape {list(mask.shape)}, expected [b, q, k] or [q, k]") + txt_tokens = mask.shape[1] - (hi * wi) + # quadrants of the mask + txt_to_txt = mask[:, :txt_tokens, :txt_tokens] + txt_to_img = mask[:, :txt_tokens, txt_tokens:] + img_to_img = mask[:, txt_tokens:, txt_tokens:] + img_to_txt = mask[:, txt_tokens:, :txt_tokens] + + # convert to 1d x 2d, interpolate, then back to 1d x 1d + txt_to_img = rearrange (txt_to_img, "b t (h w) -> b t h w", h=hi, w=wi) + txt_to_img = interpolate(txt_to_img, size=img_size_out, mode="bilinear") + txt_to_img = rearrange (txt_to_img, "b t h w -> b t (h w)") + # this one is hard because we have to do it twice + # convert to 1d x 2d, interpolate, then to 2d x 1d, interpolate, then 1d x 1d + img_to_img = rearrange (img_to_img, "b hw (h w) -> b hw h w", h=hi, w=wi) + img_to_img = interpolate(img_to_img, size=img_size_out, mode="bilinear") + img_to_img = rearrange (img_to_img, "b (hk wk) hq wq -> b (hq wq) hk wk", hk=hi, wk=wi) + img_to_img = interpolate(img_to_img, size=img_size_out, mode="bilinear") + img_to_img = rearrange (img_to_img, "b (hq wq) hk wk -> b (hk wk) (hq wq)", hq=ho, wq=wo) + # convert to 2d x 1d, interpolate, then back to 1d x 1d + img_to_txt = rearrange (img_to_txt, "b (h w) t -> b t h w", h=hi, w=wi) + img_to_txt = interpolate(img_to_txt, size=img_size_out, mode="bilinear") + img_to_txt = rearrange (img_to_txt, "b t h w -> b (h w) t") + + # reassemble the mask from blocks + out = torch.cat([ + torch.cat([txt_to_txt, txt_to_img], dim=2), + torch.cat([img_to_txt, img_to_img], dim=2)], + dim=1 + ) + return out diff --git a/nodes.py b/nodes.py index aab31bfc..8f1f1499 100644 --- a/nodes.py +++ b/nodes.py @@ -1008,23 +1008,58 @@ class StyleModelApply: "style_model": ("STYLE_MODEL", ), "clip_vision_output": ("CLIP_VISION_OUTPUT", ), "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}), - "strength_type": (["multiply"], ), + "strength_type": (["multiply", "attn_bias"], ), }} RETURN_TYPES = ("CONDITIONING",) FUNCTION = "apply_stylemodel" CATEGORY = "conditioning/style_model" - def apply_stylemodel(self, clip_vision_output, style_model, conditioning, strength, strength_type): + def apply_stylemodel(self, conditioning, style_model, clip_vision_output, strength, strength_type): cond = style_model.get_cond(clip_vision_output).flatten(start_dim=0, end_dim=1).unsqueeze(dim=0) if strength_type == "multiply": cond *= strength - c = [] + n = cond.shape[1] + c_out = [] for t in conditioning: - n = [torch.cat((t[0], cond), dim=1), t[1].copy()] - c.append(n) - return (c, ) + (txt, keys) = t + keys = keys.copy() + if strength_type == "attn_bias" and strength != 1.0: + # math.log raises an error if the argument is zero + # torch.log returns -inf, which is what we want + attn_bias = torch.log(torch.Tensor([strength])) + # get the size of the mask image + mask_ref_size = keys.get("attention_mask_img_shape", (1, 1)) + n_ref = mask_ref_size[0] * mask_ref_size[1] + n_txt = txt.shape[1] + # grab the existing mask + mask = keys.get("attention_mask", None) + # create a default mask if it doesn't exist + if mask is None: + mask = torch.zeros((txt.shape[0], n_txt + n_ref, n_txt + n_ref), dtype=torch.float16) + # convert the mask dtype, because it might be boolean + # we want it to be interpreted as a bias + if mask.dtype == torch.bool: + # log(True) = log(1) = 0 + # log(False) = log(0) = -inf + mask = torch.log(mask.to(dtype=torch.float16)) + # now we make the mask bigger to add space for our new tokens + new_mask = torch.zeros((txt.shape[0], n_txt + n + n_ref, n_txt + n + n_ref), dtype=torch.float16) + # copy over the old mask, in quandrants + new_mask[:, :n_txt, :n_txt] = mask[:, :n_txt, :n_txt] + new_mask[:, :n_txt, n_txt+n:] = mask[:, :n_txt, n_txt:] + new_mask[:, n_txt+n:, :n_txt] = mask[:, n_txt:, :n_txt] + new_mask[:, n_txt+n:, n_txt+n:] = mask[:, n_txt:, n_txt:] + # now fill in the attention bias to our redux tokens + new_mask[:, :n_txt, n_txt:n_txt+n] = attn_bias + new_mask[:, n_txt+n:, n_txt:n_txt+n] = attn_bias + keys["attention_mask"] = new_mask.to(txt.device) + keys["attention_mask_img_shape"] = mask_ref_size + + c_out.append([torch.cat((txt, cond), dim=1), keys]) + + return (c_out,) class unCLIPConditioning: @classmethod