Add KSamplerAdvanced node.

This node exposes more sampling options and makes it possible for example
to sample the first few steps on the latent image, do some operations on it
 and then do the rest of the sampling steps. This can be achieved using the
start_at_step and end_at_step options.
This commit is contained in:
comfyanonymous 2023-01-31 03:09:38 -05:00
parent f8f165e2c3
commit 69df7eba94
2 changed files with 83 additions and 31 deletions

View File

@ -168,15 +168,24 @@ class KSampler:
self.sigmas = sigmas[-(steps + 1):] self.sigmas = sigmas[-(steps + 1):]
def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None): def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False):
sigmas = self.sigmas sigmas = self.sigmas
sigma_min = self.sigma_min sigma_min = self.sigma_min
if last_step is not None: if last_step is not None and last_step < (len(sigmas) - 1):
sigma_min = sigmas[last_step] sigma_min = sigmas[last_step]
sigmas = sigmas[:last_step + 1] sigmas = sigmas[:last_step + 1]
if force_full_denoise:
sigmas[-1] = 0
if start_step is not None: if start_step is not None:
sigmas = sigmas[start_step:] if start_step < (len(sigmas) - 1):
sigmas = sigmas[start_step:]
else:
if latent_image is not None:
return latent_image
else:
return torch.zeros_like(noise)
noise *= sigmas[0] noise *= sigmas[0]
if latent_image is not None: if latent_image is not None:

View File

@ -221,13 +221,50 @@ class LatentRotate:
s = torch.rot90(samples, k=rotate_by, dims=[3, 2]) s = torch.rot90(samples, k=rotate_by, dims=[3, 2])
return (s,) return (s,)
def common_ksampler(device, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False):
if disable_noise:
noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
else:
noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=torch.manual_seed(seed), device="cpu")
model = model.to(device)
noise = noise.to(device)
latent_image = latent_image.to(device)
positive_copy = []
negative_copy = []
for p in positive:
t = p[0]
if t.shape[0] < noise.shape[0]:
t = torch.cat([t] * noise.shape[0])
t = t.to(device)
positive_copy += [[t] + p[1:]]
for n in negative:
t = n[0]
if t.shape[0] < noise.shape[0]:
t = torch.cat([t] * noise.shape[0])
t = t.to(device)
negative_copy += [[t] + n[1:]]
if sampler_name in comfy.samplers.KSampler.SAMPLERS:
sampler = comfy.samplers.KSampler(model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise)
else:
#other samplers
pass
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise)
samples = samples.cpu()
model = model.cpu()
return (samples, )
class KSampler: class KSampler:
def __init__(self, device="cuda"): def __init__(self, device="cuda"):
self.device = device self.device = device
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": return {"required":
{"model": ("MODEL",), {"model": ("MODEL",),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
"steps": ("INT", {"default": 20, "min": 1, "max": 10000}), "steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
@ -246,38 +283,43 @@ class KSampler:
CATEGORY = "sampling" CATEGORY = "sampling"
def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0): def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0):
noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=torch.manual_seed(seed), device="cpu") return common_ksampler(self.device, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise)
model = model.to(self.device)
noise = noise.to(self.device)
latent_image = latent_image.to(self.device)
positive_copy = [] class KSamplerAdvanced:
negative_copy = [] def __init__(self, device="cuda"):
self.device = device
for p in positive: @classmethod
t = p[0] def INPUT_TYPES(s):
if t.shape[0] < noise.shape[0]: return {"required":
t = torch.cat([t] * noise.shape[0]) {"model": ("MODEL",),
t = t.to(self.device) "add_noise": (["enable", "disable"], ),
positive_copy += [[t] + p[1:]] "noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
for n in negative: "steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
t = n[0] "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0}),
if t.shape[0] < noise.shape[0]: "sampler_name": (comfy.samplers.KSampler.SAMPLERS, ),
t = torch.cat([t] * noise.shape[0]) "scheduler": (comfy.samplers.KSampler.SCHEDULERS, ),
t = t.to(self.device) "positive": ("CONDITIONING", ),
negative_copy += [[t] + n[1:]] "negative": ("CONDITIONING", ),
"latent_image": ("LATENT", ),
"start_at_step": ("INT", {"default": 0, "min": 0, "max": 10000}),
"end_at_step": ("INT", {"default": 10000, "min": 0, "max": 10000}),
"return_with_leftover_noise": (["disable", "enable"], ),
}}
if sampler_name in comfy.samplers.KSampler.SAMPLERS: RETURN_TYPES = ("LATENT",)
sampler = comfy.samplers.KSampler(model, steps=steps, device=self.device, sampler=sampler_name, scheduler=scheduler, denoise=denoise) FUNCTION = "sample"
else:
#other samplers
pass
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image) CATEGORY = "sampling"
samples = samples.cpu()
model = model.cpu()
return (samples, )
def sample(self, model, add_noise, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, start_at_step, end_at_step, return_with_leftover_noise, denoise=1.0):
force_full_denoise = True
if return_with_leftover_noise == "enable":
force_full_denoise = False
disable_noise = False
if add_noise == "disable":
disable_noise = True
return common_ksampler(self.device, model, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, disable_noise=disable_noise, start_step=start_at_step, last_step=end_at_step, force_full_denoise=force_full_denoise)
class SaveImage: class SaveImage:
def __init__(self): def __init__(self):
@ -365,6 +407,7 @@ NODE_CLASS_MAPPINGS = {
"LoadImage": LoadImage, "LoadImage": LoadImage,
"ConditioningCombine": ConditioningCombine, "ConditioningCombine": ConditioningCombine,
"ConditioningSetArea": ConditioningSetArea, "ConditioningSetArea": ConditioningSetArea,
"KSamplerAdvanced": KSamplerAdvanced,
"LatentRotate": LatentRotate, "LatentRotate": LatentRotate,
} }