Refactor and improve the sag node.

Moved all the sag related code to comfy_extras/nodes_sag.py
This commit is contained in:
comfyanonymous 2023-12-13 16:10:03 -05:00
parent 6761233e9d
commit ba04a87d10
3 changed files with 334 additions and 321 deletions

View File

@ -61,6 +61,9 @@ class ModelPatcher:
else: else:
self.model_options["sampler_cfg_function"] = sampler_cfg_function self.model_options["sampler_cfg_function"] = sampler_cfg_function
def set_model_sampler_post_cfg_function(self, post_cfg_function):
self.model_options["sampler_post_cfg_function"] = self.model_options.get("sampler_post_cfg_function", []) + [post_cfg_function]
def set_model_unet_function_wrapper(self, unet_wrapper_function): def set_model_unet_function_wrapper(self, unet_wrapper_function):
self.model_options["model_function_wrapper"] = unet_wrapper_function self.model_options["model_function_wrapper"] = unet_wrapper_function
@ -70,13 +73,17 @@ class ModelPatcher:
to["patches"] = {} to["patches"] = {}
to["patches"][name] = to["patches"].get(name, []) + [patch] to["patches"][name] = to["patches"].get(name, []) + [patch]
def set_model_patch_replace(self, patch, name, block_name, number): def set_model_patch_replace(self, patch, name, block_name, number, transformer_index=None):
to = self.model_options["transformer_options"] to = self.model_options["transformer_options"]
if "patches_replace" not in to: if "patches_replace" not in to:
to["patches_replace"] = {} to["patches_replace"] = {}
if name not in to["patches_replace"]: if name not in to["patches_replace"]:
to["patches_replace"][name] = {} to["patches_replace"][name] = {}
to["patches_replace"][name][(block_name, number)] = patch if transformer_index is not None:
block = (block_name, number, transformer_index)
else:
block = (block_name, number)
to["patches_replace"][name][block] = patch
def set_model_attn1_patch(self, patch): def set_model_attn1_patch(self, patch):
self.set_model_patch(patch, "attn1_patch") self.set_model_patch(patch, "attn1_patch")
@ -84,11 +91,11 @@ class ModelPatcher:
def set_model_attn2_patch(self, patch): def set_model_attn2_patch(self, patch):
self.set_model_patch(patch, "attn2_patch") self.set_model_patch(patch, "attn2_patch")
def set_model_attn1_replace(self, patch, block_name, number): def set_model_attn1_replace(self, patch, block_name, number, transformer_index=None):
self.set_model_patch_replace(patch, "attn1", block_name, number) self.set_model_patch_replace(patch, "attn1", block_name, number, transformer_index)
def set_model_attn2_replace(self, patch, block_name, number): def set_model_attn2_replace(self, patch, block_name, number, transformer_index=None):
self.set_model_patch_replace(patch, "attn2", block_name, number) self.set_model_patch_replace(patch, "attn2", block_name, number, transformer_index)
def set_model_attn1_output_patch(self, patch): def set_model_attn1_output_patch(self, patch):
self.set_model_patch(patch, "attn1_output_patch") self.set_model_patch(patch, "attn1_output_patch")

View File

