refactor/split various bits of code for sampling

This commit is contained in:
BlenderNeko 2023-04-23 20:02:08 +02:00
parent 989acd769a
commit 2a09e2aa27
3 changed files with 108 additions and 78 deletions

62
comfy/sample.py Normal file
View File

@ -0,0 +1,62 @@
import torch
import comfy.model_management
def prepare_noise(latent, seed, disable_noise):
latent_image = latent["samples"]
if disable_noise:
noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
else:
batch_index = 0
if "batch_index" in latent:
batch_index = latent["batch_index"]
generator = torch.manual_seed(seed)
for i in range(batch_index):
noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
return noise
def create_mask(latent, noise):
noise_mask = None
device = comfy.model_management.get_torch_device()
if "noise_mask" in latent:
noise_mask = latent['noise_mask']
noise_mask = torch.nn.functional.interpolate(noise_mask[None,None,], size=(noise.shape[2], noise.shape[3]), mode="bilinear")
noise_mask = noise_mask.round()
noise_mask = torch.cat([noise_mask] * noise.shape[1], dim=1)
noise_mask = torch.cat([noise_mask] * noise.shape[0])
noise_mask = noise_mask.to(device)
return noise_mask
def broadcast_cond(cond, noise):
device = comfy.model_management.get_torch_device()
copy = []
for p in cond:
t = p[0]
if t.shape[0] < noise.shape[0]:
t = torch.cat([t] * noise.shape[0])
t = t.to(device)
copy += [[t] + p[1:]]
return copy
def load_c_nets(positive, negative):
def get_models(cond):
models = []
for c in cond:
if 'control' in c[1]:
models += [c[1]['control']]
if 'gligen' in c[1]:
models += [c[1]['gligen'][1]]
return models
return get_models(positive) + get_models(negative)
def load_additional_models(positive, negative):
models = load_c_nets(positive, negative)
comfy.model_management.load_controlnet_gpu(models)
return models
def cleanup_additional_models(models):
for m in models:
m.cleanup()

View File

@ -392,6 +392,38 @@ def encode_adm(noise_augmentor, conds, batch_size, device):
return conds return conds
def calculate_sigmas(model, steps, scheduler, sampler):
"""
Returns a tensor containing the sigmas corresponding to the given model, number of steps, scheduler type and sample technique
"""
if not (isinstance(model, CompVisVDenoiser) or isinstance(model, k_diffusion_external.CompVisDenoiser)):
model = CFGNoisePredictor(model)
if model.inner_model.parameterization == "v":
model = CompVisVDenoiser(model, quantize=True)
else:
model = k_diffusion_external.CompVisDenoiser(model, quantize=True)
sigmas = None
discard_penultimate_sigma = False
if sampler in ['dpm_2', 'dpm_2_ancestral']:
steps += 1
discard_penultimate_sigma = True
if scheduler == "karras":
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model.sigma_min), sigma_max=float(model.sigma_max))
elif scheduler == "normal":
sigmas = model.get_sigmas(steps)
elif scheduler == "simple":
sigmas = simple_scheduler(model, steps)
elif scheduler == "ddim_uniform":
sigmas = ddim_scheduler(model, steps)
else:
print("error invalid scheduler", scheduler)
if discard_penultimate_sigma:
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
return sigmas
class KSampler: class KSampler:
SCHEDULERS = ["karras", "normal", "simple", "ddim_uniform"] SCHEDULERS = ["karras", "normal", "simple", "ddim_uniform"]
@ -421,41 +453,19 @@ class KSampler:
self.denoise = denoise self.denoise = denoise
self.model_options = model_options self.model_options = model_options
def _calculate_sigmas(self, steps):
sigmas = None
discard_penultimate_sigma = False
if self.sampler in ['dpm_2', 'dpm_2_ancestral']:
steps += 1
discard_penultimate_sigma = True
if self.scheduler == "karras":
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max, device=self.device)
elif self.scheduler == "normal":
sigmas = self.model_wrap.get_sigmas(steps).to(self.device)
elif self.scheduler == "simple":
sigmas = simple_scheduler(self.model_wrap, steps).to(self.device)
elif self.scheduler == "ddim_uniform":
sigmas = ddim_scheduler(self.model_wrap, steps).to(self.device)
else:
print("error invalid scheduler", self.scheduler)
if discard_penultimate_sigma:
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
return sigmas
def set_steps(self, steps, denoise=None): def set_steps(self, steps, denoise=None):
self.steps = steps self.steps = steps
if denoise is None or denoise > 0.9999: if denoise is None or denoise > 0.9999:
self.sigmas = self._calculate_sigmas(steps) self.sigmas = calculate_sigmas(self.model_wrap, steps, self.scheduler, self.sampler).to(self.device)
else: else:
new_steps = int(steps/denoise) new_steps = int(steps/denoise)
sigmas = self._calculate_sigmas(new_steps) sigmas = calculate_sigmas(self.model_wrap, new_steps, self.scheduler, self.sampler).to(self.device)
self.sigmas = sigmas[-(steps + 1):] self.sigmas = sigmas[-(steps + 1):]
def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None): def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None, sigmas=None):
sigmas = self.sigmas if sigmas is None:
sigmas = self.sigmas
sigma_min = self.sigma_min sigma_min = self.sigma_min
if last_step is not None and last_step < (len(sigmas) - 1): if last_step is not None and last_step < (len(sigmas) - 1):

