mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Sampling code refactor.
This commit is contained in:
parent
aeba1cc2a0
commit
446caf711c
@ -59,7 +59,7 @@ class DDIMSampler(object):
|
||||
@torch.no_grad()
|
||||
def sample_custom(self,
|
||||
ddim_timesteps,
|
||||
conditioning,
|
||||
conditioning=None,
|
||||
callback=None,
|
||||
img_callback=None,
|
||||
quantize_x0=False,
|
||||
|
@ -544,11 +544,152 @@ def encode_adm(model, conds, batch_size, width, height, device, prompt_type):
|
||||
return conds
|
||||
|
||||
|
||||
class Sampler:
|
||||
def sample(self):
|
||||
pass
|
||||
|
||||
def max_denoise(self, model_wrap, sigmas):
|
||||
return math.isclose(float(model_wrap.sigma_max), float(sigmas[0]))
|
||||
|
||||
class DDIM(Sampler):
|
||||
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
|
||||
timesteps = []
|
||||
for s in range(sigmas.shape[0]):
|
||||
timesteps.insert(0, model_wrap.sigma_to_discrete_timestep(sigmas[s]))
|
||||
noise_mask = None
|
||||
if denoise_mask is not None:
|
||||
noise_mask = 1.0 - denoise_mask
|
||||
|
||||
ddim_callback = None
|
||||
if callback is not None:
|
||||
total_steps = len(timesteps) - 1
|
||||
ddim_callback = lambda pred_x0, i: callback(i, pred_x0, None, total_steps)
|
||||
|
||||
max_denoise = self.max_denoise(model_wrap, sigmas)
|
||||
|
||||
ddim_sampler = DDIMSampler(model_wrap.inner_model.inner_model, device=noise.device)
|
||||
ddim_sampler.make_schedule_timesteps(ddim_timesteps=timesteps, verbose=False)
|
||||
z_enc = ddim_sampler.stochastic_encode(latent_image, torch.tensor([len(timesteps) - 1] * noise.shape[0]).to(noise.device), noise=noise, max_denoise=max_denoise)
|
||||
samples, _ = ddim_sampler.sample_custom(ddim_timesteps=timesteps,
|
||||
batch_size=noise.shape[0],
|
||||
shape=noise.shape[1:],
|
||||
verbose=False,
|
||||
eta=0.0,
|
||||
x_T=z_enc,
|
||||
x0=latent_image,
|
||||
img_callback=ddim_callback,
|
||||
denoise_function=model_wrap.predict_eps_discrete_timestep,
|
||||
extra_args=extra_args,
|
||||
mask=noise_mask,
|
||||
to_zero=sigmas[-1]==0,
|
||||
end_step=sigmas.shape[0] - 1,
|
||||
disable_pbar=disable_pbar)
|
||||
return samples
|
||||
|
||||
class UNIPC(Sampler):
|
||||
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
|
||||
return uni_pc.sample_unipc(model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=self.max_denoise(model_wrap, sigmas), extra_args=extra_args, noise_mask=denoise_mask, callback=callback, disable=disable_pbar)
|
||||
|
||||
class UNIPCBH2(Sampler):
|
||||
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
|
||||
return uni_pc.sample_unipc(model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=self.max_denoise(model_wrap, sigmas), extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2', disable=disable_pbar)
|
||||
|
||||
KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
|
||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm"]
|
||||
|
||||
def ksampler(sampler_name):
|
||||
class KSAMPLER(Sampler):
|
||||
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
|
||||
extra_args["denoise_mask"] = denoise_mask
|
||||
model_k = KSamplerX0Inpaint(model_wrap)
|
||||
model_k.latent_image = latent_image
|
||||
model_k.noise = noise
|
||||
|
||||
if self.max_denoise(model_wrap, sigmas):
|
||||
noise = noise * torch.sqrt(1.0 + sigmas[0] ** 2.0)
|
||||
else:
|
||||
noise = noise * sigmas[0]
|
||||
|
||||
k_callback = None
|
||||
total_steps = len(sigmas) - 1
|
||||
if callback is not None:
|
||||
k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps)
|
||||
|
||||
sigma_min = sigmas[-1]
|
||||
if sigma_min == 0:
|
||||
sigma_min = sigmas[-2]
|
||||
|
||||
if latent_image is not None:
|
||||
noise += latent_image
|
||||
if sampler_name == "dpm_fast":
|
||||
samples = k_diffusion_sampling.sample_dpm_fast(model_k, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=k_callback, disable=disable_pbar)
|
||||
elif sampler_name == "dpm_adaptive":
|
||||
samples = k_diffusion_sampling.sample_dpm_adaptive(model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback, disable=disable_pbar)
|
||||
else:
|
||||
samples = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name))(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar)
|
||||
return samples
|
||||
return KSAMPLER
|
||||
|
||||
|
||||
def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
|
||||
positive = positive[:]
|
||||
negative = negative[:]
|
||||
|
||||
resolve_areas_and_cond_masks(positive, noise.shape[2], noise.shape[3], device)
|
||||
resolve_areas_and_cond_masks(negative, noise.shape[2], noise.shape[3], device)
|
||||
|
||||
model_denoise = CFGNoisePredictor(model)
|
||||
if model.model_type == model_base.ModelType.V_PREDICTION:
|
||||
model_wrap = CompVisVDenoiser(model_denoise, quantize=True)
|
||||
else:
|
||||
model_wrap = k_diffusion_external.CompVisDenoiser(model_denoise, quantize=True)
|
||||
|
||||
calculate_start_end_timesteps(model_wrap, negative)
|
||||
calculate_start_end_timesteps(model_wrap, positive)
|
||||
|
||||
#make sure each cond area has an opposite one with the same area
|
||||
for c in positive:
|
||||
create_cond_with_same_area_if_none(negative, c)
|
||||
for c in negative:
|
||||
create_cond_with_same_area_if_none(positive, c)
|
||||
|
||||
pre_run_control(model_wrap, negative + positive)
|
||||
|
||||
apply_empty_x_to_equal_area(list(filter(lambda c: c[1].get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x])
|
||||
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
|
||||
|
||||
if model.is_adm():
|
||||
positive = encode_adm(model, positive, noise.shape[0], noise.shape[3], noise.shape[2], device, "positive")
|
||||
negative = encode_adm(model, negative, noise.shape[0], noise.shape[3], noise.shape[2], device, "negative")
|
||||
|
||||
if latent_image is not None:
|
||||
latent_image = model.process_latent_in(latent_image)
|
||||
|
||||
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed}
|
||||
|
||||
cond_concat = None
|
||||
if hasattr(model, 'concat_keys'): #inpaint
|
||||
cond_concat = []
|
||||
for ck in model.concat_keys:
|
||||
if denoise_mask is not None:
|
||||
if ck == "mask":
|
||||
cond_concat.append(denoise_mask[:,:1])
|
||||
elif ck == "masked_image":
|
||||
cond_concat.append(latent_image) #NOTE: the latent_image should be masked by the mask in pixel space
|
||||
else:
|
||||
if ck == "mask":
|
||||
cond_concat.append(torch.ones_like(noise)[:,:1])
|
||||
elif ck == "masked_image":
|
||||
cond_concat.append(blank_inpaint_image_like(noise))
|
||||
extra_args["cond_concat"] = cond_concat
|
||||
|
||||
samples = sampler.sample(model_wrap, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
|
||||
return model.process_latent_out(samples.to(torch.float32))
|
||||
|
||||
class KSampler:
|
||||
SCHEDULERS = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"]
|
||||
SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
|
||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "ddim", "uni_pc", "uni_pc_bh2"]
|
||||
SAMPLERS = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"]
|
||||
|
||||
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}):
|
||||
self.model = model
|
||||
@ -628,117 +769,13 @@ class KSampler:
|
||||
else:
|
||||
return torch.zeros_like(noise)
|
||||
|
||||
positive = positive[:]
|
||||
negative = negative[:]
|
||||
|
||||
resolve_areas_and_cond_masks(positive, noise.shape[2], noise.shape[3], self.device)
|
||||
resolve_areas_and_cond_masks(negative, noise.shape[2], noise.shape[3], self.device)
|
||||
|
||||
calculate_start_end_timesteps(self.model_wrap, negative)
|
||||
calculate_start_end_timesteps(self.model_wrap, positive)
|
||||
|
||||
#make sure each cond area has an opposite one with the same area
|
||||
for c in positive:
|
||||
create_cond_with_same_area_if_none(negative, c)
|
||||
for c in negative:
|
||||
create_cond_with_same_area_if_none(positive, c)
|
||||
|
||||
pre_run_control(self.model_wrap, negative + positive)
|
||||
|
||||
apply_empty_x_to_equal_area(list(filter(lambda c: c[1].get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x])
|
||||
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
|
||||
|
||||
if self.model.is_adm():
|
||||
positive = encode_adm(self.model, positive, noise.shape[0], noise.shape[3], noise.shape[2], self.device, "positive")
|
||||
negative = encode_adm(self.model, negative, noise.shape[0], noise.shape[3], noise.shape[2], self.device, "negative")
|
||||
|
||||
if latent_image is not None:
|
||||
latent_image = self.model.process_latent_in(latent_image)
|
||||
|
||||
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": self.model_options, "seed":seed}
|
||||
|
||||
cond_concat = None
|
||||
if hasattr(self.model, 'concat_keys'): #inpaint
|
||||
cond_concat = []
|
||||
for ck in self.model.concat_keys:
|
||||
if denoise_mask is not None:
|
||||
if ck == "mask":
|
||||
cond_concat.append(denoise_mask[:,:1])
|
||||
elif ck == "masked_image":
|
||||
cond_concat.append(latent_image) #NOTE: the latent_image should be masked by the mask in pixel space
|
||||
else:
|
||||
if ck == "mask":
|
||||
cond_concat.append(torch.ones_like(noise)[:,:1])
|
||||
elif ck == "masked_image":
|
||||
cond_concat.append(blank_inpaint_image_like(noise))
|
||||
extra_args["cond_concat"] = cond_concat
|
||||
|
||||
if sigmas[0] != self.sigmas[0] or (self.denoise is not None and self.denoise < 1.0):
|
||||
max_denoise = False
|
||||
else:
|
||||
max_denoise = True
|
||||
|
||||
|
||||
if self.sampler == "uni_pc":
|
||||
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback, disable=disable_pbar)
|
||||
sampler = UNIPC
|
||||
elif self.sampler == "uni_pc_bh2":
|
||||
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2', disable=disable_pbar)
|
||||
sampler = UNIPCBH2
|
||||
elif self.sampler == "ddim":
|
||||
timesteps = []
|
||||
for s in range(sigmas.shape[0]):
|
||||
timesteps.insert(0, self.model_wrap.sigma_to_discrete_timestep(sigmas[s]))
|
||||
noise_mask = None
|
||||
if denoise_mask is not None:
|
||||
noise_mask = 1.0 - denoise_mask
|
||||
|
||||
ddim_callback = None
|
||||
if callback is not None:
|
||||
total_steps = len(timesteps) - 1
|
||||
ddim_callback = lambda pred_x0, i: callback(i, pred_x0, None, total_steps)
|
||||
|
||||
sampler = DDIMSampler(self.model, device=self.device)
|
||||
sampler.make_schedule_timesteps(ddim_timesteps=timesteps, verbose=False)
|
||||
z_enc = sampler.stochastic_encode(latent_image, torch.tensor([len(timesteps) - 1] * noise.shape[0]).to(self.device), noise=noise, max_denoise=max_denoise)
|
||||
samples, _ = sampler.sample_custom(ddim_timesteps=timesteps,
|
||||
conditioning=positive,
|
||||
batch_size=noise.shape[0],
|
||||
shape=noise.shape[1:],
|
||||
verbose=False,
|
||||
unconditional_guidance_scale=cfg,
|
||||
unconditional_conditioning=negative,
|
||||
eta=0.0,
|
||||
x_T=z_enc,
|
||||
x0=latent_image,
|
||||
img_callback=ddim_callback,
|
||||
denoise_function=self.model_wrap.predict_eps_discrete_timestep,
|
||||
extra_args=extra_args,
|
||||
mask=noise_mask,
|
||||
to_zero=sigmas[-1]==0,
|
||||
end_step=sigmas.shape[0] - 1,
|
||||
disable_pbar=disable_pbar)
|
||||
|
||||
sampler = DDIM
|
||||
else:
|
||||
extra_args["denoise_mask"] = denoise_mask
|
||||
self.model_k.latent_image = latent_image
|
||||
self.model_k.noise = noise
|
||||
sampler = ksampler(self.sampler)
|
||||
|
||||
if max_denoise:
|
||||
noise = noise * torch.sqrt(1.0 + sigmas[0] ** 2.0)
|
||||
else:
|
||||
noise = noise * sigmas[0]
|
||||
|
||||
k_callback = None
|
||||
total_steps = len(sigmas) - 1
|
||||
if callback is not None:
|
||||
k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps)
|
||||
|
||||
if latent_image is not None:
|
||||
noise += latent_image
|
||||
if self.sampler == "dpm_fast":
|
||||
samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=k_callback, disable=disable_pbar)
|
||||
elif self.sampler == "dpm_adaptive":
|
||||
samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback, disable=disable_pbar)
|
||||
else:
|
||||
samples = getattr(k_diffusion_sampling, "sample_{}".format(self.sampler))(self.model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar)
|
||||
|
||||
return self.model.process_latent_out(samples.to(torch.float32))
|
||||
return sample(self.model, noise, positive, negative, cfg, self.device, sampler(), sigmas, self.model_options, latent_image=latent_image, denoise_mask=denoise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
||||
|
Loading…
Reference in New Issue
Block a user