@ -1,7 +1,6 @@
from .k_diffusion import sampling as k_diffusion_sampling from .k_diffusion import sampling as k_diffusion_sampling
from .extra_samplers import uni_pc from .extra_samplers import uni_pc
import torch import torch
import torch.nn.functional as F
import enum import enum
from comfy import model_management from comfy import model_management
import math import math
@ -9,10 +8,6 @@ from comfy import model_base
import comfy.utils import comfy.utils
import comfy.conds import comfy.conds
#The main sampling function shared by all the samplers
#Returns denoised
def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
def get_area_and_mult(conds, x_in, timestep_in): def get_area_and_mult(conds, x_in, timestep_in):
area = (x_in.shape[2], x_in.shape[3], 0, 0) area = (x_in.shape[2], x_in.shape[3], 0, 0)
strength = 1.0 strength = 1.0
@ -246,73 +241,27 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option
del out_uncond_count del out_uncond_count
return out_cond, out_uncond return out_cond, out_uncond
#The main sampling function shared by all the samplers
#Returns denoised
def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
if math.isclose(cond_scale, 1.0):
uncond_ = None
else:
uncond_ = uncond
# if we're doing SAG, we still need to do uncond guidance, even though the cond and uncond will cancel out. cond_pred, uncond_pred = calc_cond_uncond_batch(model, cond, uncond_, x, timestep, model_options)
if math.isclose(cond_scale, 1.0) and "sag" not in model_options:
uncond = None
cond_pred, uncond_pred = calc_cond_uncond_batch(model, cond, uncond, x, timestep, model_options)
cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale
if "sampler_cfg_function" in model_options: if "sampler_cfg_function" in model_options:
args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep} args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep}
cfg_result = x - model_options["sampler_cfg_function"](args) cfg_result = x - model_options["sampler_cfg_function"](args)
if "sag" in model_options: for fn in model_options.get("sampler_post_cfg_function", []):
assert uncond is not None, "SAG requires uncond guidance" args = {"denoised": cfg_result, "cond": cond, "uncond": uncond, "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred,
sag_scale = model_options["sag_scale"] "sigma": timestep, "model_options": model_options, "input": x}
sag_sigma = model_options["sag_sigma"] cfg_result = fn(args)
sag_threshold = model_options.get("sag_threshold", 1.0)
# these methods are added by the sag patcher
uncond_attn = model.get_attn_scores()
mid_shape = model.get_mid_block_shape()
# create the adversarially blurred image
degraded = create_blur_map(uncond_pred, uncond_attn, mid_shape, sag_sigma, sag_threshold)
degraded_noised = degraded + x - uncond_pred
# call into the UNet
(sag, _) = calc_cond_uncond_batch(model, uncond, None, degraded_noised, timestep, model_options)
cfg_result += (degraded - sag) * sag_scale
return cfg_result return cfg_result
def create_blur_map(x0, attn, mid_shape, sigma=3.0, threshold=1.0):
# reshape and GAP the attention map
_, hw1, hw2 = attn.shape
b, _, lh, lw = x0.shape
attn = attn.reshape(b, -1, hw1, hw2)
# Global Average Pool
mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold
# Reshape
mask = (
mask.reshape(b, *mid_shape)
.unsqueeze(1)
.type(attn.dtype)
)
# Upsample
mask = F.interpolate(mask, (lh, lw))
blurred = gaussian_blur_2d(x0, kernel_size=9, sigma=sigma)
blurred = blurred * mask + x0 * (1 - mask)
return blurred
def gaussian_blur_2d(img, kernel_size, sigma):
ksize_half = (kernel_size - 1) * 0.5
x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
pdf = torch.exp(-0.5 * (x / sigma).pow(2))
x_kernel = pdf / pdf.sum()
x_kernel = x_kernel.to(device=img.device, dtype=img.dtype)
kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :])
kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1])
padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2]
img = F.pad(img, padding, mode="reflect")
img = F.conv2d(img, kernel2d, groups=img.shape[-3])
return img
class CFGNoisePredictor(torch.nn.Module): class CFGNoisePredictor(torch.nn.Module):
def __init__(self, model): def __init__(self, model):
super().__init__() super().__init__()

View File

