Refactor of sampler code to deal more easily with different model types.

This commit is contained in:
comfyanonymous 2023-07-17 01:22:12 -04:00
parent ac9c038ac2
commit 3ded1a3a04
8 changed files with 68 additions and 53 deletions

View File

@ -180,7 +180,6 @@ class NoiseScheduleVP:
def model_wrapper(
model,
sampling_function,
noise_schedule,
model_type="noise",
model_kwargs={},
@ -295,7 +294,7 @@ def model_wrapper(
if t_continuous.reshape((-1,)).shape[0] == 1:
t_continuous = t_continuous.expand((x.shape[0]))
t_input = get_model_input_time(t_continuous)
output = sampling_function(model, x, t_input, **model_kwargs)
output = model(x, t_input, **model_kwargs)
if model_type == "noise":
return output
elif model_type == "x_start":
@ -843,10 +842,12 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex
else:
timesteps = sigmas.clone()
for s in range(timesteps.shape[0]):
timesteps[s] = (model.sigma_to_t(timesteps[s]) / 1000) + (1 / len(model.sigmas))
alphas_cumprod = model.inner_model.alphas_cumprod
ns = NoiseScheduleVP('discrete', alphas_cumprod=model.inner_model.alphas_cumprod)
for s in range(timesteps.shape[0]):
timesteps[s] = (model.sigma_to_discrete_timestep(timesteps[s]) / 1000) + (1 / len(alphas_cumprod))
ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
if image is not None:
img = image * ns.marginal_alpha(timesteps[0])
@ -859,18 +860,15 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex
img = noise
if to_zero:
timesteps[-1] = (1 / len(model.sigmas))
timesteps[-1] = (1 / len(alphas_cumprod))
device = noise.device
if model.parameterization == "v":
model_type = "v"
else:
model_type = "noise"
model_fn = model_wrapper(
model.inner_model.inner_model.apply_model,
sampling_function,
model.predict_eps_discrete_timestep,
ns,
model_type=model_type,
guidance_type="uncond",

View File

@ -63,12 +63,17 @@ class DiscreteSchedule(nn.Module):
t = torch.linspace(t_max, 0, n, device=self.sigmas.device)
return sampling.append_zero(self.t_to_sigma(t))
def sigma_to_t(self, sigma, quantize=None):
quantize = self.quantize if quantize is None else quantize
def sigma_to_discrete_timestep(self, sigma):
log_sigma = sigma.log()
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
if quantize:
return dists.abs().argmin(dim=0).view(sigma.shape)
def sigma_to_t(self, sigma, quantize=None):
quantize = self.quantize if quantize is None else quantize
if quantize:
return self.sigma_to_discrete_timestep(sigma)
log_sigma = sigma.log()
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2)
high_idx = low_idx + 1
low, high = self.log_sigmas[low_idx], self.log_sigmas[high_idx]
@ -85,6 +90,10 @@ class DiscreteSchedule(nn.Module):
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
return log_sigma.exp()
def predict_eps_discrete_timestep(self, input, t, **kwargs):
sigma = self.t_to_sigma(t.round())
input = input * ((sigma ** 2 + 1.0) ** 0.5)
return (input - self(input, sigma, **kwargs)) / sigma
class DiscreteEpsDDPMDenoiser(DiscreteSchedule):
"""A wrapper for discrete schedule DDPM models that output eps (the predicted

View File

@ -14,6 +14,7 @@ class DDIMSampler(object):
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
self.device = device
self.parameterization = kwargs.get("parameterization", "eps")
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
@ -261,7 +262,7 @@ class DDIMSampler(object):
b, *_, device = *x.shape, x.device
if denoise_function is not None:
model_output = denoise_function(self.model.apply_model, x, t, **extra_args)
model_output = denoise_function(x, t, **extra_args)
elif unconditional_conditioning is None or unconditional_guidance_scale == 1.:
model_output = self.model.apply_model(x, t, c)
else:
@ -289,13 +290,13 @@ class DDIMSampler(object):
model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
if self.model.parameterization == "v":
if self.parameterization == "v":
e_t = extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * model_output + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
else:
e_t = model_output
if score_corrector is not None:
assert self.model.parameterization == "eps", 'not implemented'
assert self.parameterization == "eps", 'not implemented'
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
@ -309,7 +310,7 @@ class DDIMSampler(object):
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
# current prediction for x_0
if self.model.parameterization != "v":
if self.parameterization != "v":
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
else:
pred_x0 = extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * x - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * model_output

View File

@ -4,10 +4,15 @@ from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugme
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep
import numpy as np
from enum import Enum
from . import utils
class ModelType(Enum):
EPS = 1
V_PREDICTION = 2
class BaseModel(torch.nn.Module):
def __init__(self, model_config, v_prediction=False):
def __init__(self, model_config, model_type=ModelType.EPS):
super().__init__()
unet_config = model_config.unet_config
@ -15,16 +20,11 @@ class BaseModel(torch.nn.Module):
self.model_config = model_config
self.register_schedule(given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
self.diffusion_model = UNetModel(**unet_config)
self.v_prediction = v_prediction
if self.v_prediction:
self.parameterization = "v"
else:
self.parameterization = "eps"
self.model_type = model_type
self.adm_channels = unet_config.get("adm_in_channels", None)
if self.adm_channels is None:
self.adm_channels = 0
print("v_prediction", v_prediction)
print("model_type", model_type.name)
print("adm", self.adm_channels)
def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
@ -103,8 +103,8 @@ class BaseModel(torch.nn.Module):
class SD21UNCLIP(BaseModel):
def __init__(self, model_config, noise_aug_config, v_prediction=True):
super().__init__(model_config, v_prediction)
def __init__(self, model_config, noise_aug_config, model_type=ModelType.V_PREDICTION):
super().__init__(model_config, model_type)
self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**noise_aug_config)
def encode_adm(self, **kwargs):
@ -139,13 +139,13 @@ class SD21UNCLIP(BaseModel):
return adm_out
class SDInpaint(BaseModel):
def __init__(self, model_config, v_prediction=False):
super().__init__(model_config, v_prediction)
def __init__(self, model_config, model_type=ModelType.EPS):
super().__init__(model_config, model_type)
self.concat_keys = ("mask", "masked_image")
class SDXLRefiner(BaseModel):
def __init__(self, model_config, v_prediction=False):
super().__init__(model_config, v_prediction)
def __init__(self, model_config, model_type=ModelType.EPS):
super().__init__(model_config, model_type)
self.embedder = Timestep(256)
def encode_adm(self, **kwargs):
@ -171,8 +171,8 @@ class SDXLRefiner(BaseModel):
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
class SDXL(BaseModel):
def __init__(self, model_config, v_prediction=False):
super().__init__(model_config, v_prediction)
def __init__(self, model_config, model_type=ModelType.EPS):
super().__init__(model_config, model_type)
self.embedder = Timestep(256)
def encode_adm(self, **kwargs):

View File

@ -6,6 +6,7 @@ from comfy import model_management
from .ldm.models.diffusion.ddim import DDIMSampler
from .ldm.modules.diffusionmodules.util import make_ddim_timesteps
import math
from comfy import model_base
def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9)
return abs(a*b) // math.gcd(a, b)
@ -488,11 +489,11 @@ class KSampler:
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}):
self.model = model
self.model_denoise = CFGNoisePredictor(self.model)
if self.model.parameterization == "v":
if self.model.model_type == model_base.ModelType.V_PREDICTION:
self.model_wrap = CompVisVDenoiser(self.model_denoise, quantize=True)
else:
self.model_wrap = k_diffusion_external.CompVisDenoiser(self.model_denoise, quantize=True)
self.model_wrap.parameterization = self.model.parameterization
self.model_k = KSamplerX0Inpaint(self.model_wrap)
self.device = device
if scheduler not in self.SCHEDULERS:
@ -614,7 +615,7 @@ class KSampler:
elif self.sampler == "ddim":
timesteps = []
for s in range(sigmas.shape[0]):
timesteps.insert(0, self.model_wrap.sigma_to_t(sigmas[s]))
timesteps.insert(0, self.model_wrap.sigma_to_discrete_timestep(sigmas[s]))
noise_mask = None
if denoise_mask is not None:
noise_mask = 1.0 - denoise_mask
@ -638,7 +639,7 @@ class KSampler:
x_T=z_enc,
x0=latent_image,
img_callback=ddim_callback,
denoise_function=sampling_function,
denoise_function=self.model_wrap.predict_eps_discrete_timestep,
extra_args=extra_args,
mask=noise_mask,
to_zero=sigmas[-1]==0,

View File

@ -1008,11 +1008,11 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
if "noise_aug_config" in model_config_params:
noise_aug_config = model_config_params["noise_aug_config"]
v_prediction = False
model_type = model_base.ModelType.EPS
if "parameterization" in model_config_params:
if model_config_params["parameterization"] == "v":
v_prediction = True
model_type = model_base.ModelType.V_PREDICTION
clip = None
vae = None
@ -1032,11 +1032,11 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
model_config.latent_format = latent_formats.SD15(scale_factor=scale_factor)
if config['model']["target"].endswith("LatentInpaintDiffusion"):
model = model_base.SDInpaint(model_config, v_prediction=v_prediction)
model = model_base.SDInpaint(model_config, model_type=model_type)
elif config['model']["target"].endswith("ImageEmbeddingConditionedLatentDiffusion"):
model = model_base.SD21UNCLIP(model_config, noise_aug_config["params"], v_prediction=v_prediction)
model = model_base.SD21UNCLIP(model_config, noise_aug_config["params"], model_type=model_type)
else:
model = model_base.BaseModel(model_config, v_prediction=v_prediction)
model = model_base.BaseModel(model_config, model_type=model_type)
if fp16:
model = model.half()

View File

@ -53,13 +53,13 @@ class SD20(supported_models_base.BASE):
latent_format = latent_formats.SD15
def v_prediction(self, state_dict, prefix=""):
def model_type(self, state_dict, prefix=""):
if self.unet_config["in_channels"] == 4: #SD2.0 inpainting models are not v prediction
k = "{}output_blocks.11.1.transformer_blocks.0.norm1.bias".format(prefix)
out = state_dict[k]
if torch.std(out, unbiased=False) > 0.09: # not sure how well this will actually work. I guess we will find out.
return True
return False
return model_base.ModelType.V_PREDICTION
return model_base.ModelType.EPS
def process_clip_state_dict(self, state_dict):
state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.transformer.text_model.", 24)
@ -145,8 +145,14 @@ class SDXL(supported_models_base.BASE):
latent_format = latent_formats.SDXL
def model_type(self, state_dict, prefix=""):
if "v_pred" in state_dict:
return model_base.ModelType.V_PREDICTION
else:
return model_base.ModelType.EPS
def get_model(self, state_dict, prefix=""):
return model_base.SDXL(self)
return model_base.SDXL(self, model_type=self.model_type(state_dict, prefix))
def process_clip_state_dict(self, state_dict):
keys_to_replace = {}

View File

@ -41,8 +41,8 @@ class BASE:
return False
return True
def v_prediction(self, state_dict, prefix=""):
return False
def model_type(self, state_dict, prefix=""):
return model_base.ModelType.EPS
def inpaint_model(self):
return self.unet_config["in_channels"] > 4
@ -55,11 +55,11 @@ class BASE:
def get_model(self, state_dict, prefix=""):
if self.inpaint_model():
return model_base.SDInpaint(self, v_prediction=self.v_prediction(state_dict, prefix))
return model_base.SDInpaint(self, model_type=self.model_type(state_dict, prefix))
elif self.noise_aug_config is not None:
return model_base.SD21UNCLIP(self, self.noise_aug_config, v_prediction=self.v_prediction(state_dict, prefix))
return model_base.SD21UNCLIP(self, self.noise_aug_config, model_type=self.model_type(state_dict, prefix))
else:
return model_base.BaseModel(self, v_prediction=self.v_prediction(state_dict, prefix))
return model_base.BaseModel(self, model_type=self.model_type(state_dict, prefix))
def process_clip_state_dict(self, state_dict):
return state_dict