mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Refactor and improve the sag node.
Moved all the sag related code to comfy_extras/nodes_sag.py
This commit is contained in:
parent
6761233e9d
commit
ba04a87d10
@ -61,6 +61,9 @@ class ModelPatcher:
|
|||||||
else:
|
else:
|
||||||
self.model_options["sampler_cfg_function"] = sampler_cfg_function
|
self.model_options["sampler_cfg_function"] = sampler_cfg_function
|
||||||
|
|
||||||
|
def set_model_sampler_post_cfg_function(self, post_cfg_function):
|
||||||
|
self.model_options["sampler_post_cfg_function"] = self.model_options.get("sampler_post_cfg_function", []) + [post_cfg_function]
|
||||||
|
|
||||||
def set_model_unet_function_wrapper(self, unet_wrapper_function):
|
def set_model_unet_function_wrapper(self, unet_wrapper_function):
|
||||||
self.model_options["model_function_wrapper"] = unet_wrapper_function
|
self.model_options["model_function_wrapper"] = unet_wrapper_function
|
||||||
|
|
||||||
@ -70,13 +73,17 @@ class ModelPatcher:
|
|||||||
to["patches"] = {}
|
to["patches"] = {}
|
||||||
to["patches"][name] = to["patches"].get(name, []) + [patch]
|
to["patches"][name] = to["patches"].get(name, []) + [patch]
|
||||||
|
|
||||||
def set_model_patch_replace(self, patch, name, block_name, number):
|
def set_model_patch_replace(self, patch, name, block_name, number, transformer_index=None):
|
||||||
to = self.model_options["transformer_options"]
|
to = self.model_options["transformer_options"]
|
||||||
if "patches_replace" not in to:
|
if "patches_replace" not in to:
|
||||||
to["patches_replace"] = {}
|
to["patches_replace"] = {}
|
||||||
if name not in to["patches_replace"]:
|
if name not in to["patches_replace"]:
|
||||||
to["patches_replace"][name] = {}
|
to["patches_replace"][name] = {}
|
||||||
to["patches_replace"][name][(block_name, number)] = patch
|
if transformer_index is not None:
|
||||||
|
block = (block_name, number, transformer_index)
|
||||||
|
else:
|
||||||
|
block = (block_name, number)
|
||||||
|
to["patches_replace"][name][block] = patch
|
||||||
|
|
||||||
def set_model_attn1_patch(self, patch):
|
def set_model_attn1_patch(self, patch):
|
||||||
self.set_model_patch(patch, "attn1_patch")
|
self.set_model_patch(patch, "attn1_patch")
|
||||||
@ -84,11 +91,11 @@ class ModelPatcher:
|
|||||||
def set_model_attn2_patch(self, patch):
|
def set_model_attn2_patch(self, patch):
|
||||||
self.set_model_patch(patch, "attn2_patch")
|
self.set_model_patch(patch, "attn2_patch")
|
||||||
|
|
||||||
def set_model_attn1_replace(self, patch, block_name, number):
|
def set_model_attn1_replace(self, patch, block_name, number, transformer_index=None):
|
||||||
self.set_model_patch_replace(patch, "attn1", block_name, number)
|
self.set_model_patch_replace(patch, "attn1", block_name, number, transformer_index)
|
||||||
|
|
||||||
def set_model_attn2_replace(self, patch, block_name, number):
|
def set_model_attn2_replace(self, patch, block_name, number, transformer_index=None):
|
||||||
self.set_model_patch_replace(patch, "attn2", block_name, number)
|
self.set_model_patch_replace(patch, "attn2", block_name, number, transformer_index)
|
||||||
|
|
||||||
def set_model_attn1_output_patch(self, patch):
|
def set_model_attn1_output_patch(self, patch):
|
||||||
self.set_model_patch(patch, "attn1_output_patch")
|
self.set_model_patch(patch, "attn1_output_patch")
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
from .k_diffusion import sampling as k_diffusion_sampling
|
from .k_diffusion import sampling as k_diffusion_sampling
|
||||||
from .extra_samplers import uni_pc
|
from .extra_samplers import uni_pc
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
|
||||||
import enum
|
import enum
|
||||||
from comfy import model_management
|
from comfy import model_management
|
||||||
import math
|
import math
|
||||||
@ -9,310 +8,260 @@ from comfy import model_base
|
|||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.conds
|
import comfy.conds
|
||||||
|
|
||||||
|
def get_area_and_mult(conds, x_in, timestep_in):
|
||||||
|
area = (x_in.shape[2], x_in.shape[3], 0, 0)
|
||||||
|
strength = 1.0
|
||||||
|
|
||||||
|
if 'timestep_start' in conds:
|
||||||
|
timestep_start = conds['timestep_start']
|
||||||
|
if timestep_in[0] > timestep_start:
|
||||||
|
return None
|
||||||
|
if 'timestep_end' in conds:
|
||||||
|
timestep_end = conds['timestep_end']
|
||||||
|
if timestep_in[0] < timestep_end:
|
||||||
|
return None
|
||||||
|
if 'area' in conds:
|
||||||
|
area = conds['area']
|
||||||
|
if 'strength' in conds:
|
||||||
|
strength = conds['strength']
|
||||||
|
|
||||||
|
input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
|
||||||
|
if 'mask' in conds:
|
||||||
|
# Scale the mask to the size of the input
|
||||||
|
# The mask should have been resized as we began the sampling process
|
||||||
|
mask_strength = 1.0
|
||||||
|
if "mask_strength" in conds:
|
||||||
|
mask_strength = conds["mask_strength"]
|
||||||
|
mask = conds['mask']
|
||||||
|
assert(mask.shape[1] == x_in.shape[2])
|
||||||
|
assert(mask.shape[2] == x_in.shape[3])
|
||||||
|
mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] * mask_strength
|
||||||
|
mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1)
|
||||||
|
else:
|
||||||
|
mask = torch.ones_like(input_x)
|
||||||
|
mult = mask * strength
|
||||||
|
|
||||||
|
if 'mask' not in conds:
|
||||||
|
rr = 8
|
||||||
|
if area[2] != 0:
|
||||||
|
for t in range(rr):
|
||||||
|
mult[:,:,t:1+t,:] *= ((1.0/rr) * (t + 1))
|
||||||
|
if (area[0] + area[2]) < x_in.shape[2]:
|
||||||
|
for t in range(rr):
|
||||||
|
mult[:,:,area[0] - 1 - t:area[0] - t,:] *= ((1.0/rr) * (t + 1))
|
||||||
|
if area[3] != 0:
|
||||||
|
for t in range(rr):
|
||||||
|
mult[:,:,:,t:1+t] *= ((1.0/rr) * (t + 1))
|
||||||
|
if (area[1] + area[3]) < x_in.shape[3]:
|
||||||
|
for t in range(rr):
|
||||||
|
mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1))
|
||||||
|
|
||||||
|
conditioning = {}
|
||||||
|
model_conds = conds["model_conds"]
|
||||||
|
for c in model_conds:
|
||||||
|
conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area)
|
||||||
|
|
||||||
|
control = None
|
||||||
|
if 'control' in conds:
|
||||||
|
control = conds['control']
|
||||||
|
|
||||||
|
patches = None
|
||||||
|
if 'gligen' in conds:
|
||||||
|
gligen = conds['gligen']
|
||||||
|
patches = {}
|
||||||
|
gligen_type = gligen[0]
|
||||||
|
gligen_model = gligen[1]
|
||||||
|
if gligen_type == "position":
|
||||||
|
gligen_patch = gligen_model.model.set_position(input_x.shape, gligen[2], input_x.device)
|
||||||
|
else:
|
||||||
|
gligen_patch = gligen_model.model.set_empty(input_x.shape, input_x.device)
|
||||||
|
|
||||||
|
patches['middle_patch'] = [gligen_patch]
|
||||||
|
|
||||||
|
return (input_x, mult, conditioning, area, control, patches)
|
||||||
|
|
||||||
|
def cond_equal_size(c1, c2):
|
||||||
|
if c1 is c2:
|
||||||
|
return True
|
||||||
|
if c1.keys() != c2.keys():
|
||||||
|
return False
|
||||||
|
for k in c1:
|
||||||
|
if not c1[k].can_concat(c2[k]):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def can_concat_cond(c1, c2):
|
||||||
|
if c1[0].shape != c2[0].shape:
|
||||||
|
return False
|
||||||
|
|
||||||
|
#control
|
||||||
|
if (c1[4] is None) != (c2[4] is None):
|
||||||
|
return False
|
||||||
|
if c1[4] is not None:
|
||||||
|
if c1[4] is not c2[4]:
|
||||||
|
return False
|
||||||
|
|
||||||
|
#patches
|
||||||
|
if (c1[5] is None) != (c2[5] is None):
|
||||||
|
return False
|
||||||
|
if (c1[5] is not None):
|
||||||
|
if c1[5] is not c2[5]:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return cond_equal_size(c1[2], c2[2])
|
||||||
|
|
||||||
|
def cond_cat(c_list):
|
||||||
|
c_crossattn = []
|
||||||
|
c_concat = []
|
||||||
|
c_adm = []
|
||||||
|
crossattn_max_len = 0
|
||||||
|
|
||||||
|
temp = {}
|
||||||
|
for x in c_list:
|
||||||
|
for k in x:
|
||||||
|
cur = temp.get(k, [])
|
||||||
|
cur.append(x[k])
|
||||||
|
temp[k] = cur
|
||||||
|
|
||||||
|
out = {}
|
||||||
|
for k in temp:
|
||||||
|
conds = temp[k]
|
||||||
|
out[k] = conds[0].concat(conds[1:])
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
|
||||||
|
out_cond = torch.zeros_like(x_in)
|
||||||
|
out_count = torch.ones_like(x_in) * 1e-37
|
||||||
|
|
||||||
|
out_uncond = torch.zeros_like(x_in)
|
||||||
|
out_uncond_count = torch.ones_like(x_in) * 1e-37
|
||||||
|
|
||||||
|
COND = 0
|
||||||
|
UNCOND = 1
|
||||||
|
|
||||||
|
to_run = []
|
||||||
|
for x in cond:
|
||||||
|
p = get_area_and_mult(x, x_in, timestep)
|
||||||
|
if p is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
to_run += [(p, COND)]
|
||||||
|
if uncond is not None:
|
||||||
|
for x in uncond:
|
||||||
|
p = get_area_and_mult(x, x_in, timestep)
|
||||||
|
if p is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
to_run += [(p, UNCOND)]
|
||||||
|
|
||||||
|
while len(to_run) > 0:
|
||||||
|
first = to_run[0]
|
||||||
|
first_shape = first[0][0].shape
|
||||||
|
to_batch_temp = []
|
||||||
|
for x in range(len(to_run)):
|
||||||
|
if can_concat_cond(to_run[x][0], first[0]):
|
||||||
|
to_batch_temp += [x]
|
||||||
|
|
||||||
|
to_batch_temp.reverse()
|
||||||
|
to_batch = to_batch_temp[:1]
|
||||||
|
|
||||||
|
free_memory = model_management.get_free_memory(x_in.device)
|
||||||
|
for i in range(1, len(to_batch_temp) + 1):
|
||||||
|
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
|
||||||
|
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
|
||||||
|
if model.memory_required(input_shape) < free_memory:
|
||||||
|
to_batch = batch_amount
|
||||||
|
break
|
||||||
|
|
||||||
|
input_x = []
|
||||||
|
mult = []
|
||||||
|
c = []
|
||||||
|
cond_or_uncond = []
|
||||||
|
area = []
|
||||||
|
control = None
|
||||||
|
patches = None
|
||||||
|
for x in to_batch:
|
||||||
|
o = to_run.pop(x)
|
||||||
|
p = o[0]
|
||||||
|
input_x += [p[0]]
|
||||||
|
mult += [p[1]]
|
||||||
|
c += [p[2]]
|
||||||
|
area += [p[3]]
|
||||||
|
cond_or_uncond += [o[1]]
|
||||||
|
control = p[4]
|
||||||
|
patches = p[5]
|
||||||
|
|
||||||
|
batch_chunks = len(cond_or_uncond)
|
||||||
|
input_x = torch.cat(input_x)
|
||||||
|
c = cond_cat(c)
|
||||||
|
timestep_ = torch.cat([timestep] * batch_chunks)
|
||||||
|
|
||||||
|
if control is not None:
|
||||||
|
c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond))
|
||||||
|
|
||||||
|
transformer_options = {}
|
||||||
|
if 'transformer_options' in model_options:
|
||||||
|
transformer_options = model_options['transformer_options'].copy()
|
||||||
|
|
||||||
|
if patches is not None:
|
||||||
|
if "patches" in transformer_options:
|
||||||
|
cur_patches = transformer_options["patches"].copy()
|
||||||
|
for p in patches:
|
||||||
|
if p in cur_patches:
|
||||||
|
cur_patches[p] = cur_patches[p] + patches[p]
|
||||||
|
else:
|
||||||
|
cur_patches[p] = patches[p]
|
||||||
|
else:
|
||||||
|
transformer_options["patches"] = patches
|
||||||
|
|
||||||
|
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
|
||||||
|
transformer_options["sigmas"] = timestep
|
||||||
|
|
||||||
|
c['transformer_options'] = transformer_options
|
||||||
|
|
||||||
|
if 'model_function_wrapper' in model_options:
|
||||||
|
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
|
||||||
|
else:
|
||||||
|
output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
|
||||||
|
del input_x
|
||||||
|
|
||||||
|
for o in range(batch_chunks):
|
||||||
|
if cond_or_uncond[o] == COND:
|
||||||
|
out_cond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o]
|
||||||
|
out_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o]
|
||||||
|
else:
|
||||||
|
out_uncond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o]
|
||||||
|
out_uncond_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o]
|
||||||
|
del mult
|
||||||
|
|
||||||
|
out_cond /= out_count
|
||||||
|
del out_count
|
||||||
|
out_uncond /= out_uncond_count
|
||||||
|
del out_uncond_count
|
||||||
|
return out_cond, out_uncond
|
||||||
|
|
||||||
#The main sampling function shared by all the samplers
|
#The main sampling function shared by all the samplers
|
||||||
#Returns denoised
|
#Returns denoised
|
||||||
def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
|
def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
|
||||||
def get_area_and_mult(conds, x_in, timestep_in):
|
if math.isclose(cond_scale, 1.0):
|
||||||
area = (x_in.shape[2], x_in.shape[3], 0, 0)
|
uncond_ = None
|
||||||
strength = 1.0
|
else:
|
||||||
|
uncond_ = uncond
|
||||||
|
|
||||||
if 'timestep_start' in conds:
|
cond_pred, uncond_pred = calc_cond_uncond_batch(model, cond, uncond_, x, timestep, model_options)
|
||||||
timestep_start = conds['timestep_start']
|
|
||||||
if timestep_in[0] > timestep_start:
|
|
||||||
return None
|
|
||||||
if 'timestep_end' in conds:
|
|
||||||
timestep_end = conds['timestep_end']
|
|
||||||
if timestep_in[0] < timestep_end:
|
|
||||||
return None
|
|
||||||
if 'area' in conds:
|
|
||||||
area = conds['area']
|
|
||||||
if 'strength' in conds:
|
|
||||||
strength = conds['strength']
|
|
||||||
|
|
||||||
input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
|
|
||||||
if 'mask' in conds:
|
|
||||||
# Scale the mask to the size of the input
|
|
||||||
# The mask should have been resized as we began the sampling process
|
|
||||||
mask_strength = 1.0
|
|
||||||
if "mask_strength" in conds:
|
|
||||||
mask_strength = conds["mask_strength"]
|
|
||||||
mask = conds['mask']
|
|
||||||
assert(mask.shape[1] == x_in.shape[2])
|
|
||||||
assert(mask.shape[2] == x_in.shape[3])
|
|
||||||
mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] * mask_strength
|
|
||||||
mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1)
|
|
||||||
else:
|
|
||||||
mask = torch.ones_like(input_x)
|
|
||||||
mult = mask * strength
|
|
||||||
|
|
||||||
if 'mask' not in conds:
|
|
||||||
rr = 8
|
|
||||||
if area[2] != 0:
|
|
||||||
for t in range(rr):
|
|
||||||
mult[:,:,t:1+t,:] *= ((1.0/rr) * (t + 1))
|
|
||||||
if (area[0] + area[2]) < x_in.shape[2]:
|
|
||||||
for t in range(rr):
|
|
||||||
mult[:,:,area[0] - 1 - t:area[0] - t,:] *= ((1.0/rr) * (t + 1))
|
|
||||||
if area[3] != 0:
|
|
||||||
for t in range(rr):
|
|
||||||
mult[:,:,:,t:1+t] *= ((1.0/rr) * (t + 1))
|
|
||||||
if (area[1] + area[3]) < x_in.shape[3]:
|
|
||||||
for t in range(rr):
|
|
||||||
mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1))
|
|
||||||
|
|
||||||
conditioning = {}
|
|
||||||
model_conds = conds["model_conds"]
|
|
||||||
for c in model_conds:
|
|
||||||
conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area)
|
|
||||||
|
|
||||||
control = None
|
|
||||||
if 'control' in conds:
|
|
||||||
control = conds['control']
|
|
||||||
|
|
||||||
patches = None
|
|
||||||
if 'gligen' in conds:
|
|
||||||
gligen = conds['gligen']
|
|
||||||
patches = {}
|
|
||||||
gligen_type = gligen[0]
|
|
||||||
gligen_model = gligen[1]
|
|
||||||
if gligen_type == "position":
|
|
||||||
gligen_patch = gligen_model.model.set_position(input_x.shape, gligen[2], input_x.device)
|
|
||||||
else:
|
|
||||||
gligen_patch = gligen_model.model.set_empty(input_x.shape, input_x.device)
|
|
||||||
|
|
||||||
patches['middle_patch'] = [gligen_patch]
|
|
||||||
|
|
||||||
return (input_x, mult, conditioning, area, control, patches)
|
|
||||||
|
|
||||||
def cond_equal_size(c1, c2):
|
|
||||||
if c1 is c2:
|
|
||||||
return True
|
|
||||||
if c1.keys() != c2.keys():
|
|
||||||
return False
|
|
||||||
for k in c1:
|
|
||||||
if not c1[k].can_concat(c2[k]):
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
def can_concat_cond(c1, c2):
|
|
||||||
if c1[0].shape != c2[0].shape:
|
|
||||||
return False
|
|
||||||
|
|
||||||
#control
|
|
||||||
if (c1[4] is None) != (c2[4] is None):
|
|
||||||
return False
|
|
||||||
if c1[4] is not None:
|
|
||||||
if c1[4] is not c2[4]:
|
|
||||||
return False
|
|
||||||
|
|
||||||
#patches
|
|
||||||
if (c1[5] is None) != (c2[5] is None):
|
|
||||||
return False
|
|
||||||
if (c1[5] is not None):
|
|
||||||
if c1[5] is not c2[5]:
|
|
||||||
return False
|
|
||||||
|
|
||||||
return cond_equal_size(c1[2], c2[2])
|
|
||||||
|
|
||||||
def cond_cat(c_list):
|
|
||||||
c_crossattn = []
|
|
||||||
c_concat = []
|
|
||||||
c_adm = []
|
|
||||||
crossattn_max_len = 0
|
|
||||||
|
|
||||||
temp = {}
|
|
||||||
for x in c_list:
|
|
||||||
for k in x:
|
|
||||||
cur = temp.get(k, [])
|
|
||||||
cur.append(x[k])
|
|
||||||
temp[k] = cur
|
|
||||||
|
|
||||||
out = {}
|
|
||||||
for k in temp:
|
|
||||||
conds = temp[k]
|
|
||||||
out[k] = conds[0].concat(conds[1:])
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
|
|
||||||
out_cond = torch.zeros_like(x_in)
|
|
||||||
out_count = torch.ones_like(x_in) * 1e-37
|
|
||||||
|
|
||||||
out_uncond = torch.zeros_like(x_in)
|
|
||||||
out_uncond_count = torch.ones_like(x_in) * 1e-37
|
|
||||||
|
|
||||||
COND = 0
|
|
||||||
UNCOND = 1
|
|
||||||
|
|
||||||
to_run = []
|
|
||||||
for x in cond:
|
|
||||||
p = get_area_and_mult(x, x_in, timestep)
|
|
||||||
if p is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
to_run += [(p, COND)]
|
|
||||||
if uncond is not None:
|
|
||||||
for x in uncond:
|
|
||||||
p = get_area_and_mult(x, x_in, timestep)
|
|
||||||
if p is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
to_run += [(p, UNCOND)]
|
|
||||||
|
|
||||||
while len(to_run) > 0:
|
|
||||||
first = to_run[0]
|
|
||||||
first_shape = first[0][0].shape
|
|
||||||
to_batch_temp = []
|
|
||||||
for x in range(len(to_run)):
|
|
||||||
if can_concat_cond(to_run[x][0], first[0]):
|
|
||||||
to_batch_temp += [x]
|
|
||||||
|
|
||||||
to_batch_temp.reverse()
|
|
||||||
to_batch = to_batch_temp[:1]
|
|
||||||
|
|
||||||
free_memory = model_management.get_free_memory(x_in.device)
|
|
||||||
for i in range(1, len(to_batch_temp) + 1):
|
|
||||||
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
|
|
||||||
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
|
|
||||||
if model.memory_required(input_shape) < free_memory:
|
|
||||||
to_batch = batch_amount
|
|
||||||
break
|
|
||||||
|
|
||||||
input_x = []
|
|
||||||
mult = []
|
|
||||||
c = []
|
|
||||||
cond_or_uncond = []
|
|
||||||
area = []
|
|
||||||
control = None
|
|
||||||
patches = None
|
|
||||||
for x in to_batch:
|
|
||||||
o = to_run.pop(x)
|
|
||||||
p = o[0]
|
|
||||||
input_x += [p[0]]
|
|
||||||
mult += [p[1]]
|
|
||||||
c += [p[2]]
|
|
||||||
area += [p[3]]
|
|
||||||
cond_or_uncond += [o[1]]
|
|
||||||
control = p[4]
|
|
||||||
patches = p[5]
|
|
||||||
|
|
||||||
batch_chunks = len(cond_or_uncond)
|
|
||||||
input_x = torch.cat(input_x)
|
|
||||||
c = cond_cat(c)
|
|
||||||
timestep_ = torch.cat([timestep] * batch_chunks)
|
|
||||||
|
|
||||||
if control is not None:
|
|
||||||
c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond))
|
|
||||||
|
|
||||||
transformer_options = {}
|
|
||||||
if 'transformer_options' in model_options:
|
|
||||||
transformer_options = model_options['transformer_options'].copy()
|
|
||||||
|
|
||||||
if patches is not None:
|
|
||||||
if "patches" in transformer_options:
|
|
||||||
cur_patches = transformer_options["patches"].copy()
|
|
||||||
for p in patches:
|
|
||||||
if p in cur_patches:
|
|
||||||
cur_patches[p] = cur_patches[p] + patches[p]
|
|
||||||
else:
|
|
||||||
cur_patches[p] = patches[p]
|
|
||||||
else:
|
|
||||||
transformer_options["patches"] = patches
|
|
||||||
|
|
||||||
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
|
|
||||||
transformer_options["sigmas"] = timestep
|
|
||||||
|
|
||||||
c['transformer_options'] = transformer_options
|
|
||||||
|
|
||||||
if 'model_function_wrapper' in model_options:
|
|
||||||
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
|
|
||||||
else:
|
|
||||||
output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
|
|
||||||
del input_x
|
|
||||||
|
|
||||||
for o in range(batch_chunks):
|
|
||||||
if cond_or_uncond[o] == COND:
|
|
||||||
out_cond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o]
|
|
||||||
out_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o]
|
|
||||||
else:
|
|
||||||
out_uncond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o]
|
|
||||||
out_uncond_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o]
|
|
||||||
del mult
|
|
||||||
|
|
||||||
out_cond /= out_count
|
|
||||||
del out_count
|
|
||||||
out_uncond /= out_uncond_count
|
|
||||||
del out_uncond_count
|
|
||||||
return out_cond, out_uncond
|
|
||||||
|
|
||||||
|
|
||||||
# if we're doing SAG, we still need to do uncond guidance, even though the cond and uncond will cancel out.
|
|
||||||
if math.isclose(cond_scale, 1.0) and "sag" not in model_options:
|
|
||||||
uncond = None
|
|
||||||
|
|
||||||
cond_pred, uncond_pred = calc_cond_uncond_batch(model, cond, uncond, x, timestep, model_options)
|
|
||||||
cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale
|
cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale
|
||||||
if "sampler_cfg_function" in model_options:
|
if "sampler_cfg_function" in model_options:
|
||||||
args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep}
|
args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep}
|
||||||
cfg_result = x - model_options["sampler_cfg_function"](args)
|
cfg_result = x - model_options["sampler_cfg_function"](args)
|
||||||
|
|
||||||
if "sag" in model_options:
|
for fn in model_options.get("sampler_post_cfg_function", []):
|
||||||
assert uncond is not None, "SAG requires uncond guidance"
|
args = {"denoised": cfg_result, "cond": cond, "uncond": uncond, "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred,
|
||||||
sag_scale = model_options["sag_scale"]
|
"sigma": timestep, "model_options": model_options, "input": x}
|
||||||
sag_sigma = model_options["sag_sigma"]
|
cfg_result = fn(args)
|
||||||
sag_threshold = model_options.get("sag_threshold", 1.0)
|
|
||||||
|
|
||||||
# these methods are added by the sag patcher
|
|
||||||
uncond_attn = model.get_attn_scores()
|
|
||||||
mid_shape = model.get_mid_block_shape()
|
|
||||||
# create the adversarially blurred image
|
|
||||||
degraded = create_blur_map(uncond_pred, uncond_attn, mid_shape, sag_sigma, sag_threshold)
|
|
||||||
degraded_noised = degraded + x - uncond_pred
|
|
||||||
# call into the UNet
|
|
||||||
(sag, _) = calc_cond_uncond_batch(model, uncond, None, degraded_noised, timestep, model_options)
|
|
||||||
cfg_result += (degraded - sag) * sag_scale
|
|
||||||
return cfg_result
|
return cfg_result
|
||||||
|
|
||||||
def create_blur_map(x0, attn, mid_shape, sigma=3.0, threshold=1.0):
|
|
||||||
# reshape and GAP the attention map
|
|
||||||
_, hw1, hw2 = attn.shape
|
|
||||||
b, _, lh, lw = x0.shape
|
|
||||||
attn = attn.reshape(b, -1, hw1, hw2)
|
|
||||||
# Global Average Pool
|
|
||||||
mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold
|
|
||||||
# Reshape
|
|
||||||
mask = (
|
|
||||||
mask.reshape(b, *mid_shape)
|
|
||||||
.unsqueeze(1)
|
|
||||||
.type(attn.dtype)
|
|
||||||
)
|
|
||||||
# Upsample
|
|
||||||
mask = F.interpolate(mask, (lh, lw))
|
|
||||||
|
|
||||||
blurred = gaussian_blur_2d(x0, kernel_size=9, sigma=sigma)
|
|
||||||
blurred = blurred * mask + x0 * (1 - mask)
|
|
||||||
return blurred
|
|
||||||
|
|
||||||
def gaussian_blur_2d(img, kernel_size, sigma):
|
|
||||||
ksize_half = (kernel_size - 1) * 0.5
|
|
||||||
|
|
||||||
x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
|
|
||||||
|
|
||||||
pdf = torch.exp(-0.5 * (x / sigma).pow(2))
|
|
||||||
|
|
||||||
x_kernel = pdf / pdf.sum()
|
|
||||||
x_kernel = x_kernel.to(device=img.device, dtype=img.dtype)
|
|
||||||
|
|
||||||
kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :])
|
|
||||||
kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1])
|
|
||||||
|
|
||||||
padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2]
|
|
||||||
|
|
||||||
img = F.pad(img, padding, mode="reflect")
|
|
||||||
img = F.conv2d(img, kernel2d, groups=img.shape[-3])
|
|
||||||
return img
|
|
||||||
|
|
||||||
class CFGNoisePredictor(torch.nn.Module):
|
class CFGNoisePredictor(torch.nn.Module):
|
||||||
def __init__(self, model):
|
def __init__(self, model):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -1,8 +1,12 @@
|
|||||||
import torch
|
import torch
|
||||||
from torch import einsum
|
from torch import einsum
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import math
|
||||||
|
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
import os
|
import os
|
||||||
from comfy.ldm.modules.attention import optimized_attention, _ATTN_PRECISION
|
from comfy.ldm.modules.attention import optimized_attention, _ATTN_PRECISION
|
||||||
|
import comfy.samplers
|
||||||
|
|
||||||
# from comfy/ldm/modules/attention.py
|
# from comfy/ldm/modules/attention.py
|
||||||
# but modified to return attention scores as well as output
|
# but modified to return attention scores as well as output
|
||||||
@ -49,7 +53,49 @@ def attention_basic_with_sim(q, k, v, heads, mask=None):
|
|||||||
)
|
)
|
||||||
return (out, sim)
|
return (out, sim)
|
||||||
|
|
||||||
class SagNode:
|
def create_blur_map(x0, attn, sigma=3.0, threshold=1.0):
|
||||||
|
# reshape and GAP the attention map
|
||||||
|
_, hw1, hw2 = attn.shape
|
||||||
|
b, _, lh, lw = x0.shape
|
||||||
|
attn = attn.reshape(b, -1, hw1, hw2)
|
||||||
|
# Global Average Pool
|
||||||
|
mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold
|
||||||
|
ratio = round(math.sqrt(lh * lw / hw1))
|
||||||
|
mid_shape = [math.ceil(lh / ratio), math.ceil(lw / ratio)]
|
||||||
|
|
||||||
|
# Reshape
|
||||||
|
mask = (
|
||||||
|
mask.reshape(b, *mid_shape)
|
||||||
|
.unsqueeze(1)
|
||||||
|
.type(attn.dtype)
|
||||||
|
)
|
||||||
|
# Upsample
|
||||||
|
mask = F.interpolate(mask, (lh, lw))
|
||||||
|
|
||||||
|
blurred = gaussian_blur_2d(x0, kernel_size=9, sigma=sigma)
|
||||||
|
blurred = blurred * mask + x0 * (1 - mask)
|
||||||
|
return blurred
|
||||||
|
|
||||||
|
def gaussian_blur_2d(img, kernel_size, sigma):
|
||||||
|
ksize_half = (kernel_size - 1) * 0.5
|
||||||
|
|
||||||
|
x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
|
||||||
|
|
||||||
|
pdf = torch.exp(-0.5 * (x / sigma).pow(2))
|
||||||
|
|
||||||
|
x_kernel = pdf / pdf.sum()
|
||||||
|
x_kernel = x_kernel.to(device=img.device, dtype=img.dtype)
|
||||||
|
|
||||||
|
kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :])
|
||||||
|
kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1])
|
||||||
|
|
||||||
|
padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2]
|
||||||
|
|
||||||
|
img = F.pad(img, padding, mode="reflect")
|
||||||
|
img = F.conv2d(img, kernel2d, groups=img.shape[-3])
|
||||||
|
return img
|
||||||
|
|
||||||
|
class SelfAttentionGuidance:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "model": ("MODEL",),
|
return {"required": { "model": ("MODEL",),
|
||||||
@ -63,15 +109,9 @@ class SagNode:
|
|||||||
|
|
||||||
def patch(self, model, scale, blur_sigma):
|
def patch(self, model, scale, blur_sigma):
|
||||||
m = model.clone()
|
m = model.clone()
|
||||||
# set extra options on the model
|
|
||||||
m.model_options["sag"] = True
|
|
||||||
m.model_options["sag_scale"] = scale
|
|
||||||
m.model_options["sag_sigma"] = blur_sigma
|
|
||||||
|
|
||||||
attn_scores = None
|
attn_scores = None
|
||||||
mid_block_shape = None
|
mid_block_shape = None
|
||||||
m.model.get_attn_scores = lambda: attn_scores
|
|
||||||
m.model.get_mid_block_shape = lambda: mid_block_shape
|
|
||||||
|
|
||||||
# TODO: make this work properly with chunked batches
|
# TODO: make this work properly with chunked batches
|
||||||
# currently, we can only save the attn from one UNet call
|
# currently, we can only save the attn from one UNet call
|
||||||
@ -92,24 +132,41 @@ class SagNode:
|
|||||||
else:
|
else:
|
||||||
return optimized_attention(q, k, v, heads=heads)
|
return optimized_attention(q, k, v, heads=heads)
|
||||||
|
|
||||||
|
def post_cfg_function(args):
|
||||||
|
nonlocal attn_scores
|
||||||
|
nonlocal mid_block_shape
|
||||||
|
uncond_attn = attn_scores
|
||||||
|
|
||||||
|
sag_scale = scale
|
||||||
|
sag_sigma = blur_sigma
|
||||||
|
sag_threshold = 1.0
|
||||||
|
model = args["model"]
|
||||||
|
uncond_pred = args["uncond_denoised"]
|
||||||
|
uncond = args["uncond"]
|
||||||
|
cfg_result = args["denoised"]
|
||||||
|
sigma = args["sigma"]
|
||||||
|
model_options = args["model_options"]
|
||||||
|
x = args["input"]
|
||||||
|
|
||||||
|
# create the adversarially blurred image
|
||||||
|
degraded = create_blur_map(uncond_pred, uncond_attn, sag_sigma, sag_threshold)
|
||||||
|
degraded_noised = degraded + x - uncond_pred
|
||||||
|
# call into the UNet
|
||||||
|
(sag, _) = comfy.samplers.calc_cond_uncond_batch(model, uncond, None, degraded_noised, sigma, model_options)
|
||||||
|
return cfg_result + (degraded - sag) * sag_scale
|
||||||
|
|
||||||
|
m.set_model_sampler_post_cfg_function(post_cfg_function)
|
||||||
|
|
||||||
# from diffusers:
|
# from diffusers:
|
||||||
# unet.mid_block.attentions[0].transformer_blocks[0].attn1.patch
|
# unet.mid_block.attentions[0].transformer_blocks[0].attn1.patch
|
||||||
def set_model_patch_replace(patch, name, key):
|
m.set_model_attn1_replace(attn_and_record, "middle", 0, 0)
|
||||||
to = m.model_options["transformer_options"]
|
|
||||||
if "patches_replace" not in to:
|
|
||||||
to["patches_replace"] = {}
|
|
||||||
if name not in to["patches_replace"]:
|
|
||||||
to["patches_replace"][name] = {}
|
|
||||||
to["patches_replace"][name][key] = patch
|
|
||||||
set_model_patch_replace(attn_and_record, "attn1", ("middle", 0, 0))
|
|
||||||
# from diffusers:
|
|
||||||
# unet.mid_block.attentions[0].register_forward_hook()
|
|
||||||
def forward_hook(m, inp, out):
|
|
||||||
nonlocal mid_block_shape
|
|
||||||
mid_block_shape = out[0].shape[-2:]
|
|
||||||
m.model.diffusion_model.middle_block[0].register_forward_hook(forward_hook)
|
|
||||||
return (m, )
|
return (m, )
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"Self-Attention Guidance": SagNode,
|
"SelfAttentionGuidance": SelfAttentionGuidance,
|
||||||
|
}
|
||||||
|
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
|
"SelfAttentionGuidance": "Self-Attention Guidance",
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user