mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-03-14 13:17:32 +00:00
Implement Self-Attention Guidance (#2201)
* First SAG test * need to put extra options on the model instead of patcher * no errors and results seem not-broken * Use @ashen-uncensored formula, which works better!!! * Fix a crash when using weird resolutions. Remove an unnecessary UNet call * Improve comments, optimize memory in blur routine * SAG works with sampler_cfg_function
This commit is contained in:
parent
390078904c
commit
6761233e9d
@ -1,6 +1,7 @@
|
|||||||
from .k_diffusion import sampling as k_diffusion_sampling
|
from .k_diffusion import sampling as k_diffusion_sampling
|
||||||
from .extra_samplers import uni_pc
|
from .extra_samplers import uni_pc
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
import enum
|
import enum
|
||||||
from comfy import model_management
|
from comfy import model_management
|
||||||
import math
|
import math
|
||||||
@ -60,10 +61,10 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option
|
|||||||
for t in range(rr):
|
for t in range(rr):
|
||||||
mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1))
|
mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1))
|
||||||
|
|
||||||
conditionning = {}
|
conditioning = {}
|
||||||
model_conds = conds["model_conds"]
|
model_conds = conds["model_conds"]
|
||||||
for c in model_conds:
|
for c in model_conds:
|
||||||
conditionning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area)
|
conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area)
|
||||||
|
|
||||||
control = None
|
control = None
|
||||||
if 'control' in conds:
|
if 'control' in conds:
|
||||||
@ -82,7 +83,7 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option
|
|||||||
|
|
||||||
patches['middle_patch'] = [gligen_patch]
|
patches['middle_patch'] = [gligen_patch]
|
||||||
|
|
||||||
return (input_x, mult, conditionning, area, control, patches)
|
return (input_x, mult, conditioning, area, control, patches)
|
||||||
|
|
||||||
def cond_equal_size(c1, c2):
|
def cond_equal_size(c1, c2):
|
||||||
if c1 is c2:
|
if c1 is c2:
|
||||||
@ -246,15 +247,71 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option
|
|||||||
return out_cond, out_uncond
|
return out_cond, out_uncond
|
||||||
|
|
||||||
|
|
||||||
if math.isclose(cond_scale, 1.0):
|
# if we're doing SAG, we still need to do uncond guidance, even though the cond and uncond will cancel out.
|
||||||
|
if math.isclose(cond_scale, 1.0) and "sag" not in model_options:
|
||||||
uncond = None
|
uncond = None
|
||||||
|
|
||||||
cond, uncond = calc_cond_uncond_batch(model, cond, uncond, x, timestep, model_options)
|
cond_pred, uncond_pred = calc_cond_uncond_batch(model, cond, uncond, x, timestep, model_options)
|
||||||
|
cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale
|
||||||
if "sampler_cfg_function" in model_options:
|
if "sampler_cfg_function" in model_options:
|
||||||
args = {"cond": x - cond, "uncond": x - uncond, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep}
|
args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep}
|
||||||
return x - model_options["sampler_cfg_function"](args)
|
cfg_result = x - model_options["sampler_cfg_function"](args)
|
||||||
else:
|
|
||||||
return uncond + (cond - uncond) * cond_scale
|
if "sag" in model_options:
|
||||||
|
assert uncond is not None, "SAG requires uncond guidance"
|
||||||
|
sag_scale = model_options["sag_scale"]
|
||||||
|
sag_sigma = model_options["sag_sigma"]
|
||||||
|
sag_threshold = model_options.get("sag_threshold", 1.0)
|
||||||
|
|
||||||
|
# these methods are added by the sag patcher
|
||||||
|
uncond_attn = model.get_attn_scores()
|
||||||
|
mid_shape = model.get_mid_block_shape()
|
||||||
|
# create the adversarially blurred image
|
||||||
|
degraded = create_blur_map(uncond_pred, uncond_attn, mid_shape, sag_sigma, sag_threshold)
|
||||||
|
degraded_noised = degraded + x - uncond_pred
|
||||||
|
# call into the UNet
|
||||||
|
(sag, _) = calc_cond_uncond_batch(model, uncond, None, degraded_noised, timestep, model_options)
|
||||||
|
cfg_result += (degraded - sag) * sag_scale
|
||||||
|
return cfg_result
|
||||||
|
|
||||||
|
def create_blur_map(x0, attn, mid_shape, sigma=3.0, threshold=1.0):
|
||||||
|
# reshape and GAP the attention map
|
||||||
|
_, hw1, hw2 = attn.shape
|
||||||
|
b, _, lh, lw = x0.shape
|
||||||
|
attn = attn.reshape(b, -1, hw1, hw2)
|
||||||
|
# Global Average Pool
|
||||||
|
mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold
|
||||||
|
# Reshape
|
||||||
|
mask = (
|
||||||
|
mask.reshape(b, *mid_shape)
|
||||||
|
.unsqueeze(1)
|
||||||
|
.type(attn.dtype)
|
||||||
|
)
|
||||||
|
# Upsample
|
||||||
|
mask = F.interpolate(mask, (lh, lw))
|
||||||
|
|
||||||
|
blurred = gaussian_blur_2d(x0, kernel_size=9, sigma=sigma)
|
||||||
|
blurred = blurred * mask + x0 * (1 - mask)
|
||||||
|
return blurred
|
||||||
|
|
||||||
|
def gaussian_blur_2d(img, kernel_size, sigma):
|
||||||
|
ksize_half = (kernel_size - 1) * 0.5
|
||||||
|
|
||||||
|
x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
|
||||||
|
|
||||||
|
pdf = torch.exp(-0.5 * (x / sigma).pow(2))
|
||||||
|
|
||||||
|
x_kernel = pdf / pdf.sum()
|
||||||
|
x_kernel = x_kernel.to(device=img.device, dtype=img.dtype)
|
||||||
|
|
||||||
|
kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :])
|
||||||
|
kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1])
|
||||||
|
|
||||||
|
padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2]
|
||||||
|
|
||||||
|
img = F.pad(img, padding, mode="reflect")
|
||||||
|
img = F.conv2d(img, kernel2d, groups=img.shape[-3])
|
||||||
|
return img
|
||||||
|
|
||||||
class CFGNoisePredictor(torch.nn.Module):
|
class CFGNoisePredictor(torch.nn.Module):
|
||||||
def __init__(self, model):
|
def __init__(self, model):
|
||||||
|
115
comfy_extras/nodes_sag.py
Normal file
115
comfy_extras/nodes_sag.py
Normal file
@ -0,0 +1,115 @@
|
|||||||
|
import torch
|
||||||
|
from torch import einsum
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
import os
|
||||||
|
from comfy.ldm.modules.attention import optimized_attention, _ATTN_PRECISION
|
||||||
|
|
||||||
|
# from comfy/ldm/modules/attention.py
|
||||||
|
# but modified to return attention scores as well as output
|
||||||
|
def attention_basic_with_sim(q, k, v, heads, mask=None):
|
||||||
|
b, _, dim_head = q.shape
|
||||||
|
dim_head //= heads
|
||||||
|
scale = dim_head ** -0.5
|
||||||
|
|
||||||
|
h = heads
|
||||||
|
q, k, v = map(
|
||||||
|
lambda t: t.unsqueeze(3)
|
||||||
|
.reshape(b, -1, heads, dim_head)
|
||||||
|
.permute(0, 2, 1, 3)
|
||||||
|
.reshape(b * heads, -1, dim_head)
|
||||||
|
.contiguous(),
|
||||||
|
(q, k, v),
|
||||||
|
)
|
||||||
|
|
||||||
|
# force cast to fp32 to avoid overflowing
|
||||||
|
if _ATTN_PRECISION =="fp32":
|
||||||
|
with torch.autocast(enabled=False, device_type = 'cuda'):
|
||||||
|
q, k = q.float(), k.float()
|
||||||
|
sim = einsum('b i d, b j d -> b i j', q, k) * scale
|
||||||
|
else:
|
||||||
|
sim = einsum('b i d, b j d -> b i j', q, k) * scale
|
||||||
|
|
||||||
|
del q, k
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
mask = rearrange(mask, 'b ... -> b (...)')
|
||||||
|
max_neg_value = -torch.finfo(sim.dtype).max
|
||||||
|
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
||||||
|
sim.masked_fill_(~mask, max_neg_value)
|
||||||
|
|
||||||
|
# attention, what we cannot get enough of
|
||||||
|
sim = sim.softmax(dim=-1)
|
||||||
|
|
||||||
|
out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v)
|
||||||
|
out = (
|
||||||
|
out.unsqueeze(0)
|
||||||
|
.reshape(b, heads, -1, dim_head)
|
||||||
|
.permute(0, 2, 1, 3)
|
||||||
|
.reshape(b, -1, heads * dim_head)
|
||||||
|
)
|
||||||
|
return (out, sim)
|
||||||
|
|
||||||
|
class SagNode:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "model": ("MODEL",),
|
||||||
|
"scale": ("FLOAT", {"default": 0.5, "min": -2.0, "max": 5.0, "step": 0.1}),
|
||||||
|
"blur_sigma": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 10.0, "step": 0.1}),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("MODEL",)
|
||||||
|
FUNCTION = "patch"
|
||||||
|
|
||||||
|
CATEGORY = "_for_testing"
|
||||||
|
|
||||||
|
def patch(self, model, scale, blur_sigma):
|
||||||
|
m = model.clone()
|
||||||
|
# set extra options on the model
|
||||||
|
m.model_options["sag"] = True
|
||||||
|
m.model_options["sag_scale"] = scale
|
||||||
|
m.model_options["sag_sigma"] = blur_sigma
|
||||||
|
|
||||||
|
attn_scores = None
|
||||||
|
mid_block_shape = None
|
||||||
|
m.model.get_attn_scores = lambda: attn_scores
|
||||||
|
m.model.get_mid_block_shape = lambda: mid_block_shape
|
||||||
|
|
||||||
|
# TODO: make this work properly with chunked batches
|
||||||
|
# currently, we can only save the attn from one UNet call
|
||||||
|
def attn_and_record(q, k, v, extra_options):
|
||||||
|
nonlocal attn_scores
|
||||||
|
# if uncond, save the attention scores
|
||||||
|
heads = extra_options["n_heads"]
|
||||||
|
cond_or_uncond = extra_options["cond_or_uncond"]
|
||||||
|
b = q.shape[0] // len(cond_or_uncond)
|
||||||
|
if 1 in cond_or_uncond:
|
||||||
|
uncond_index = cond_or_uncond.index(1)
|
||||||
|
# do the entire attention operation, but save the attention scores to attn_scores
|
||||||
|
(out, sim) = attention_basic_with_sim(q, k, v, heads=heads)
|
||||||
|
# when using a higher batch size, I BELIEVE the result batch dimension is [uc1, ... ucn, c1, ... cn]
|
||||||
|
n_slices = heads * b
|
||||||
|
attn_scores = sim[n_slices * uncond_index:n_slices * (uncond_index+1)]
|
||||||
|
return out
|
||||||
|
else:
|
||||||
|
return optimized_attention(q, k, v, heads=heads)
|
||||||
|
|
||||||
|
# from diffusers:
|
||||||
|
# unet.mid_block.attentions[0].transformer_blocks[0].attn1.patch
|
||||||
|
def set_model_patch_replace(patch, name, key):
|
||||||
|
to = m.model_options["transformer_options"]
|
||||||
|
if "patches_replace" not in to:
|
||||||
|
to["patches_replace"] = {}
|
||||||
|
if name not in to["patches_replace"]:
|
||||||
|
to["patches_replace"][name] = {}
|
||||||
|
to["patches_replace"][name][key] = patch
|
||||||
|
set_model_patch_replace(attn_and_record, "attn1", ("middle", 0, 0))
|
||||||
|
# from diffusers:
|
||||||
|
# unet.mid_block.attentions[0].register_forward_hook()
|
||||||
|
def forward_hook(m, inp, out):
|
||||||
|
nonlocal mid_block_shape
|
||||||
|
mid_block_shape = out[0].shape[-2:]
|
||||||
|
m.model.diffusion_model.middle_block[0].register_forward_hook(forward_hook)
|
||||||
|
return (m, )
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"Self-Attention Guidance": SagNode,
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user