mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-03-14 13:17:32 +00:00
Sampling code refactor.
This commit is contained in:
parent
aeba1cc2a0
commit
446caf711c
@ -59,7 +59,7 @@ class DDIMSampler(object):
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_custom(self,
|
def sample_custom(self,
|
||||||
ddim_timesteps,
|
ddim_timesteps,
|
||||||
conditioning,
|
conditioning=None,
|
||||||
callback=None,
|
callback=None,
|
||||||
img_callback=None,
|
img_callback=None,
|
||||||
quantize_x0=False,
|
quantize_x0=False,
|
||||||
|
@ -544,11 +544,152 @@ def encode_adm(model, conds, batch_size, width, height, device, prompt_type):
|
|||||||
return conds
|
return conds
|
||||||
|
|
||||||
|
|
||||||
|
class Sampler:
|
||||||
|
def sample(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def max_denoise(self, model_wrap, sigmas):
|
||||||
|
return math.isclose(float(model_wrap.sigma_max), float(sigmas[0]))
|
||||||
|
|
||||||
|
class DDIM(Sampler):
|
||||||
|
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
|
||||||
|
timesteps = []
|
||||||
|
for s in range(sigmas.shape[0]):
|
||||||
|
timesteps.insert(0, model_wrap.sigma_to_discrete_timestep(sigmas[s]))
|
||||||
|
noise_mask = None
|
||||||
|
if denoise_mask is not None:
|
||||||
|
noise_mask = 1.0 - denoise_mask
|
||||||
|
|
||||||
|
ddim_callback = None
|
||||||
|
if callback is not None:
|
||||||
|
total_steps = len(timesteps) - 1
|
||||||
|
ddim_callback = lambda pred_x0, i: callback(i, pred_x0, None, total_steps)
|
||||||
|
|
||||||
|
max_denoise = self.max_denoise(model_wrap, sigmas)
|
||||||
|
|
||||||
|
ddim_sampler = DDIMSampler(model_wrap.inner_model.inner_model, device=noise.device)
|
||||||
|
ddim_sampler.make_schedule_timesteps(ddim_timesteps=timesteps, verbose=False)
|
||||||
|
z_enc = ddim_sampler.stochastic_encode(latent_image, torch.tensor([len(timesteps) - 1] * noise.shape[0]).to(noise.device), noise=noise, max_denoise=max_denoise)
|
||||||
|
samples, _ = ddim_sampler.sample_custom(ddim_timesteps=timesteps,
|
||||||
|
batch_size=noise.shape[0],
|
||||||
|
shape=noise.shape[1:],
|
||||||
|
verbose=False,
|
||||||
|
eta=0.0,
|
||||||
|
x_T=z_enc,
|
||||||
|
x0=latent_image,
|
||||||
|
img_callback=ddim_callback,
|
||||||
|
denoise_function=model_wrap.predict_eps_discrete_timestep,
|
||||||
|
extra_args=extra_args,
|
||||||
|
mask=noise_mask,
|
||||||
|
to_zero=sigmas[-1]==0,
|
||||||
|
end_step=sigmas.shape[0] - 1,
|
||||||
|
disable_pbar=disable_pbar)
|
||||||
|
return samples
|
||||||
|
|
||||||
|
class UNIPC(Sampler):
|
||||||
|
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
|
||||||
|
return uni_pc.sample_unipc(model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=self.max_denoise(model_wrap, sigmas), extra_args=extra_args, noise_mask=denoise_mask, callback=callback, disable=disable_pbar)
|
||||||
|
|
||||||
|
class UNIPCBH2(Sampler):
|
||||||
|
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
|
||||||
|
return uni_pc.sample_unipc(model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=self.max_denoise(model_wrap, sigmas), extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2', disable=disable_pbar)
|
||||||
|
|
||||||
|
KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
|
||||||
|
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||||
|
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm"]
|
||||||
|
|
||||||
|
def ksampler(sampler_name):
|
||||||
|
class KSAMPLER(Sampler):
|
||||||
|
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
|
||||||
|
extra_args["denoise_mask"] = denoise_mask
|
||||||
|
model_k = KSamplerX0Inpaint(model_wrap)
|
||||||
|
model_k.latent_image = latent_image
|
||||||
|
model_k.noise = noise
|
||||||
|
|
||||||
|
if self.max_denoise(model_wrap, sigmas):
|
||||||
|
noise = noise * torch.sqrt(1.0 + sigmas[0] ** 2.0)
|
||||||
|
else:
|
||||||
|
noise = noise * sigmas[0]
|
||||||
|
|
||||||
|
k_callback = None
|
||||||
|
total_steps = len(sigmas) - 1
|
||||||
|
if callback is not None:
|
||||||
|
k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps)
|
||||||
|
|
||||||
|
sigma_min = sigmas[-1]
|
||||||
|
if sigma_min == 0:
|
||||||
|
sigma_min = sigmas[-2]
|
||||||
|
|
||||||
|
if latent_image is not None:
|
||||||
|
noise += latent_image
|
||||||
|
if sampler_name == "dpm_fast":
|
||||||
|
samples = k_diffusion_sampling.sample_dpm_fast(model_k, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=k_callback, disable=disable_pbar)
|
||||||
|
elif sampler_name == "dpm_adaptive":
|
||||||
|
samples = k_diffusion_sampling.sample_dpm_adaptive(model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback, disable=disable_pbar)
|
||||||
|
else:
|
||||||
|
samples = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name))(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar)
|
||||||
|
return samples
|
||||||
|
return KSAMPLER
|
||||||
|
|
||||||
|
|
||||||
|
def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
|
||||||
|
positive = positive[:]
|
||||||
|
negative = negative[:]
|
||||||
|
|
||||||
|
resolve_areas_and_cond_masks(positive, noise.shape[2], noise.shape[3], device)
|
||||||
|
resolve_areas_and_cond_masks(negative, noise.shape[2], noise.shape[3], device)
|
||||||
|
|
||||||
|
model_denoise = CFGNoisePredictor(model)
|
||||||
|
if model.model_type == model_base.ModelType.V_PREDICTION:
|
||||||
|
model_wrap = CompVisVDenoiser(model_denoise, quantize=True)
|
||||||
|
else:
|
||||||
|
model_wrap = k_diffusion_external.CompVisDenoiser(model_denoise, quantize=True)
|
||||||
|
|
||||||
|
calculate_start_end_timesteps(model_wrap, negative)
|
||||||
|
calculate_start_end_timesteps(model_wrap, positive)
|
||||||
|
|
||||||
|
#make sure each cond area has an opposite one with the same area
|
||||||
|
for c in positive:
|
||||||
|
create_cond_with_same_area_if_none(negative, c)
|
||||||
|
for c in negative:
|
||||||
|
create_cond_with_same_area_if_none(positive, c)
|
||||||
|
|
||||||
|
pre_run_control(model_wrap, negative + positive)
|
||||||
|
|
||||||
|
apply_empty_x_to_equal_area(list(filter(lambda c: c[1].get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x])
|
||||||
|
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
|
||||||
|
|
||||||
|
if model.is_adm():
|
||||||
|
positive = encode_adm(model, positive, noise.shape[0], noise.shape[3], noise.shape[2], device, "positive")
|
||||||
|
negative = encode_adm(model, negative, noise.shape[0], noise.shape[3], noise.shape[2], device, "negative")
|
||||||
|
|
||||||
|
if latent_image is not None:
|
||||||
|
latent_image = model.process_latent_in(latent_image)
|
||||||
|
|
||||||
|
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed}
|
||||||
|
|
||||||
|
cond_concat = None
|
||||||
|
if hasattr(model, 'concat_keys'): #inpaint
|
||||||
|
cond_concat = []
|
||||||
|
for ck in model.concat_keys:
|
||||||
|
if denoise_mask is not None:
|
||||||
|
if ck == "mask":
|
||||||
|
cond_concat.append(denoise_mask[:,:1])
|
||||||
|
elif ck == "masked_image":
|
||||||
|
cond_concat.append(latent_image) #NOTE: the latent_image should be masked by the mask in pixel space
|
||||||
|
else:
|
||||||
|
if ck == "mask":
|
||||||
|
cond_concat.append(torch.ones_like(noise)[:,:1])
|
||||||
|
elif ck == "masked_image":
|
||||||
|
cond_concat.append(blank_inpaint_image_like(noise))
|
||||||
|
extra_args["cond_concat"] = cond_concat
|
||||||
|
|
||||||
|
samples = sampler.sample(model_wrap, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
|
||||||
|
return model.process_latent_out(samples.to(torch.float32))
|
||||||
|
|
||||||
class KSampler:
|
class KSampler:
|
||||||
SCHEDULERS = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"]
|
SCHEDULERS = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"]
|
||||||
SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
|
SAMPLERS = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"]
|
||||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
|
|
||||||
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "ddim", "uni_pc", "uni_pc_bh2"]
|
|
||||||
|
|
||||||
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}):
|
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}):
|
||||||
self.model = model
|
self.model = model
|
||||||
@ -628,117 +769,13 @@ class KSampler:
|
|||||||
else:
|
else:
|
||||||
return torch.zeros_like(noise)
|
return torch.zeros_like(noise)
|
||||||
|
|
||||||
positive = positive[:]
|
|
||||||
negative = negative[:]
|
|
||||||
|
|
||||||
resolve_areas_and_cond_masks(positive, noise.shape[2], noise.shape[3], self.device)
|
|
||||||
resolve_areas_and_cond_masks(negative, noise.shape[2], noise.shape[3], self.device)
|
|
||||||
|
|
||||||
calculate_start_end_timesteps(self.model_wrap, negative)
|
|
||||||
calculate_start_end_timesteps(self.model_wrap, positive)
|
|
||||||
|
|
||||||
#make sure each cond area has an opposite one with the same area
|
|
||||||
for c in positive:
|
|
||||||
create_cond_with_same_area_if_none(negative, c)
|
|
||||||
for c in negative:
|
|
||||||
create_cond_with_same_area_if_none(positive, c)
|
|
||||||
|
|
||||||
pre_run_control(self.model_wrap, negative + positive)
|
|
||||||
|
|
||||||
apply_empty_x_to_equal_area(list(filter(lambda c: c[1].get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x])
|
|
||||||
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
|
|
||||||
|
|
||||||
if self.model.is_adm():
|
|
||||||
positive = encode_adm(self.model, positive, noise.shape[0], noise.shape[3], noise.shape[2], self.device, "positive")
|
|
||||||
negative = encode_adm(self.model, negative, noise.shape[0], noise.shape[3], noise.shape[2], self.device, "negative")
|
|
||||||
|
|
||||||
if latent_image is not None:
|
|
||||||
latent_image = self.model.process_latent_in(latent_image)
|
|
||||||
|
|
||||||
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": self.model_options, "seed":seed}
|
|
||||||
|
|
||||||
cond_concat = None
|
|
||||||
if hasattr(self.model, 'concat_keys'): #inpaint
|
|
||||||
cond_concat = []
|
|
||||||
for ck in self.model.concat_keys:
|
|
||||||
if denoise_mask is not None:
|
|
||||||
if ck == "mask":
|
|
||||||
cond_concat.append(denoise_mask[:,:1])
|
|
||||||
elif ck == "masked_image":
|
|
||||||
cond_concat.append(latent_image) #NOTE: the latent_image should be masked by the mask in pixel space
|
|
||||||
else:
|
|
||||||
if ck == "mask":
|
|
||||||
cond_concat.append(torch.ones_like(noise)[:,:1])
|
|
||||||
elif ck == "masked_image":
|
|
||||||
cond_concat.append(blank_inpaint_image_like(noise))
|
|
||||||
extra_args["cond_concat"] = cond_concat
|
|
||||||
|
|
||||||
if sigmas[0] != self.sigmas[0] or (self.denoise is not None and self.denoise < 1.0):
|
|
||||||
max_denoise = False
|
|
||||||
else:
|
|
||||||
max_denoise = True
|
|
||||||
|
|
||||||
|
|
||||||
if self.sampler == "uni_pc":
|
if self.sampler == "uni_pc":
|
||||||
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback, disable=disable_pbar)
|
sampler = UNIPC
|
||||||
elif self.sampler == "uni_pc_bh2":
|
elif self.sampler == "uni_pc_bh2":
|
||||||
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2', disable=disable_pbar)
|
sampler = UNIPCBH2
|
||||||
elif self.sampler == "ddim":
|
elif self.sampler == "ddim":
|
||||||
timesteps = []
|
sampler = DDIM
|
||||||
for s in range(sigmas.shape[0]):
|
|
||||||
timesteps.insert(0, self.model_wrap.sigma_to_discrete_timestep(sigmas[s]))
|
|
||||||
noise_mask = None
|
|
||||||
if denoise_mask is not None:
|
|
||||||
noise_mask = 1.0 - denoise_mask
|
|
||||||
|
|
||||||
ddim_callback = None
|
|
||||||
if callback is not None:
|
|
||||||
total_steps = len(timesteps) - 1
|
|
||||||
ddim_callback = lambda pred_x0, i: callback(i, pred_x0, None, total_steps)
|
|
||||||
|
|
||||||
sampler = DDIMSampler(self.model, device=self.device)
|
|
||||||
sampler.make_schedule_timesteps(ddim_timesteps=timesteps, verbose=False)
|
|
||||||
z_enc = sampler.stochastic_encode(latent_image, torch.tensor([len(timesteps) - 1] * noise.shape[0]).to(self.device), noise=noise, max_denoise=max_denoise)
|
|
||||||
samples, _ = sampler.sample_custom(ddim_timesteps=timesteps,
|
|
||||||
conditioning=positive,
|
|
||||||
batch_size=noise.shape[0],
|
|
||||||
shape=noise.shape[1:],
|
|
||||||
verbose=False,
|
|
||||||
unconditional_guidance_scale=cfg,
|
|
||||||
unconditional_conditioning=negative,
|
|
||||||
eta=0.0,
|
|
||||||
x_T=z_enc,
|
|
||||||
x0=latent_image,
|
|
||||||
img_callback=ddim_callback,
|
|
||||||
denoise_function=self.model_wrap.predict_eps_discrete_timestep,
|
|
||||||
extra_args=extra_args,
|
|
||||||
mask=noise_mask,
|
|
||||||
to_zero=sigmas[-1]==0,
|
|
||||||
end_step=sigmas.shape[0] - 1,
|
|
||||||
disable_pbar=disable_pbar)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
extra_args["denoise_mask"] = denoise_mask
|
sampler = ksampler(self.sampler)
|
||||||
self.model_k.latent_image = latent_image
|
|
||||||
self.model_k.noise = noise
|
|
||||||
|
|
||||||
if max_denoise:
|
return sample(self.model, noise, positive, negative, cfg, self.device, sampler(), sigmas, self.model_options, latent_image=latent_image, denoise_mask=denoise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
||||||
noise = noise * torch.sqrt(1.0 + sigmas[0] ** 2.0)
|
|
||||||
else:
|
|
||||||
noise = noise * sigmas[0]
|
|
||||||
|
|
||||||
k_callback = None
|
|
||||||
total_steps = len(sigmas) - 1
|
|
||||||
if callback is not None:
|
|
||||||
k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps)
|
|
||||||
|
|
||||||
if latent_image is not None:
|
|
||||||
noise += latent_image
|
|
||||||
if self.sampler == "dpm_fast":
|
|
||||||
samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=k_callback, disable=disable_pbar)
|
|
||||||
elif self.sampler == "dpm_adaptive":
|
|
||||||
samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback, disable=disable_pbar)
|
|
||||||
else:
|
|
||||||
samples = getattr(k_diffusion_sampling, "sample_{}".format(self.sampler))(self.model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar)
|
|
||||||
|
|
||||||
return self.model.process_latent_out(samples.to(torch.float32))
|
|
||||||
|
Loading…
Reference in New Issue
Block a user