mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Refactor and improve the sag node.
Moved all the sag related code to comfy_extras/nodes_sag.py
This commit is contained in:
parent
6761233e9d
commit
ba04a87d10
@ -61,6 +61,9 @@ class ModelPatcher:
|
||||
else:
|
||||
self.model_options["sampler_cfg_function"] = sampler_cfg_function
|
||||
|
||||
def set_model_sampler_post_cfg_function(self, post_cfg_function):
|
||||
self.model_options["sampler_post_cfg_function"] = self.model_options.get("sampler_post_cfg_function", []) + [post_cfg_function]
|
||||
|
||||
def set_model_unet_function_wrapper(self, unet_wrapper_function):
|
||||
self.model_options["model_function_wrapper"] = unet_wrapper_function
|
||||
|
||||
@ -70,13 +73,17 @@ class ModelPatcher:
|
||||
to["patches"] = {}
|
||||
to["patches"][name] = to["patches"].get(name, []) + [patch]
|
||||
|
||||
def set_model_patch_replace(self, patch, name, block_name, number):
|
||||
def set_model_patch_replace(self, patch, name, block_name, number, transformer_index=None):
|
||||
to = self.model_options["transformer_options"]
|
||||
if "patches_replace" not in to:
|
||||
to["patches_replace"] = {}
|
||||
if name not in to["patches_replace"]:
|
||||
to["patches_replace"][name] = {}
|
||||
to["patches_replace"][name][(block_name, number)] = patch
|
||||
if transformer_index is not None:
|
||||
block = (block_name, number, transformer_index)
|
||||
else:
|
||||
block = (block_name, number)
|
||||
to["patches_replace"][name][block] = patch
|
||||
|
||||
def set_model_attn1_patch(self, patch):
|
||||
self.set_model_patch(patch, "attn1_patch")
|
||||
@ -84,11 +91,11 @@ class ModelPatcher:
|
||||
def set_model_attn2_patch(self, patch):
|
||||
self.set_model_patch(patch, "attn2_patch")
|
||||
|
||||
def set_model_attn1_replace(self, patch, block_name, number):
|
||||
self.set_model_patch_replace(patch, "attn1", block_name, number)
|
||||
def set_model_attn1_replace(self, patch, block_name, number, transformer_index=None):
|
||||
self.set_model_patch_replace(patch, "attn1", block_name, number, transformer_index)
|
||||
|
||||
def set_model_attn2_replace(self, patch, block_name, number):
|
||||
self.set_model_patch_replace(patch, "attn2", block_name, number)
|
||||
def set_model_attn2_replace(self, patch, block_name, number, transformer_index=None):
|
||||
self.set_model_patch_replace(patch, "attn2", block_name, number, transformer_index)
|
||||
|
||||
def set_model_attn1_output_patch(self, patch):
|
||||
self.set_model_patch(patch, "attn1_output_patch")
|
||||
|
@ -1,7 +1,6 @@
|
||||
from .k_diffusion import sampling as k_diffusion_sampling
|
||||
from .extra_samplers import uni_pc
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import enum
|
||||
from comfy import model_management
|
||||
import math
|
||||
@ -9,10 +8,6 @@ from comfy import model_base
|
||||
import comfy.utils
|
||||
import comfy.conds
|
||||
|
||||
|
||||
#The main sampling function shared by all the samplers
|
||||
#Returns denoised
|
||||
def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
|
||||
def get_area_and_mult(conds, x_in, timestep_in):
|
||||
area = (x_in.shape[2], x_in.shape[3], 0, 0)
|
||||
strength = 1.0
|
||||
@ -246,73 +241,27 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option
|
||||
del out_uncond_count
|
||||
return out_cond, out_uncond
|
||||
|
||||
#The main sampling function shared by all the samplers
|
||||
#Returns denoised
|
||||
def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
|
||||
if math.isclose(cond_scale, 1.0):
|
||||
uncond_ = None
|
||||
else:
|
||||
uncond_ = uncond
|
||||
|
||||
# if we're doing SAG, we still need to do uncond guidance, even though the cond and uncond will cancel out.
|
||||
if math.isclose(cond_scale, 1.0) and "sag" not in model_options:
|
||||
uncond = None
|
||||
|
||||
cond_pred, uncond_pred = calc_cond_uncond_batch(model, cond, uncond, x, timestep, model_options)
|
||||
cond_pred, uncond_pred = calc_cond_uncond_batch(model, cond, uncond_, x, timestep, model_options)
|
||||
cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale
|
||||
if "sampler_cfg_function" in model_options:
|
||||
args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep}
|
||||
cfg_result = x - model_options["sampler_cfg_function"](args)
|
||||
|
||||
if "sag" in model_options:
|
||||
assert uncond is not None, "SAG requires uncond guidance"
|
||||
sag_scale = model_options["sag_scale"]
|
||||
sag_sigma = model_options["sag_sigma"]
|
||||
sag_threshold = model_options.get("sag_threshold", 1.0)
|
||||
for fn in model_options.get("sampler_post_cfg_function", []):
|
||||
args = {"denoised": cfg_result, "cond": cond, "uncond": uncond, "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred,
|
||||
"sigma": timestep, "model_options": model_options, "input": x}
|
||||
cfg_result = fn(args)
|
||||
|
||||
# these methods are added by the sag patcher
|
||||
uncond_attn = model.get_attn_scores()
|
||||
mid_shape = model.get_mid_block_shape()
|
||||
# create the adversarially blurred image
|
||||
degraded = create_blur_map(uncond_pred, uncond_attn, mid_shape, sag_sigma, sag_threshold)
|
||||
degraded_noised = degraded + x - uncond_pred
|
||||
# call into the UNet
|
||||
(sag, _) = calc_cond_uncond_batch(model, uncond, None, degraded_noised, timestep, model_options)
|
||||
cfg_result += (degraded - sag) * sag_scale
|
||||
return cfg_result
|
||||
|
||||
def create_blur_map(x0, attn, mid_shape, sigma=3.0, threshold=1.0):
|
||||
# reshape and GAP the attention map
|
||||
_, hw1, hw2 = attn.shape
|
||||
b, _, lh, lw = x0.shape
|
||||
attn = attn.reshape(b, -1, hw1, hw2)
|
||||
# Global Average Pool
|
||||
mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold
|
||||
# Reshape
|
||||
mask = (
|
||||
mask.reshape(b, *mid_shape)
|
||||
.unsqueeze(1)
|
||||
.type(attn.dtype)
|
||||
)
|
||||
# Upsample
|
||||
mask = F.interpolate(mask, (lh, lw))
|
||||
|
||||
blurred = gaussian_blur_2d(x0, kernel_size=9, sigma=sigma)
|
||||
blurred = blurred * mask + x0 * (1 - mask)
|
||||
return blurred
|
||||
|
||||
def gaussian_blur_2d(img, kernel_size, sigma):
|
||||
ksize_half = (kernel_size - 1) * 0.5
|
||||
|
||||
x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
|
||||
|
||||
pdf = torch.exp(-0.5 * (x / sigma).pow(2))
|
||||
|
||||
x_kernel = pdf / pdf.sum()
|
||||
x_kernel = x_kernel.to(device=img.device, dtype=img.dtype)
|
||||
|
||||
kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :])
|
||||
kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1])
|
||||
|
||||
padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2]
|
||||
|
||||
img = F.pad(img, padding, mode="reflect")
|
||||
img = F.conv2d(img, kernel2d, groups=img.shape[-3])
|
||||
return img
|
||||
|
||||
class CFGNoisePredictor(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
|
@ -1,8 +1,12 @@
|
||||
import torch
|
||||
from torch import einsum
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
|
||||
from einops import rearrange, repeat
|
||||
import os
|
||||
from comfy.ldm.modules.attention import optimized_attention, _ATTN_PRECISION
|
||||
import comfy.samplers
|
||||
|
||||
# from comfy/ldm/modules/attention.py
|
||||
# but modified to return attention scores as well as output
|
||||
@ -49,7 +53,49 @@ def attention_basic_with_sim(q, k, v, heads, mask=None):
|
||||
)
|
||||
return (out, sim)
|
||||
|
||||
class SagNode:
|
||||
def create_blur_map(x0, attn, sigma=3.0, threshold=1.0):
|
||||
# reshape and GAP the attention map
|
||||
_, hw1, hw2 = attn.shape
|
||||
b, _, lh, lw = x0.shape
|
||||
attn = attn.reshape(b, -1, hw1, hw2)
|
||||
# Global Average Pool
|
||||
mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold
|
||||
ratio = round(math.sqrt(lh * lw / hw1))
|
||||
mid_shape = [math.ceil(lh / ratio), math.ceil(lw / ratio)]
|
||||
|
||||
# Reshape
|
||||
mask = (
|
||||
mask.reshape(b, *mid_shape)
|
||||
.unsqueeze(1)
|
||||
.type(attn.dtype)
|
||||
)
|
||||
# Upsample
|
||||
mask = F.interpolate(mask, (lh, lw))
|
||||
|
||||
blurred = gaussian_blur_2d(x0, kernel_size=9, sigma=sigma)
|
||||
blurred = blurred * mask + x0 * (1 - mask)
|
||||
return blurred
|
||||
|
||||
def gaussian_blur_2d(img, kernel_size, sigma):
|
||||
ksize_half = (kernel_size - 1) * 0.5
|
||||
|
||||
x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
|
||||
|
||||
pdf = torch.exp(-0.5 * (x / sigma).pow(2))
|
||||
|
||||
x_kernel = pdf / pdf.sum()
|
||||
x_kernel = x_kernel.to(device=img.device, dtype=img.dtype)
|
||||
|
||||
kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :])
|
||||
kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1])
|
||||
|
||||
padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2]
|
||||
|
||||
img = F.pad(img, padding, mode="reflect")
|
||||
img = F.conv2d(img, kernel2d, groups=img.shape[-3])
|
||||
return img
|
||||
|
||||
class SelfAttentionGuidance:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
@ -63,15 +109,9 @@ class SagNode:
|
||||
|
||||
def patch(self, model, scale, blur_sigma):
|
||||
m = model.clone()
|
||||
# set extra options on the model
|
||||
m.model_options["sag"] = True
|
||||
m.model_options["sag_scale"] = scale
|
||||
m.model_options["sag_sigma"] = blur_sigma
|
||||
|
||||
attn_scores = None
|
||||
mid_block_shape = None
|
||||
m.model.get_attn_scores = lambda: attn_scores
|
||||
m.model.get_mid_block_shape = lambda: mid_block_shape
|
||||
|
||||
# TODO: make this work properly with chunked batches
|
||||
# currently, we can only save the attn from one UNet call
|
||||
@ -92,24 +132,41 @@ class SagNode:
|
||||
else:
|
||||
return optimized_attention(q, k, v, heads=heads)
|
||||
|
||||
def post_cfg_function(args):
|
||||
nonlocal attn_scores
|
||||
nonlocal mid_block_shape
|
||||
uncond_attn = attn_scores
|
||||
|
||||
sag_scale = scale
|
||||
sag_sigma = blur_sigma
|
||||
sag_threshold = 1.0
|
||||
model = args["model"]
|
||||
uncond_pred = args["uncond_denoised"]
|
||||
uncond = args["uncond"]
|
||||
cfg_result = args["denoised"]
|
||||
sigma = args["sigma"]
|
||||
model_options = args["model_options"]
|
||||
x = args["input"]
|
||||
|
||||
# create the adversarially blurred image
|
||||
degraded = create_blur_map(uncond_pred, uncond_attn, sag_sigma, sag_threshold)
|
||||
degraded_noised = degraded + x - uncond_pred
|
||||
# call into the UNet
|
||||
(sag, _) = comfy.samplers.calc_cond_uncond_batch(model, uncond, None, degraded_noised, sigma, model_options)
|
||||
return cfg_result + (degraded - sag) * sag_scale
|
||||
|
||||
m.set_model_sampler_post_cfg_function(post_cfg_function)
|
||||
|
||||
# from diffusers:
|
||||
# unet.mid_block.attentions[0].transformer_blocks[0].attn1.patch
|
||||
def set_model_patch_replace(patch, name, key):
|
||||
to = m.model_options["transformer_options"]
|
||||
if "patches_replace" not in to:
|
||||
to["patches_replace"] = {}
|
||||
if name not in to["patches_replace"]:
|
||||
to["patches_replace"][name] = {}
|
||||
to["patches_replace"][name][key] = patch
|
||||
set_model_patch_replace(attn_and_record, "attn1", ("middle", 0, 0))
|
||||
# from diffusers:
|
||||
# unet.mid_block.attentions[0].register_forward_hook()
|
||||
def forward_hook(m, inp, out):
|
||||
nonlocal mid_block_shape
|
||||
mid_block_shape = out[0].shape[-2:]
|
||||
m.model.diffusion_model.middle_block[0].register_forward_hook(forward_hook)
|
||||
m.set_model_attn1_replace(attn_and_record, "middle", 0, 0)
|
||||
|
||||
return (m, )
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"Self-Attention Guidance": SagNode,
|
||||
"SelfAttentionGuidance": SelfAttentionGuidance,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"SelfAttentionGuidance": "Self-Attention Guidance",
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user