@ -1,8 +1,12 @@
import torch import torch
from torch import einsum from torch import einsum
import torch.nn.functional as F
import math
from einops import rearrange, repeat from einops import rearrange, repeat
import os import os
from comfy.ldm.modules.attention import optimized_attention, _ATTN_PRECISION from comfy.ldm.modules.attention import optimized_attention, _ATTN_PRECISION
import comfy.samplers
# from comfy/ldm/modules/attention.py # from comfy/ldm/modules/attention.py
# but modified to return attention scores as well as output # but modified to return attention scores as well as output
@ -49,7 +53,49 @@ def attention_basic_with_sim(q, k, v, heads, mask=None):
) )
return (out, sim) return (out, sim)
class SagNode: def create_blur_map(x0, attn, sigma=3.0, threshold=1.0):
# reshape and GAP the attention map
_, hw1, hw2 = attn.shape
b, _, lh, lw = x0.shape
attn = attn.reshape(b, -1, hw1, hw2)
# Global Average Pool
mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold
ratio = round(math.sqrt(lh * lw / hw1))
mid_shape = [math.ceil(lh / ratio), math.ceil(lw / ratio)]
# Reshape
mask = (
mask.reshape(b, *mid_shape)
.unsqueeze(1)
.type(attn.dtype)
)
# Upsample
mask = F.interpolate(mask, (lh, lw))
blurred = gaussian_blur_2d(x0, kernel_size=9, sigma=sigma)
blurred = blurred * mask + x0 * (1 - mask)
return blurred
def gaussian_blur_2d(img, kernel_size, sigma):
ksize_half = (kernel_size - 1) * 0.5
x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
pdf = torch.exp(-0.5 * (x / sigma).pow(2))
x_kernel = pdf / pdf.sum()
x_kernel = x_kernel.to(device=img.device, dtype=img.dtype)
kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :])
kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1])
padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2]
img = F.pad(img, padding, mode="reflect")
img = F.conv2d(img, kernel2d, groups=img.shape[-3])
return img
class SelfAttentionGuidance:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",), return {"required": { "model": ("MODEL",),
@ -63,15 +109,9 @@ class SagNode:
def patch(self, model, scale, blur_sigma): def patch(self, model, scale, blur_sigma):
m = model.clone() m = model.clone()
# set extra options on the model
m.model_options["sag"] = True
m.model_options["sag_scale"] = scale
m.model_options["sag_sigma"] = blur_sigma
attn_scores = None attn_scores = None
mid_block_shape = None mid_block_shape = None
m.model.get_attn_scores = lambda: attn_scores
m.model.get_mid_block_shape = lambda: mid_block_shape
# TODO: make this work properly with chunked batches # TODO: make this work properly with chunked batches
# currently, we can only save the attn from one UNet call # currently, we can only save the attn from one UNet call
@ -92,24 +132,41 @@ class SagNode:
else: else:
return optimized_attention(q, k, v, heads=heads) return optimized_attention(q, k, v, heads=heads)
def post_cfg_function(args):
nonlocal attn_scores
nonlocal mid_block_shape
uncond_attn = attn_scores
sag_scale = scale
sag_sigma = blur_sigma
sag_threshold = 1.0
model = args["model"]
uncond_pred = args["uncond_denoised"]
uncond = args["uncond"]
cfg_result = args["denoised"]
sigma = args["sigma"]
model_options = args["model_options"]
x = args["input"]
# create the adversarially blurred image
degraded = create_blur_map(uncond_pred, uncond_attn, sag_sigma, sag_threshold)
degraded_noised = degraded + x - uncond_pred
# call into the UNet
(sag, _) = comfy.samplers.calc_cond_uncond_batch(model, uncond, None, degraded_noised, sigma, model_options)
return cfg_result + (degraded - sag) * sag_scale
m.set_model_sampler_post_cfg_function(post_cfg_function)
# from diffusers: # from diffusers:
# unet.mid_block.attentions[0].transformer_blocks[0].attn1.patch # unet.mid_block.attentions[0].transformer_blocks[0].attn1.patch
def set_model_patch_replace(patch, name, key): m.set_model_attn1_replace(attn_and_record, "middle", 0, 0)
to = m.model_options["transformer_options"]
if "patches_replace" not in to:
to["patches_replace"] = {}
if name not in to["patches_replace"]:
to["patches_replace"][name] = {}
to["patches_replace"][name][key] = patch
set_model_patch_replace(attn_and_record, "attn1", ("middle", 0, 0))
# from diffusers:
# unet.mid_block.attentions[0].register_forward_hook()
def forward_hook(m, inp, out):
nonlocal mid_block_shape
mid_block_shape = out[0].shape[-2:]
m.model.diffusion_model.middle_block[0].register_forward_hook(forward_hook)
return (m, ) return (m, )
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"Self-Attention Guidance": SagNode, "SelfAttentionGuidance": SelfAttentionGuidance,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"SelfAttentionGuidance": "Self-Attention Guidance",
} }