mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Refactor of sampler code to deal more easily with different model types.
This commit is contained in:
parent
ac9c038ac2
commit
3ded1a3a04
@ -180,7 +180,6 @@ class NoiseScheduleVP:
|
||||
|
||||
def model_wrapper(
|
||||
model,
|
||||
sampling_function,
|
||||
noise_schedule,
|
||||
model_type="noise",
|
||||
model_kwargs={},
|
||||
@ -295,7 +294,7 @@ def model_wrapper(
|
||||
if t_continuous.reshape((-1,)).shape[0] == 1:
|
||||
t_continuous = t_continuous.expand((x.shape[0]))
|
||||
t_input = get_model_input_time(t_continuous)
|
||||
output = sampling_function(model, x, t_input, **model_kwargs)
|
||||
output = model(x, t_input, **model_kwargs)
|
||||
if model_type == "noise":
|
||||
return output
|
||||
elif model_type == "x_start":
|
||||
@ -843,10 +842,12 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex
|
||||
else:
|
||||
timesteps = sigmas.clone()
|
||||
|
||||
for s in range(timesteps.shape[0]):
|
||||
timesteps[s] = (model.sigma_to_t(timesteps[s]) / 1000) + (1 / len(model.sigmas))
|
||||
alphas_cumprod = model.inner_model.alphas_cumprod
|
||||
|
||||
ns = NoiseScheduleVP('discrete', alphas_cumprod=model.inner_model.alphas_cumprod)
|
||||
for s in range(timesteps.shape[0]):
|
||||
timesteps[s] = (model.sigma_to_discrete_timestep(timesteps[s]) / 1000) + (1 / len(alphas_cumprod))
|
||||
|
||||
ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
|
||||
|
||||
if image is not None:
|
||||
img = image * ns.marginal_alpha(timesteps[0])
|
||||
@ -859,18 +860,15 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex
|
||||
img = noise
|
||||
|
||||
if to_zero:
|
||||
timesteps[-1] = (1 / len(model.sigmas))
|
||||
timesteps[-1] = (1 / len(alphas_cumprod))
|
||||
|
||||
device = noise.device
|
||||
|
||||
if model.parameterization == "v":
|
||||
model_type = "v"
|
||||
else:
|
||||
model_type = "noise"
|
||||
|
||||
model_type = "noise"
|
||||
|
||||
model_fn = model_wrapper(
|
||||
model.inner_model.inner_model.apply_model,
|
||||
sampling_function,
|
||||
model.predict_eps_discrete_timestep,
|
||||
ns,
|
||||
model_type=model_type,
|
||||
guidance_type="uncond",
|
||||
|
@ -63,12 +63,17 @@ class DiscreteSchedule(nn.Module):
|
||||
t = torch.linspace(t_max, 0, n, device=self.sigmas.device)
|
||||
return sampling.append_zero(self.t_to_sigma(t))
|
||||
|
||||
def sigma_to_t(self, sigma, quantize=None):
|
||||
quantize = self.quantize if quantize is None else quantize
|
||||
def sigma_to_discrete_timestep(self, sigma):
|
||||
log_sigma = sigma.log()
|
||||
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
|
||||
return dists.abs().argmin(dim=0).view(sigma.shape)
|
||||
|
||||
def sigma_to_t(self, sigma, quantize=None):
|
||||
quantize = self.quantize if quantize is None else quantize
|
||||
if quantize:
|
||||
return dists.abs().argmin(dim=0).view(sigma.shape)
|
||||
return self.sigma_to_discrete_timestep(sigma)
|
||||
log_sigma = sigma.log()
|
||||
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
|
||||
low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2)
|
||||
high_idx = low_idx + 1
|
||||
low, high = self.log_sigmas[low_idx], self.log_sigmas[high_idx]
|
||||
@ -85,6 +90,10 @@ class DiscreteSchedule(nn.Module):
|
||||
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
|
||||
return log_sigma.exp()
|
||||
|
||||
def predict_eps_discrete_timestep(self, input, t, **kwargs):
|
||||
sigma = self.t_to_sigma(t.round())
|
||||
input = input * ((sigma ** 2 + 1.0) ** 0.5)
|
||||
return (input - self(input, sigma, **kwargs)) / sigma
|
||||
|
||||
class DiscreteEpsDDPMDenoiser(DiscreteSchedule):
|
||||
"""A wrapper for discrete schedule DDPM models that output eps (the predicted
|
||||
|
@ -14,6 +14,7 @@ class DDIMSampler(object):
|
||||
self.ddpm_num_timesteps = model.num_timesteps
|
||||
self.schedule = schedule
|
||||
self.device = device
|
||||
self.parameterization = kwargs.get("parameterization", "eps")
|
||||
|
||||
def register_buffer(self, name, attr):
|
||||
if type(attr) == torch.Tensor:
|
||||
@ -261,7 +262,7 @@ class DDIMSampler(object):
|
||||
b, *_, device = *x.shape, x.device
|
||||
|
||||
if denoise_function is not None:
|
||||
model_output = denoise_function(self.model.apply_model, x, t, **extra_args)
|
||||
model_output = denoise_function(x, t, **extra_args)
|
||||
elif unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||
model_output = self.model.apply_model(x, t, c)
|
||||
else:
|
||||
@ -289,13 +290,13 @@ class DDIMSampler(object):
|
||||
model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
||||
model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
|
||||
|
||||
if self.model.parameterization == "v":
|
||||
if self.parameterization == "v":
|
||||
e_t = extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * model_output + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
|
||||
else:
|
||||
e_t = model_output
|
||||
|
||||
if score_corrector is not None:
|
||||
assert self.model.parameterization == "eps", 'not implemented'
|
||||
assert self.parameterization == "eps", 'not implemented'
|
||||
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
||||
|
||||
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||
@ -309,7 +310,7 @@ class DDIMSampler(object):
|
||||
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
|
||||
|
||||
# current prediction for x_0
|
||||
if self.model.parameterization != "v":
|
||||
if self.parameterization != "v":
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
else:
|
||||
pred_x0 = extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * x - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * model_output
|
||||
|
@ -4,10 +4,15 @@ from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugme
|
||||
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
|
||||
from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep
|
||||
import numpy as np
|
||||
from enum import Enum
|
||||
from . import utils
|
||||
|
||||
class ModelType(Enum):
|
||||
EPS = 1
|
||||
V_PREDICTION = 2
|
||||
|
||||
class BaseModel(torch.nn.Module):
|
||||
def __init__(self, model_config, v_prediction=False):
|
||||
def __init__(self, model_config, model_type=ModelType.EPS):
|
||||
super().__init__()
|
||||
|
||||
unet_config = model_config.unet_config
|
||||
@ -15,16 +20,11 @@ class BaseModel(torch.nn.Module):
|
||||
self.model_config = model_config
|
||||
self.register_schedule(given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
|
||||
self.diffusion_model = UNetModel(**unet_config)
|
||||
self.v_prediction = v_prediction
|
||||
if self.v_prediction:
|
||||
self.parameterization = "v"
|
||||
else:
|
||||
self.parameterization = "eps"
|
||||
|
||||
self.model_type = model_type
|
||||
self.adm_channels = unet_config.get("adm_in_channels", None)
|
||||
if self.adm_channels is None:
|
||||
self.adm_channels = 0
|
||||
print("v_prediction", v_prediction)
|
||||
print("model_type", model_type.name)
|
||||
print("adm", self.adm_channels)
|
||||
|
||||
def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
|
||||
@ -103,8 +103,8 @@ class BaseModel(torch.nn.Module):
|
||||
|
||||
|
||||
class SD21UNCLIP(BaseModel):
|
||||
def __init__(self, model_config, noise_aug_config, v_prediction=True):
|
||||
super().__init__(model_config, v_prediction)
|
||||
def __init__(self, model_config, noise_aug_config, model_type=ModelType.V_PREDICTION):
|
||||
super().__init__(model_config, model_type)
|
||||
self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**noise_aug_config)
|
||||
|
||||
def encode_adm(self, **kwargs):
|
||||
@ -139,13 +139,13 @@ class SD21UNCLIP(BaseModel):
|
||||
return adm_out
|
||||
|
||||
class SDInpaint(BaseModel):
|
||||
def __init__(self, model_config, v_prediction=False):
|
||||
super().__init__(model_config, v_prediction)
|
||||
def __init__(self, model_config, model_type=ModelType.EPS):
|
||||
super().__init__(model_config, model_type)
|
||||
self.concat_keys = ("mask", "masked_image")
|
||||
|
||||
class SDXLRefiner(BaseModel):
|
||||
def __init__(self, model_config, v_prediction=False):
|
||||
super().__init__(model_config, v_prediction)
|
||||
def __init__(self, model_config, model_type=ModelType.EPS):
|
||||
super().__init__(model_config, model_type)
|
||||
self.embedder = Timestep(256)
|
||||
|
||||
def encode_adm(self, **kwargs):
|
||||
@ -171,8 +171,8 @@ class SDXLRefiner(BaseModel):
|
||||
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
|
||||
|
||||
class SDXL(BaseModel):
|
||||
def __init__(self, model_config, v_prediction=False):
|
||||
super().__init__(model_config, v_prediction)
|
||||
def __init__(self, model_config, model_type=ModelType.EPS):
|
||||
super().__init__(model_config, model_type)
|
||||
self.embedder = Timestep(256)
|
||||
|
||||
def encode_adm(self, **kwargs):
|
||||
|
@ -6,6 +6,7 @@ from comfy import model_management
|
||||
from .ldm.models.diffusion.ddim import DDIMSampler
|
||||
from .ldm.modules.diffusionmodules.util import make_ddim_timesteps
|
||||
import math
|
||||
from comfy import model_base
|
||||
|
||||
def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9)
|
||||
return abs(a*b) // math.gcd(a, b)
|
||||
@ -488,11 +489,11 @@ class KSampler:
|
||||
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}):
|
||||
self.model = model
|
||||
self.model_denoise = CFGNoisePredictor(self.model)
|
||||
if self.model.parameterization == "v":
|
||||
if self.model.model_type == model_base.ModelType.V_PREDICTION:
|
||||
self.model_wrap = CompVisVDenoiser(self.model_denoise, quantize=True)
|
||||
else:
|
||||
self.model_wrap = k_diffusion_external.CompVisDenoiser(self.model_denoise, quantize=True)
|
||||
self.model_wrap.parameterization = self.model.parameterization
|
||||
|
||||
self.model_k = KSamplerX0Inpaint(self.model_wrap)
|
||||
self.device = device
|
||||
if scheduler not in self.SCHEDULERS:
|
||||
@ -614,7 +615,7 @@ class KSampler:
|
||||
elif self.sampler == "ddim":
|
||||
timesteps = []
|
||||
for s in range(sigmas.shape[0]):
|
||||
timesteps.insert(0, self.model_wrap.sigma_to_t(sigmas[s]))
|
||||
timesteps.insert(0, self.model_wrap.sigma_to_discrete_timestep(sigmas[s]))
|
||||
noise_mask = None
|
||||
if denoise_mask is not None:
|
||||
noise_mask = 1.0 - denoise_mask
|
||||
@ -638,7 +639,7 @@ class KSampler:
|
||||
x_T=z_enc,
|
||||
x0=latent_image,
|
||||
img_callback=ddim_callback,
|
||||
denoise_function=sampling_function,
|
||||
denoise_function=self.model_wrap.predict_eps_discrete_timestep,
|
||||
extra_args=extra_args,
|
||||
mask=noise_mask,
|
||||
to_zero=sigmas[-1]==0,
|
||||
|
10
comfy/sd.py
10
comfy/sd.py
@ -1008,11 +1008,11 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
||||
if "noise_aug_config" in model_config_params:
|
||||
noise_aug_config = model_config_params["noise_aug_config"]
|
||||
|
||||
v_prediction = False
|
||||
model_type = model_base.ModelType.EPS
|
||||
|
||||
if "parameterization" in model_config_params:
|
||||
if model_config_params["parameterization"] == "v":
|
||||
v_prediction = True
|
||||
model_type = model_base.ModelType.V_PREDICTION
|
||||
|
||||
clip = None
|
||||
vae = None
|
||||
@ -1032,11 +1032,11 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
||||
model_config.latent_format = latent_formats.SD15(scale_factor=scale_factor)
|
||||
|
||||
if config['model']["target"].endswith("LatentInpaintDiffusion"):
|
||||
model = model_base.SDInpaint(model_config, v_prediction=v_prediction)
|
||||
model = model_base.SDInpaint(model_config, model_type=model_type)
|
||||
elif config['model']["target"].endswith("ImageEmbeddingConditionedLatentDiffusion"):
|
||||
model = model_base.SD21UNCLIP(model_config, noise_aug_config["params"], v_prediction=v_prediction)
|
||||
model = model_base.SD21UNCLIP(model_config, noise_aug_config["params"], model_type=model_type)
|
||||
else:
|
||||
model = model_base.BaseModel(model_config, v_prediction=v_prediction)
|
||||
model = model_base.BaseModel(model_config, model_type=model_type)
|
||||
|
||||
if fp16:
|
||||
model = model.half()
|
||||
|
@ -53,13 +53,13 @@ class SD20(supported_models_base.BASE):
|
||||
|
||||
latent_format = latent_formats.SD15
|
||||
|
||||
def v_prediction(self, state_dict, prefix=""):
|
||||
def model_type(self, state_dict, prefix=""):
|
||||
if self.unet_config["in_channels"] == 4: #SD2.0 inpainting models are not v prediction
|
||||
k = "{}output_blocks.11.1.transformer_blocks.0.norm1.bias".format(prefix)
|
||||
out = state_dict[k]
|
||||
if torch.std(out, unbiased=False) > 0.09: # not sure how well this will actually work. I guess we will find out.
|
||||
return True
|
||||
return False
|
||||
return model_base.ModelType.V_PREDICTION
|
||||
return model_base.ModelType.EPS
|
||||
|
||||
def process_clip_state_dict(self, state_dict):
|
||||
state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.transformer.text_model.", 24)
|
||||
@ -145,8 +145,14 @@ class SDXL(supported_models_base.BASE):
|
||||
|
||||
latent_format = latent_formats.SDXL
|
||||
|
||||
def model_type(self, state_dict, prefix=""):
|
||||
if "v_pred" in state_dict:
|
||||
return model_base.ModelType.V_PREDICTION
|
||||
else:
|
||||
return model_base.ModelType.EPS
|
||||
|
||||
def get_model(self, state_dict, prefix=""):
|
||||
return model_base.SDXL(self)
|
||||
return model_base.SDXL(self, model_type=self.model_type(state_dict, prefix))
|
||||
|
||||
def process_clip_state_dict(self, state_dict):
|
||||
keys_to_replace = {}
|
||||
|
@ -41,8 +41,8 @@ class BASE:
|
||||
return False
|
||||
return True
|
||||
|
||||
def v_prediction(self, state_dict, prefix=""):
|
||||
return False
|
||||
def model_type(self, state_dict, prefix=""):
|
||||
return model_base.ModelType.EPS
|
||||
|
||||
def inpaint_model(self):
|
||||
return self.unet_config["in_channels"] > 4
|
||||
@ -55,11 +55,11 @@ class BASE:
|
||||
|
||||
def get_model(self, state_dict, prefix=""):
|
||||
if self.inpaint_model():
|
||||
return model_base.SDInpaint(self, v_prediction=self.v_prediction(state_dict, prefix))
|
||||
return model_base.SDInpaint(self, model_type=self.model_type(state_dict, prefix))
|
||||
elif self.noise_aug_config is not None:
|
||||
return model_base.SD21UNCLIP(self, self.noise_aug_config, v_prediction=self.v_prediction(state_dict, prefix))
|
||||
return model_base.SD21UNCLIP(self, self.noise_aug_config, model_type=self.model_type(state_dict, prefix))
|
||||
else:
|
||||
return model_base.BaseModel(self, v_prediction=self.v_prediction(state_dict, prefix))
|
||||
return model_base.BaseModel(self, model_type=self.model_type(state_dict, prefix))
|
||||
|
||||
def process_clip_state_dict(self, state_dict):
|
||||
return state_dict
|
||||
|
Loading…
Reference in New Issue
Block a user