mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 10:25:16 +00:00
Simplify differential diffusion code.
This commit is contained in:
parent
727021bdea
commit
12c1080ebc
@ -67,6 +67,9 @@ class ModelPatcher:
|
|||||||
def set_model_unet_function_wrapper(self, unet_wrapper_function):
|
def set_model_unet_function_wrapper(self, unet_wrapper_function):
|
||||||
self.model_options["model_function_wrapper"] = unet_wrapper_function
|
self.model_options["model_function_wrapper"] = unet_wrapper_function
|
||||||
|
|
||||||
|
def set_model_denoise_mask_function(self, denoise_mask_function):
|
||||||
|
self.model_options["denoise_mask_function"] = denoise_mask_function
|
||||||
|
|
||||||
def set_model_patch(self, patch, name):
|
def set_model_patch(self, patch, name):
|
||||||
to = self.model_options["transformer_options"]
|
to = self.model_options["transformer_options"]
|
||||||
if "patches" not in to:
|
if "patches" not in to:
|
||||||
|
@ -272,13 +272,14 @@ class CFGNoisePredictor(torch.nn.Module):
|
|||||||
return self.apply_model(*args, **kwargs)
|
return self.apply_model(*args, **kwargs)
|
||||||
|
|
||||||
class KSamplerX0Inpaint(torch.nn.Module):
|
class KSamplerX0Inpaint(torch.nn.Module):
|
||||||
def __init__(self, model):
|
def __init__(self, model, sigmas):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.inner_model = model
|
self.inner_model = model
|
||||||
|
self.sigmas = sigmas
|
||||||
def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, model_options={}, seed=None):
|
def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, model_options={}, seed=None):
|
||||||
if denoise_mask is not None:
|
if denoise_mask is not None:
|
||||||
if "denoise_mask_function" in model_options:
|
if "denoise_mask_function" in model_options:
|
||||||
denoise_mask = model_options["denoise_mask_function"](sigma, denoise_mask)
|
denoise_mask = model_options["denoise_mask_function"](sigma, denoise_mask, extra_options={"model": self.inner_model, "sigmas": self.sigmas})
|
||||||
latent_mask = 1. - denoise_mask
|
latent_mask = 1. - denoise_mask
|
||||||
x = x * denoise_mask + self.inner_model.inner_model.model_sampling.noise_scaling(sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1)), self.noise, self.latent_image) * latent_mask
|
x = x * denoise_mask + self.inner_model.inner_model.model_sampling.noise_scaling(sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1)), self.noise, self.latent_image) * latent_mask
|
||||||
out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, model_options=model_options, seed=seed)
|
out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, model_options=model_options, seed=seed)
|
||||||
@ -528,7 +529,7 @@ class KSAMPLER(Sampler):
|
|||||||
|
|
||||||
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
|
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
|
||||||
extra_args["denoise_mask"] = denoise_mask
|
extra_args["denoise_mask"] = denoise_mask
|
||||||
model_k = KSamplerX0Inpaint(model_wrap)
|
model_k = KSamplerX0Inpaint(model_wrap, sigmas)
|
||||||
model_k.latent_image = latent_image
|
model_k.latent_image = latent_image
|
||||||
if self.inpaint_options.get("random", False): #TODO: Should this be the default?
|
if self.inpaint_options.get("random", False): #TODO: Should this be the default?
|
||||||
generator = torch.manual_seed(extra_args.get("seed", 41) + 1)
|
generator = torch.manual_seed(extra_args.get("seed", 41) + 1)
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
# code adapted from https://github.com/exx8/differential-diffusion
|
# code adapted from https://github.com/exx8/differential-diffusion
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import inspect
|
|
||||||
|
|
||||||
class DifferentialDiffusion():
|
class DifferentialDiffusion():
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -13,81 +12,27 @@ class DifferentialDiffusion():
|
|||||||
CATEGORY = "_for_testing"
|
CATEGORY = "_for_testing"
|
||||||
INIT = False
|
INIT = False
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def IS_CHANGED(s, *args, **kwargs):
|
|
||||||
DifferentialDiffusion.INIT = s.INIT = True
|
|
||||||
return ""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
DifferentialDiffusion.INIT = False
|
|
||||||
self.sigmas: torch.Tensor = None
|
|
||||||
self.thresholds: torch.Tensor = None
|
|
||||||
self.mask_i = None
|
|
||||||
self.valid_sigmas = False
|
|
||||||
self.varying_sigmas_samplers = ["dpmpp_2s", "dpmpp_sde", "dpm_2", "heun", "restart"]
|
|
||||||
|
|
||||||
def apply(self, model):
|
def apply(self, model):
|
||||||
model = model.clone()
|
model = model.clone()
|
||||||
model.model_options["denoise_mask_function"] = self.forward
|
model.set_model_denoise_mask_function(self.forward)
|
||||||
return (model,)
|
return (model,)
|
||||||
|
|
||||||
def init_sigmas(self, sigma: torch.Tensor, denoise_mask: torch.Tensor):
|
def forward(self, sigma: torch.Tensor, denoise_mask: torch.Tensor, extra_options: dict):
|
||||||
self.__init__()
|
model = extra_options["model"]
|
||||||
self.sigmas, sampler = find_outer_instance("sigmas", callback=get_sigmas_and_sampler) or (None, "")
|
step_sigmas = extra_options["sigmas"]
|
||||||
self.valid_sigmas = not ("sample_" not in sampler or any(s in sampler for s in self.varying_sigmas_samplers)) or "generic" in sampler
|
sigma_to = model.inner_model.model_sampling.sigma_min
|
||||||
if self.sigmas is None:
|
if step_sigmas[-1] > sigma_to:
|
||||||
self.sigmas = sigma[:1].repeat(2)
|
sigma_to = step_sigmas[-1]
|
||||||
self.sigmas[-1].zero_()
|
sigma_from = step_sigmas[0]
|
||||||
self.sigmas_min = self.sigmas.min()
|
|
||||||
self.sigmas_max = self.sigmas.max()
|
ts_from = model.inner_model.model_sampling.timestep(sigma_from)
|
||||||
self.thresholds = torch.linspace(1, 0, self.sigmas.shape[0], dtype=sigma.dtype, device=sigma.device)
|
ts_to = model.inner_model.model_sampling.timestep(sigma_to)
|
||||||
self.thresholds_min_len = self.thresholds.shape[0] - 1
|
current_ts = model.inner_model.model_sampling.timestep(sigma)
|
||||||
if self.valid_sigmas:
|
|
||||||
thresholds = self.thresholds[:-1].reshape(-1, 1, 1, 1, 1)
|
threshold = (current_ts - ts_to) / (ts_from - ts_to)
|
||||||
mask = denoise_mask.unsqueeze(0)
|
|
||||||
mask = (mask >= thresholds).to(denoise_mask.dtype)
|
|
||||||
self.mask_i = iter(mask)
|
|
||||||
|
|
||||||
def forward(self, sigma: torch.Tensor, denoise_mask: torch.Tensor):
|
|
||||||
if self.sigmas is None or DifferentialDiffusion.INIT:
|
|
||||||
self.init_sigmas(sigma, denoise_mask)
|
|
||||||
if self.valid_sigmas:
|
|
||||||
try:
|
|
||||||
return next(self.mask_i)
|
|
||||||
except StopIteration:
|
|
||||||
self.valid_sigmas = False
|
|
||||||
if self.thresholds_min_len > 1:
|
|
||||||
nearest_idx = (self.sigmas - sigma[0]).abs().argmin()
|
|
||||||
if not self.thresholds_min_len > nearest_idx:
|
|
||||||
nearest_idx = -2
|
|
||||||
threshold = self.thresholds[nearest_idx]
|
|
||||||
else:
|
|
||||||
threshold = (sigma[0] - self.sigmas_min) / (self.sigmas_max - self.sigmas_min)
|
|
||||||
return (denoise_mask >= threshold).to(denoise_mask.dtype)
|
return (denoise_mask >= threshold).to(denoise_mask.dtype)
|
||||||
|
|
||||||
def get_sigmas_and_sampler(frame, target):
|
|
||||||
found = frame.f_locals[target]
|
|
||||||
if isinstance(found, torch.Tensor) and found[-1] < 0.1:
|
|
||||||
return found, frame.f_code.co_name
|
|
||||||
return False
|
|
||||||
|
|
||||||
def find_outer_instance(target: str, target_type=None, callback=None):
|
|
||||||
frame = inspect.currentframe()
|
|
||||||
i = 0
|
|
||||||
while frame and i < 100:
|
|
||||||
if target in frame.f_locals:
|
|
||||||
if callback is not None:
|
|
||||||
res = callback(frame, target)
|
|
||||||
if res:
|
|
||||||
return res
|
|
||||||
else:
|
|
||||||
found = frame.f_locals[target]
|
|
||||||
if isinstance(found, target_type):
|
|
||||||
return found
|
|
||||||
frame = frame.f_back
|
|
||||||
i += 1
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"DifferentialDiffusion": DifferentialDiffusion,
|
"DifferentialDiffusion": DifferentialDiffusion,
|
||||||
|
Loading…
Reference in New Issue
Block a user