diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 55ca913e..e0acdc96 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -61,6 +61,9 @@ class ModelPatcher: else: self.model_options["sampler_cfg_function"] = sampler_cfg_function + def set_model_sampler_post_cfg_function(self, post_cfg_function): + self.model_options["sampler_post_cfg_function"] = self.model_options.get("sampler_post_cfg_function", []) + [post_cfg_function] + def set_model_unet_function_wrapper(self, unet_wrapper_function): self.model_options["model_function_wrapper"] = unet_wrapper_function @@ -70,13 +73,17 @@ class ModelPatcher: to["patches"] = {} to["patches"][name] = to["patches"].get(name, []) + [patch] - def set_model_patch_replace(self, patch, name, block_name, number): + def set_model_patch_replace(self, patch, name, block_name, number, transformer_index=None): to = self.model_options["transformer_options"] if "patches_replace" not in to: to["patches_replace"] = {} if name not in to["patches_replace"]: to["patches_replace"][name] = {} - to["patches_replace"][name][(block_name, number)] = patch + if transformer_index is not None: + block = (block_name, number, transformer_index) + else: + block = (block_name, number) + to["patches_replace"][name][block] = patch def set_model_attn1_patch(self, patch): self.set_model_patch(patch, "attn1_patch") @@ -84,11 +91,11 @@ class ModelPatcher: def set_model_attn2_patch(self, patch): self.set_model_patch(patch, "attn2_patch") - def set_model_attn1_replace(self, patch, block_name, number): - self.set_model_patch_replace(patch, "attn1", block_name, number) + def set_model_attn1_replace(self, patch, block_name, number, transformer_index=None): + self.set_model_patch_replace(patch, "attn1", block_name, number, transformer_index) - def set_model_attn2_replace(self, patch, block_name, number): - self.set_model_patch_replace(patch, "attn2", block_name, number) + def set_model_attn2_replace(self, patch, block_name, number, transformer_index=None): + self.set_model_patch_replace(patch, "attn2", block_name, number, transformer_index) def set_model_attn1_output_patch(self, patch): self.set_model_patch(patch, "attn1_output_patch") diff --git a/comfy/samplers.py b/comfy/samplers.py index 1cdad736..106e7287 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -1,7 +1,6 @@ from .k_diffusion import sampling as k_diffusion_sampling from .extra_samplers import uni_pc import torch -import torch.nn.functional as F import enum from comfy import model_management import math @@ -9,310 +8,260 @@ from comfy import model_base import comfy.utils import comfy.conds +def get_area_and_mult(conds, x_in, timestep_in): + area = (x_in.shape[2], x_in.shape[3], 0, 0) + strength = 1.0 + + if 'timestep_start' in conds: + timestep_start = conds['timestep_start'] + if timestep_in[0] > timestep_start: + return None + if 'timestep_end' in conds: + timestep_end = conds['timestep_end'] + if timestep_in[0] < timestep_end: + return None + if 'area' in conds: + area = conds['area'] + if 'strength' in conds: + strength = conds['strength'] + + input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] + if 'mask' in conds: + # Scale the mask to the size of the input + # The mask should have been resized as we began the sampling process + mask_strength = 1.0 + if "mask_strength" in conds: + mask_strength = conds["mask_strength"] + mask = conds['mask'] + assert(mask.shape[1] == x_in.shape[2]) + assert(mask.shape[2] == x_in.shape[3]) + mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] * mask_strength + mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1) + else: + mask = torch.ones_like(input_x) + mult = mask * strength + + if 'mask' not in conds: + rr = 8 + if area[2] != 0: + for t in range(rr): + mult[:,:,t:1+t,:] *= ((1.0/rr) * (t + 1)) + if (area[0] + area[2]) < x_in.shape[2]: + for t in range(rr): + mult[:,:,area[0] - 1 - t:area[0] - t,:] *= ((1.0/rr) * (t + 1)) + if area[3] != 0: + for t in range(rr): + mult[:,:,:,t:1+t] *= ((1.0/rr) * (t + 1)) + if (area[1] + area[3]) < x_in.shape[3]: + for t in range(rr): + mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1)) + + conditioning = {} + model_conds = conds["model_conds"] + for c in model_conds: + conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area) + + control = None + if 'control' in conds: + control = conds['control'] + + patches = None + if 'gligen' in conds: + gligen = conds['gligen'] + patches = {} + gligen_type = gligen[0] + gligen_model = gligen[1] + if gligen_type == "position": + gligen_patch = gligen_model.model.set_position(input_x.shape, gligen[2], input_x.device) + else: + gligen_patch = gligen_model.model.set_empty(input_x.shape, input_x.device) + + patches['middle_patch'] = [gligen_patch] + + return (input_x, mult, conditioning, area, control, patches) + +def cond_equal_size(c1, c2): + if c1 is c2: + return True + if c1.keys() != c2.keys(): + return False + for k in c1: + if not c1[k].can_concat(c2[k]): + return False + return True + +def can_concat_cond(c1, c2): + if c1[0].shape != c2[0].shape: + return False + + #control + if (c1[4] is None) != (c2[4] is None): + return False + if c1[4] is not None: + if c1[4] is not c2[4]: + return False + + #patches + if (c1[5] is None) != (c2[5] is None): + return False + if (c1[5] is not None): + if c1[5] is not c2[5]: + return False + + return cond_equal_size(c1[2], c2[2]) + +def cond_cat(c_list): + c_crossattn = [] + c_concat = [] + c_adm = [] + crossattn_max_len = 0 + + temp = {} + for x in c_list: + for k in x: + cur = temp.get(k, []) + cur.append(x[k]) + temp[k] = cur + + out = {} + for k in temp: + conds = temp[k] + out[k] = conds[0].concat(conds[1:]) + + return out + +def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): + out_cond = torch.zeros_like(x_in) + out_count = torch.ones_like(x_in) * 1e-37 + + out_uncond = torch.zeros_like(x_in) + out_uncond_count = torch.ones_like(x_in) * 1e-37 + + COND = 0 + UNCOND = 1 + + to_run = [] + for x in cond: + p = get_area_and_mult(x, x_in, timestep) + if p is None: + continue + + to_run += [(p, COND)] + if uncond is not None: + for x in uncond: + p = get_area_and_mult(x, x_in, timestep) + if p is None: + continue + + to_run += [(p, UNCOND)] + + while len(to_run) > 0: + first = to_run[0] + first_shape = first[0][0].shape + to_batch_temp = [] + for x in range(len(to_run)): + if can_concat_cond(to_run[x][0], first[0]): + to_batch_temp += [x] + + to_batch_temp.reverse() + to_batch = to_batch_temp[:1] + + free_memory = model_management.get_free_memory(x_in.device) + for i in range(1, len(to_batch_temp) + 1): + batch_amount = to_batch_temp[:len(to_batch_temp)//i] + input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:] + if model.memory_required(input_shape) < free_memory: + to_batch = batch_amount + break + + input_x = [] + mult = [] + c = [] + cond_or_uncond = [] + area = [] + control = None + patches = None + for x in to_batch: + o = to_run.pop(x) + p = o[0] + input_x += [p[0]] + mult += [p[1]] + c += [p[2]] + area += [p[3]] + cond_or_uncond += [o[1]] + control = p[4] + patches = p[5] + + batch_chunks = len(cond_or_uncond) + input_x = torch.cat(input_x) + c = cond_cat(c) + timestep_ = torch.cat([timestep] * batch_chunks) + + if control is not None: + c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond)) + + transformer_options = {} + if 'transformer_options' in model_options: + transformer_options = model_options['transformer_options'].copy() + + if patches is not None: + if "patches" in transformer_options: + cur_patches = transformer_options["patches"].copy() + for p in patches: + if p in cur_patches: + cur_patches[p] = cur_patches[p] + patches[p] + else: + cur_patches[p] = patches[p] + else: + transformer_options["patches"] = patches + + transformer_options["cond_or_uncond"] = cond_or_uncond[:] + transformer_options["sigmas"] = timestep + + c['transformer_options'] = transformer_options + + if 'model_function_wrapper' in model_options: + output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks) + else: + output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks) + del input_x + + for o in range(batch_chunks): + if cond_or_uncond[o] == COND: + out_cond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o] + out_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o] + else: + out_uncond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o] + out_uncond_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o] + del mult + + out_cond /= out_count + del out_count + out_uncond /= out_uncond_count + del out_uncond_count + return out_cond, out_uncond #The main sampling function shared by all the samplers #Returns denoised def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None): - def get_area_and_mult(conds, x_in, timestep_in): - area = (x_in.shape[2], x_in.shape[3], 0, 0) - strength = 1.0 + if math.isclose(cond_scale, 1.0): + uncond_ = None + else: + uncond_ = uncond - if 'timestep_start' in conds: - timestep_start = conds['timestep_start'] - if timestep_in[0] > timestep_start: - return None - if 'timestep_end' in conds: - timestep_end = conds['timestep_end'] - if timestep_in[0] < timestep_end: - return None - if 'area' in conds: - area = conds['area'] - if 'strength' in conds: - strength = conds['strength'] - - input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] - if 'mask' in conds: - # Scale the mask to the size of the input - # The mask should have been resized as we began the sampling process - mask_strength = 1.0 - if "mask_strength" in conds: - mask_strength = conds["mask_strength"] - mask = conds['mask'] - assert(mask.shape[1] == x_in.shape[2]) - assert(mask.shape[2] == x_in.shape[3]) - mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] * mask_strength - mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1) - else: - mask = torch.ones_like(input_x) - mult = mask * strength - - if 'mask' not in conds: - rr = 8 - if area[2] != 0: - for t in range(rr): - mult[:,:,t:1+t,:] *= ((1.0/rr) * (t + 1)) - if (area[0] + area[2]) < x_in.shape[2]: - for t in range(rr): - mult[:,:,area[0] - 1 - t:area[0] - t,:] *= ((1.0/rr) * (t + 1)) - if area[3] != 0: - for t in range(rr): - mult[:,:,:,t:1+t] *= ((1.0/rr) * (t + 1)) - if (area[1] + area[3]) < x_in.shape[3]: - for t in range(rr): - mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1)) - - conditioning = {} - model_conds = conds["model_conds"] - for c in model_conds: - conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area) - - control = None - if 'control' in conds: - control = conds['control'] - - patches = None - if 'gligen' in conds: - gligen = conds['gligen'] - patches = {} - gligen_type = gligen[0] - gligen_model = gligen[1] - if gligen_type == "position": - gligen_patch = gligen_model.model.set_position(input_x.shape, gligen[2], input_x.device) - else: - gligen_patch = gligen_model.model.set_empty(input_x.shape, input_x.device) - - patches['middle_patch'] = [gligen_patch] - - return (input_x, mult, conditioning, area, control, patches) - - def cond_equal_size(c1, c2): - if c1 is c2: - return True - if c1.keys() != c2.keys(): - return False - for k in c1: - if not c1[k].can_concat(c2[k]): - return False - return True - - def can_concat_cond(c1, c2): - if c1[0].shape != c2[0].shape: - return False - - #control - if (c1[4] is None) != (c2[4] is None): - return False - if c1[4] is not None: - if c1[4] is not c2[4]: - return False - - #patches - if (c1[5] is None) != (c2[5] is None): - return False - if (c1[5] is not None): - if c1[5] is not c2[5]: - return False - - return cond_equal_size(c1[2], c2[2]) - - def cond_cat(c_list): - c_crossattn = [] - c_concat = [] - c_adm = [] - crossattn_max_len = 0 - - temp = {} - for x in c_list: - for k in x: - cur = temp.get(k, []) - cur.append(x[k]) - temp[k] = cur - - out = {} - for k in temp: - conds = temp[k] - out[k] = conds[0].concat(conds[1:]) - - return out - - def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): - out_cond = torch.zeros_like(x_in) - out_count = torch.ones_like(x_in) * 1e-37 - - out_uncond = torch.zeros_like(x_in) - out_uncond_count = torch.ones_like(x_in) * 1e-37 - - COND = 0 - UNCOND = 1 - - to_run = [] - for x in cond: - p = get_area_and_mult(x, x_in, timestep) - if p is None: - continue - - to_run += [(p, COND)] - if uncond is not None: - for x in uncond: - p = get_area_and_mult(x, x_in, timestep) - if p is None: - continue - - to_run += [(p, UNCOND)] - - while len(to_run) > 0: - first = to_run[0] - first_shape = first[0][0].shape - to_batch_temp = [] - for x in range(len(to_run)): - if can_concat_cond(to_run[x][0], first[0]): - to_batch_temp += [x] - - to_batch_temp.reverse() - to_batch = to_batch_temp[:1] - - free_memory = model_management.get_free_memory(x_in.device) - for i in range(1, len(to_batch_temp) + 1): - batch_amount = to_batch_temp[:len(to_batch_temp)//i] - input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:] - if model.memory_required(input_shape) < free_memory: - to_batch = batch_amount - break - - input_x = [] - mult = [] - c = [] - cond_or_uncond = [] - area = [] - control = None - patches = None - for x in to_batch: - o = to_run.pop(x) - p = o[0] - input_x += [p[0]] - mult += [p[1]] - c += [p[2]] - area += [p[3]] - cond_or_uncond += [o[1]] - control = p[4] - patches = p[5] - - batch_chunks = len(cond_or_uncond) - input_x = torch.cat(input_x) - c = cond_cat(c) - timestep_ = torch.cat([timestep] * batch_chunks) - - if control is not None: - c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond)) - - transformer_options = {} - if 'transformer_options' in model_options: - transformer_options = model_options['transformer_options'].copy() - - if patches is not None: - if "patches" in transformer_options: - cur_patches = transformer_options["patches"].copy() - for p in patches: - if p in cur_patches: - cur_patches[p] = cur_patches[p] + patches[p] - else: - cur_patches[p] = patches[p] - else: - transformer_options["patches"] = patches - - transformer_options["cond_or_uncond"] = cond_or_uncond[:] - transformer_options["sigmas"] = timestep - - c['transformer_options'] = transformer_options - - if 'model_function_wrapper' in model_options: - output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks) - else: - output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks) - del input_x - - for o in range(batch_chunks): - if cond_or_uncond[o] == COND: - out_cond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o] - out_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o] - else: - out_uncond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o] - out_uncond_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o] - del mult - - out_cond /= out_count - del out_count - out_uncond /= out_uncond_count - del out_uncond_count - return out_cond, out_uncond - - - # if we're doing SAG, we still need to do uncond guidance, even though the cond and uncond will cancel out. - if math.isclose(cond_scale, 1.0) and "sag" not in model_options: - uncond = None - - cond_pred, uncond_pred = calc_cond_uncond_batch(model, cond, uncond, x, timestep, model_options) + cond_pred, uncond_pred = calc_cond_uncond_batch(model, cond, uncond_, x, timestep, model_options) cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale if "sampler_cfg_function" in model_options: args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep} cfg_result = x - model_options["sampler_cfg_function"](args) - if "sag" in model_options: - assert uncond is not None, "SAG requires uncond guidance" - sag_scale = model_options["sag_scale"] - sag_sigma = model_options["sag_sigma"] - sag_threshold = model_options.get("sag_threshold", 1.0) + for fn in model_options.get("sampler_post_cfg_function", []): + args = {"denoised": cfg_result, "cond": cond, "uncond": uncond, "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred, + "sigma": timestep, "model_options": model_options, "input": x} + cfg_result = fn(args) - # these methods are added by the sag patcher - uncond_attn = model.get_attn_scores() - mid_shape = model.get_mid_block_shape() - # create the adversarially blurred image - degraded = create_blur_map(uncond_pred, uncond_attn, mid_shape, sag_sigma, sag_threshold) - degraded_noised = degraded + x - uncond_pred - # call into the UNet - (sag, _) = calc_cond_uncond_batch(model, uncond, None, degraded_noised, timestep, model_options) - cfg_result += (degraded - sag) * sag_scale return cfg_result -def create_blur_map(x0, attn, mid_shape, sigma=3.0, threshold=1.0): - # reshape and GAP the attention map - _, hw1, hw2 = attn.shape - b, _, lh, lw = x0.shape - attn = attn.reshape(b, -1, hw1, hw2) - # Global Average Pool - mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold - # Reshape - mask = ( - mask.reshape(b, *mid_shape) - .unsqueeze(1) - .type(attn.dtype) - ) - # Upsample - mask = F.interpolate(mask, (lh, lw)) - - blurred = gaussian_blur_2d(x0, kernel_size=9, sigma=sigma) - blurred = blurred * mask + x0 * (1 - mask) - return blurred - -def gaussian_blur_2d(img, kernel_size, sigma): - ksize_half = (kernel_size - 1) * 0.5 - - x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size) - - pdf = torch.exp(-0.5 * (x / sigma).pow(2)) - - x_kernel = pdf / pdf.sum() - x_kernel = x_kernel.to(device=img.device, dtype=img.dtype) - - kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :]) - kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1]) - - padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2] - - img = F.pad(img, padding, mode="reflect") - img = F.conv2d(img, kernel2d, groups=img.shape[-3]) - return img - class CFGNoisePredictor(torch.nn.Module): def __init__(self, model): super().__init__() diff --git a/comfy_extras/nodes_sag.py b/comfy_extras/nodes_sag.py index 1ec0c93a..4c609565 100644 --- a/comfy_extras/nodes_sag.py +++ b/comfy_extras/nodes_sag.py @@ -1,8 +1,12 @@ import torch from torch import einsum +import torch.nn.functional as F +import math + from einops import rearrange, repeat import os from comfy.ldm.modules.attention import optimized_attention, _ATTN_PRECISION +import comfy.samplers # from comfy/ldm/modules/attention.py # but modified to return attention scores as well as output @@ -49,7 +53,49 @@ def attention_basic_with_sim(q, k, v, heads, mask=None): ) return (out, sim) -class SagNode: +def create_blur_map(x0, attn, sigma=3.0, threshold=1.0): + # reshape and GAP the attention map + _, hw1, hw2 = attn.shape + b, _, lh, lw = x0.shape + attn = attn.reshape(b, -1, hw1, hw2) + # Global Average Pool + mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold + ratio = round(math.sqrt(lh * lw / hw1)) + mid_shape = [math.ceil(lh / ratio), math.ceil(lw / ratio)] + + # Reshape + mask = ( + mask.reshape(b, *mid_shape) + .unsqueeze(1) + .type(attn.dtype) + ) + # Upsample + mask = F.interpolate(mask, (lh, lw)) + + blurred = gaussian_blur_2d(x0, kernel_size=9, sigma=sigma) + blurred = blurred * mask + x0 * (1 - mask) + return blurred + +def gaussian_blur_2d(img, kernel_size, sigma): + ksize_half = (kernel_size - 1) * 0.5 + + x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size) + + pdf = torch.exp(-0.5 * (x / sigma).pow(2)) + + x_kernel = pdf / pdf.sum() + x_kernel = x_kernel.to(device=img.device, dtype=img.dtype) + + kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :]) + kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1]) + + padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2] + + img = F.pad(img, padding, mode="reflect") + img = F.conv2d(img, kernel2d, groups=img.shape[-3]) + return img + +class SelfAttentionGuidance: @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), @@ -63,15 +109,9 @@ class SagNode: def patch(self, model, scale, blur_sigma): m = model.clone() - # set extra options on the model - m.model_options["sag"] = True - m.model_options["sag_scale"] = scale - m.model_options["sag_sigma"] = blur_sigma - + attn_scores = None mid_block_shape = None - m.model.get_attn_scores = lambda: attn_scores - m.model.get_mid_block_shape = lambda: mid_block_shape # TODO: make this work properly with chunked batches # currently, we can only save the attn from one UNet call @@ -92,24 +132,41 @@ class SagNode: else: return optimized_attention(q, k, v, heads=heads) + def post_cfg_function(args): + nonlocal attn_scores + nonlocal mid_block_shape + uncond_attn = attn_scores + + sag_scale = scale + sag_sigma = blur_sigma + sag_threshold = 1.0 + model = args["model"] + uncond_pred = args["uncond_denoised"] + uncond = args["uncond"] + cfg_result = args["denoised"] + sigma = args["sigma"] + model_options = args["model_options"] + x = args["input"] + + # create the adversarially blurred image + degraded = create_blur_map(uncond_pred, uncond_attn, sag_sigma, sag_threshold) + degraded_noised = degraded + x - uncond_pred + # call into the UNet + (sag, _) = comfy.samplers.calc_cond_uncond_batch(model, uncond, None, degraded_noised, sigma, model_options) + return cfg_result + (degraded - sag) * sag_scale + + m.set_model_sampler_post_cfg_function(post_cfg_function) + # from diffusers: # unet.mid_block.attentions[0].transformer_blocks[0].attn1.patch - def set_model_patch_replace(patch, name, key): - to = m.model_options["transformer_options"] - if "patches_replace" not in to: - to["patches_replace"] = {} - if name not in to["patches_replace"]: - to["patches_replace"][name] = {} - to["patches_replace"][name][key] = patch - set_model_patch_replace(attn_and_record, "attn1", ("middle", 0, 0)) - # from diffusers: - # unet.mid_block.attentions[0].register_forward_hook() - def forward_hook(m, inp, out): - nonlocal mid_block_shape - mid_block_shape = out[0].shape[-2:] - m.model.diffusion_model.middle_block[0].register_forward_hook(forward_hook) + m.set_model_attn1_replace(attn_and_record, "middle", 0, 0) + return (m, ) NODE_CLASS_MAPPINGS = { - "Self-Attention Guidance": SagNode, + "SelfAttentionGuidance": SelfAttentionGuidance, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "SelfAttentionGuidance": "Self-Attention Guidance", }