View File

@ -16,6 +16,7 @@ sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "co
import comfy.diffusers_convert import comfy.diffusers_convert
import comfy.samplers import comfy.samplers
import comfy.sample
import comfy.sd import comfy.sd
import comfy.utils import comfy.utils
@ -739,31 +740,12 @@ class SetLatentNoiseMask:
s["noise_mask"] = mask s["noise_mask"] = mask
return (s,) return (s,)
def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False): def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False):
latent_image = latent["samples"]
noise_mask = None
device = comfy.model_management.get_torch_device() device = comfy.model_management.get_torch_device()
latent_image = latent["samples"]
if disable_noise: noise = comfy.sample.prepare_noise(latent, seed, disable_noise)
noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") noise_mask = comfy.sample.create_mask(latent, noise)
else:
batch_index = 0
if "batch_index" in latent:
batch_index = latent["batch_index"]
generator = torch.manual_seed(seed)
for i in range(batch_index):
noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
if "noise_mask" in latent:
noise_mask = latent['noise_mask']
noise_mask = torch.nn.functional.interpolate(noise_mask[None,None,], size=(noise.shape[2], noise.shape[3]), mode="bilinear")
noise_mask = noise_mask.round()
noise_mask = torch.cat([noise_mask] * noise.shape[1], dim=1)
noise_mask = torch.cat([noise_mask] * noise.shape[0])
noise_mask = noise_mask.to(device)
real_model = None real_model = None
comfy.model_management.load_model_gpu(model) comfy.model_management.load_model_gpu(model)
@ -772,34 +754,10 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
noise = noise.to(device) noise = noise.to(device)
latent_image = latent_image.to(device) latent_image = latent_image.to(device)
positive_copy = [] positive_copy = comfy.sample.broadcast_cond(positive, noise)
negative_copy = [] negative_copy = comfy.sample.broadcast_cond(negative, noise)
control_nets = [] models = comfy.sample.load_additional_models(positive, negative)
def get_models(cond):
models = []
for c in cond:
if 'control' in c[1]:
models += [c[1]['control']]
if 'gligen' in c[1]:
models += [c[1]['gligen'][1]]
return models
for p in positive:
t = p[0]
if t.shape[0] < noise.shape[0]:
t = torch.cat([t] * noise.shape[0])
t = t.to(device)
positive_copy += [[t] + p[1:]]
for n in negative:
t = n[0]
if t.shape[0] < noise.shape[0]:
t = torch.cat([t] * noise.shape[0])
t = t.to(device)
negative_copy += [[t] + n[1:]]
models = get_models(positive) + get_models(negative)
comfy.model_management.load_controlnet_gpu(models)
if sampler_name in comfy.samplers.KSampler.SAMPLERS: if sampler_name in comfy.samplers.KSampler.SAMPLERS:
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options) sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
@ -809,8 +767,8 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask) samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask)
samples = samples.cpu() samples = samples.cpu()
for m in models:
m.cleanup() comfy.sample.cleanup_additional_models(models)
out = latent.copy() out = latent.copy()
out["samples"] = samples out["samples"] = samples