mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Add support for attention masking in Flux (#5942)
* fix attention OOM in xformers * allow passing attention mask in flux attention * allow an attn_mask in flux * attn masks can be done using replace patches instead of a separate dict * fix return types * fix return order * enumerate * patch the right keys * arg names * fix a silly bug * fix xformers masks * replace match with if, elif, else * mask with image_ref_size * remove unused import * remove unused import 2 * fix pytorch/xformers attention This corrects a weird inconsistency with skip_reshape. It also allows masks of various shapes to be passed, which will be automtically expanded (in a memory-efficient way) to a size that is compatible with xformers or pytorch sdpa respectively. * fix mask shapes
This commit is contained in:
parent
0f954f34af
commit
61b50720d0
@ -142,7 +142,7 @@ class DoubleStreamBlock(nn.Module):
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor):
|
||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None):
|
||||
img_mod1, img_mod2 = self.img_mod(vec)
|
||||
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
||||
|
||||
@ -163,7 +163,8 @@ class DoubleStreamBlock(nn.Module):
|
||||
# run actual attention
|
||||
attn = attention(torch.cat((txt_q, img_q), dim=2),
|
||||
torch.cat((txt_k, img_k), dim=2),
|
||||
torch.cat((txt_v, img_v), dim=2), pe=pe)
|
||||
torch.cat((txt_v, img_v), dim=2),
|
||||
pe=pe, mask=attn_mask)
|
||||
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
||||
|
||||
@ -217,7 +218,7 @@ class SingleStreamBlock(nn.Module):
|
||||
self.mlp_act = nn.GELU(approximate="tanh")
|
||||
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
|
||||
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None) -> Tensor:
|
||||
mod, _ = self.modulation(vec)
|
||||
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
||||
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
@ -226,7 +227,7 @@ class SingleStreamBlock(nn.Module):
|
||||
q, k = self.norm(q, k, v)
|
||||
|
||||
# compute attention
|
||||
attn = attention(q, k, v, pe=pe)
|
||||
attn = attention(q, k, v, pe=pe, mask=attn_mask)
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||
x += mod.gate * output
|
||||
|
@ -1,14 +1,15 @@
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from torch import Tensor
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
import comfy.model_management
|
||||
|
||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
|
||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
|
||||
q, k = apply_rope(q, k, pe)
|
||||
|
||||
heads = q.shape[1]
|
||||
x = optimized_attention(q, k, v, heads, skip_reshape=True)
|
||||
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask)
|
||||
return x
|
||||
|
||||
|
||||
@ -33,3 +34,4 @@ def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
||||
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
||||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
||||
|
||||
|
@ -4,6 +4,8 @@ from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from einops import rearrange, repeat
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
from .layers import (
|
||||
DoubleStreamBlock,
|
||||
@ -14,9 +16,6 @@ from .layers import (
|
||||
timestep_embedding,
|
||||
)
|
||||
|
||||
from einops import rearrange, repeat
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
@dataclass
|
||||
class FluxParams:
|
||||
in_channels: int
|
||||
@ -98,8 +97,9 @@ class Flux(nn.Module):
|
||||
timesteps: Tensor,
|
||||
y: Tensor,
|
||||
guidance: Tensor = None,
|
||||
control=None,
|
||||
control = None,
|
||||
transformer_options={},
|
||||
attn_mask: Tensor = None,
|
||||
) -> Tensor:
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
@ -124,14 +124,27 @@ class Flux(nn.Module):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"])
|
||||
out["img"], out["txt"] = block(img=args["img"],
|
||||
txt=args["txt"],
|
||||
vec=args["vec"],
|
||||
pe=args["pe"],
|
||||
attn_mask=args.get("attn_mask"))
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("double_block", i)]({"img": img,
|
||||
"txt": txt,
|
||||
"vec": vec,
|
||||
"pe": pe,
|
||||
"attn_mask": attn_mask},
|
||||
{"original_block": block_wrap})
|
||||
txt = out["txt"]
|
||||
img = out["img"]
|
||||
else:
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
||||
img, txt = block(img=img,
|
||||
txt=txt,
|
||||
vec=vec,
|
||||
pe=pe,
|
||||
attn_mask=attn_mask)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_i = control.get("input")
|
||||
@ -146,13 +159,20 @@ class Flux(nn.Module):
|
||||
if ("single_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"])
|
||||
out["img"] = block(args["img"],
|
||||
vec=args["vec"],
|
||||
pe=args["pe"],
|
||||
attn_mask=args.get("attn_mask"))
|
||||
return out
|
||||
|
||||
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("single_block", i)]({"img": img,
|
||||
"vec": vec,
|
||||
"pe": pe,
|
||||
"attn_mask": attn_mask},
|
||||
{"original_block": block_wrap})
|
||||
img = out["img"]
|
||||
else:
|
||||
img = block(img, vec=vec, pe=pe)
|
||||
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_o = control.get("output")
|
||||
@ -181,5 +201,5 @@ class Flux(nn.Module):
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options)
|
||||
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
|
||||
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]
|
||||
|
@ -340,12 +340,9 @@ except:
|
||||
pass
|
||||
|
||||
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
else:
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
|
||||
b = q.shape[0]
|
||||
dim_head = q.shape[-1]
|
||||
# check to make sure xformers isn't broken
|
||||
disabled_xformers = False
|
||||
|
||||
if BROKEN_XFORMERS:
|
||||
@ -360,32 +357,41 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
|
||||
return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape)
|
||||
|
||||
if skip_reshape:
|
||||
# b h k d -> b k h d
|
||||
q, k, v = map(
|
||||
lambda t: t.reshape(b * heads, -1, dim_head),
|
||||
lambda t: t.permute(0, 2, 1, 3),
|
||||
(q, k, v),
|
||||
)
|
||||
# actually do the reshaping
|
||||
else:
|
||||
dim_head //= heads
|
||||
q, k, v = map(
|
||||
lambda t: t.reshape(b, -1, heads, dim_head),
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
if mask is not None:
|
||||
# add a singleton batch dimension
|
||||
if mask.ndim == 2:
|
||||
mask = mask.unsqueeze(0)
|
||||
# add a singleton heads dimension
|
||||
if mask.ndim == 3:
|
||||
mask = mask.unsqueeze(1)
|
||||
# pad to a multiple of 8
|
||||
pad = 8 - mask.shape[-1] % 8
|
||||
mask_out = torch.empty([q.shape[0], q.shape[2], q.shape[1], mask.shape[-1] + pad], dtype=q.dtype, device=q.device)
|
||||
# the xformers docs says that it's allowed to have a mask of shape (1, Nq, Nk)
|
||||
# but when using separated heads, the shape has to be (B, H, Nq, Nk)
|
||||
# in flux, this matrix ends up being over 1GB
|
||||
# here, we create a mask with the same batch/head size as the input mask (potentially singleton or full)
|
||||
mask_out = torch.empty([mask.shape[0], mask.shape[1], q.shape[1], mask.shape[-1] + pad], dtype=q.dtype, device=q.device)
|
||||
|
||||
mask_out[..., :mask.shape[-1]] = mask
|
||||
# doesn't this remove the padding again??
|
||||
mask = mask_out[..., :mask.shape[-1]]
|
||||
mask = mask.expand(b, heads, -1, -1)
|
||||
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
|
||||
|
||||
if skip_reshape:
|
||||
out = (
|
||||
out.unsqueeze(0)
|
||||
.reshape(b, heads, -1, dim_head)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(b, -1, heads * dim_head)
|
||||
)
|
||||
else:
|
||||
out = (
|
||||
out.reshape(b, -1, heads * dim_head)
|
||||
)
|
||||
@ -410,15 +416,31 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
if SDP_BATCH_LIMIT >= q.shape[0]:
|
||||
if mask is not None:
|
||||
# add a batch dimension if there isn't already one
|
||||
if mask.ndim == 2:
|
||||
mask = mask.unsqueeze(0)
|
||||
# add a heads dimension if there isn't already one
|
||||
if mask.ndim == 3:
|
||||
mask = mask.unsqueeze(1)
|
||||
mask = mask.expand(b, heads, -1, -1)
|
||||
|
||||
|
||||
if SDP_BATCH_LIMIT >= b:
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||
out = (
|
||||
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||
)
|
||||
else:
|
||||
out = torch.empty((q.shape[0], q.shape[2], heads * dim_head), dtype=q.dtype, layout=q.layout, device=q.device)
|
||||
for i in range(0, q.shape[0], SDP_BATCH_LIMIT):
|
||||
out[i : i + SDP_BATCH_LIMIT] = torch.nn.functional.scaled_dot_product_attention(q[i : i + SDP_BATCH_LIMIT], k[i : i + SDP_BATCH_LIMIT], v[i : i + SDP_BATCH_LIMIT], attn_mask=mask, dropout_p=0.0, is_causal=False).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head)
|
||||
out = torch.empty((b, q.shape[2], heads * dim_head), dtype=q.dtype, layout=q.layout, device=q.device)
|
||||
for i in range(0, b, SDP_BATCH_LIMIT):
|
||||
out[i : i + SDP_BATCH_LIMIT] = torch.nn.functional.scaled_dot_product_attention(
|
||||
q[i : i + SDP_BATCH_LIMIT],
|
||||
k[i : i + SDP_BATCH_LIMIT],
|
||||
v[i : i + SDP_BATCH_LIMIT],
|
||||
attn_mask=None if mask is None else mask[i : i + SDP_BATCH_LIMIT],
|
||||
dropout_p=0.0, is_causal=False
|
||||
).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head)
|
||||
return out
|
||||
|
||||
|
||||
|
@ -686,6 +686,7 @@ class StableAudio1(BaseModel):
|
||||
sd["{}{}".format(k, l)] = s[l]
|
||||
return sd
|
||||
|
||||
|
||||
class HunyuanDiT(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.V_PREDICTION, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hydit.models.HunYuanDiT)
|
||||
@ -766,6 +767,16 @@ class Flux(BaseModel):
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
# upscale the attention mask, since now we
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
if attention_mask is not None:
|
||||
shape = kwargs["noise"].shape
|
||||
mask_ref_size = kwargs["attention_mask_img_shape"]
|
||||
# the model will pad to the patch size, and then divide
|
||||
# essentially dividing and rounding up
|
||||
(h_tok, w_tok) = (math.ceil(shape[2] / self.diffusion_model.patch_size), math.ceil(shape[3] / self.diffusion_model.patch_size))
|
||||
attention_mask = utils.upscale_dit_mask(attention_mask, mask_ref_size, (h_tok, w_tok))
|
||||
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([kwargs.get("guidance", 3.5)]))
|
||||
return out
|
||||
|
||||
|
@ -26,6 +26,8 @@ import numpy as np
|
||||
from PIL import Image
|
||||
import logging
|
||||
import itertools
|
||||
from torch.nn.functional import interpolate
|
||||
from einops import rearrange
|
||||
|
||||
def load_torch_file(ckpt, safe_load=False, device=None):
|
||||
if device is None:
|
||||
@ -873,5 +875,46 @@ def reshape_mask(input_mask, output_shape):
|
||||
mask = torch.nn.functional.interpolate(input_mask, size=output_shape[2:], mode=scale_mode)
|
||||
if mask.shape[1] < output_shape[1]:
|
||||
mask = mask.repeat((1, output_shape[1]) + (1,) * dims)[:,:output_shape[1]]
|
||||
mask = comfy.utils.repeat_to_batch_size(mask, output_shape[0])
|
||||
mask = repeat_to_batch_size(mask, output_shape[0])
|
||||
return mask
|
||||
|
||||
def upscale_dit_mask(mask: torch.Tensor, img_size_in, img_size_out):
|
||||
hi, wi = img_size_in
|
||||
ho, wo = img_size_out
|
||||
# if it's already the correct size, no need to do anything
|
||||
if (hi, wi) == (ho, wo):
|
||||
return mask
|
||||
if mask.ndim == 2:
|
||||
mask = mask.unsqueeze(0)
|
||||
if mask.ndim != 3:
|
||||
raise ValueError(f"Got a mask of shape {list(mask.shape)}, expected [b, q, k] or [q, k]")
|
||||
txt_tokens = mask.shape[1] - (hi * wi)
|
||||
# quadrants of the mask
|
||||
txt_to_txt = mask[:, :txt_tokens, :txt_tokens]
|
||||
txt_to_img = mask[:, :txt_tokens, txt_tokens:]
|
||||
img_to_img = mask[:, txt_tokens:, txt_tokens:]
|
||||
img_to_txt = mask[:, txt_tokens:, :txt_tokens]
|
||||
|
||||
# convert to 1d x 2d, interpolate, then back to 1d x 1d
|
||||
txt_to_img = rearrange (txt_to_img, "b t (h w) -> b t h w", h=hi, w=wi)
|
||||
txt_to_img = interpolate(txt_to_img, size=img_size_out, mode="bilinear")
|
||||
txt_to_img = rearrange (txt_to_img, "b t h w -> b t (h w)")
|
||||
# this one is hard because we have to do it twice
|
||||
# convert to 1d x 2d, interpolate, then to 2d x 1d, interpolate, then 1d x 1d
|
||||
img_to_img = rearrange (img_to_img, "b hw (h w) -> b hw h w", h=hi, w=wi)
|
||||
img_to_img = interpolate(img_to_img, size=img_size_out, mode="bilinear")
|
||||
img_to_img = rearrange (img_to_img, "b (hk wk) hq wq -> b (hq wq) hk wk", hk=hi, wk=wi)
|
||||
img_to_img = interpolate(img_to_img, size=img_size_out, mode="bilinear")
|
||||
img_to_img = rearrange (img_to_img, "b (hq wq) hk wk -> b (hk wk) (hq wq)", hq=ho, wq=wo)
|
||||
# convert to 2d x 1d, interpolate, then back to 1d x 1d
|
||||
img_to_txt = rearrange (img_to_txt, "b (h w) t -> b t h w", h=hi, w=wi)
|
||||
img_to_txt = interpolate(img_to_txt, size=img_size_out, mode="bilinear")
|
||||
img_to_txt = rearrange (img_to_txt, "b t h w -> b (h w) t")
|
||||
|
||||
# reassemble the mask from blocks
|
||||
out = torch.cat([
|
||||
torch.cat([txt_to_txt, txt_to_img], dim=2),
|
||||
torch.cat([img_to_txt, img_to_img], dim=2)],
|
||||
dim=1
|
||||
)
|
||||
return out
|
||||
|
47
nodes.py
47
nodes.py
@ -1008,23 +1008,58 @@ class StyleModelApply:
|
||||
"style_model": ("STYLE_MODEL", ),
|
||||
"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
|
||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}),
|
||||
"strength_type": (["multiply"], ),
|
||||
"strength_type": (["multiply", "attn_bias"], ),
|
||||
}}
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "apply_stylemodel"
|
||||
|
||||
CATEGORY = "conditioning/style_model"
|
||||
|
||||
def apply_stylemodel(self, clip_vision_output, style_model, conditioning, strength, strength_type):
|
||||
def apply_stylemodel(self, conditioning, style_model, clip_vision_output, strength, strength_type):
|
||||
cond = style_model.get_cond(clip_vision_output).flatten(start_dim=0, end_dim=1).unsqueeze(dim=0)
|
||||
if strength_type == "multiply":
|
||||
cond *= strength
|
||||
|
||||
c = []
|
||||
n = cond.shape[1]
|
||||
c_out = []
|
||||
for t in conditioning:
|
||||
n = [torch.cat((t[0], cond), dim=1), t[1].copy()]
|
||||
c.append(n)
|
||||
return (c, )
|
||||
(txt, keys) = t
|
||||
keys = keys.copy()
|
||||
if strength_type == "attn_bias" and strength != 1.0:
|
||||
# math.log raises an error if the argument is zero
|
||||
# torch.log returns -inf, which is what we want
|
||||
attn_bias = torch.log(torch.Tensor([strength]))
|
||||
# get the size of the mask image
|
||||
mask_ref_size = keys.get("attention_mask_img_shape", (1, 1))
|
||||
n_ref = mask_ref_size[0] * mask_ref_size[1]
|
||||
n_txt = txt.shape[1]
|
||||
# grab the existing mask
|
||||
mask = keys.get("attention_mask", None)
|
||||
# create a default mask if it doesn't exist
|
||||
if mask is None:
|
||||
mask = torch.zeros((txt.shape[0], n_txt + n_ref, n_txt + n_ref), dtype=torch.float16)
|
||||
# convert the mask dtype, because it might be boolean
|
||||
# we want it to be interpreted as a bias
|
||||
if mask.dtype == torch.bool:
|
||||
# log(True) = log(1) = 0
|
||||
# log(False) = log(0) = -inf
|
||||
mask = torch.log(mask.to(dtype=torch.float16))
|
||||
# now we make the mask bigger to add space for our new tokens
|
||||
new_mask = torch.zeros((txt.shape[0], n_txt + n + n_ref, n_txt + n + n_ref), dtype=torch.float16)
|
||||
# copy over the old mask, in quandrants
|
||||
new_mask[:, :n_txt, :n_txt] = mask[:, :n_txt, :n_txt]
|
||||
new_mask[:, :n_txt, n_txt+n:] = mask[:, :n_txt, n_txt:]
|
||||
new_mask[:, n_txt+n:, :n_txt] = mask[:, n_txt:, :n_txt]
|
||||
new_mask[:, n_txt+n:, n_txt+n:] = mask[:, n_txt:, n_txt:]
|
||||
# now fill in the attention bias to our redux tokens
|
||||
new_mask[:, :n_txt, n_txt:n_txt+n] = attn_bias
|
||||
new_mask[:, n_txt+n:, n_txt:n_txt+n] = attn_bias
|
||||
keys["attention_mask"] = new_mask.to(txt.device)
|
||||
keys["attention_mask_img_shape"] = mask_ref_size
|
||||
|
||||
c_out.append([torch.cat((txt, cond), dim=1), keys])
|
||||
|
||||
return (c_out,)
|
||||
|
||||
class unCLIPConditioning:
|
||||
@classmethod
|
||||
|
Loading…
Reference in New Issue
Block a user