Make new LATENT_PREVIEWER type for declaring KSampler preview methods

This commit is contained in:
space-nuko 2023-05-30 23:43:22 -05:00
parent a9fa2d3727
commit f326a0a468

View File

@ -34,6 +34,11 @@ import importlib
import folder_paths
class LatentPreviewer:
def decode_latent_to_preview(self, device, x0):
pass
def before_node_execution():
comfy.model_management.throw_exception_if_processing_interrupted()
@ -282,6 +287,27 @@ class TAESDEncode:
samples = taesd.encoder(pixels.permute(0, 3, 1, 2).to(device)).to(device)
return ({"samples": samples}, )
class TAESDPreviewerImpl(LatentPreviewer):
def __init__(self, taesd):
self.taesd = taesd
def decode_latent_to_preview(self, device, x0):
x_sample = self.taesd.decoder(x0.to(device))[0].detach()
x_sample = self.taesd.unscale_latents(x_sample) # returns value in [-2, 2]
x_sample = x_sample * 0.5
return x_sample
class TAESDPreviewer:
@classmethod
def INPUT_TYPES(s):
return {"required": { "taesd": ("TAESD", ), }}
RETURN_TYPES = ("LATENT_PREVIEWER",)
FUNCTION = "make_previewer"
CATEGORY = "latent/previewer"
def make_previewer(self, taesd):
return (TAESDPreviewerImpl(taesd), )
class SaveLatent:
def __init__(self):
@ -986,10 +1012,8 @@ class SetLatentNoiseMask:
return (s,)
def decode_latent_to_preview_image(taesd, device, preview_format, x0):
x_sample = taesd.decoder(x0.to(device))[0].detach()
x_sample = taesd.unscale_latents(x_sample) # returns value in [-2, 2]
x_sample = x_sample * 0.5
def decode_latent_to_preview_image(previewer, device, preview_format, x0):
x_sample = previewer.decode_latent_to_preview(device, x0)
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
@ -1015,7 +1039,7 @@ def decode_latent_to_preview_image(taesd, device, preview_format, x0):
return preview_bytes
def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, taesd=None):
def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, previewer=None):
device = comfy.model_management.get_torch_device()
latent_image = latent["samples"]
@ -1036,8 +1060,8 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
pbar = comfy.utils.ProgressBar(steps)
def callback(step, x0, x, total_steps):
preview_bytes = None
if taesd:
preview_bytes = decode_latent_to_preview_image(taesd, device, preview_format, x0)
if previewer:
preview_bytes = decode_latent_to_preview_image(previewer, device, preview_format, x0)
pbar.update_absolute(step + 1, total_steps, preview_bytes)
samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
@ -1063,7 +1087,7 @@ class KSampler:
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
},
"optional": {
"taesd": ("TAESD",)
"previewer": ("LATENT_PREVIEWER",)
}}
RETURN_TYPES = ("LATENT",)
@ -1071,8 +1095,8 @@ class KSampler:
CATEGORY = "sampling"
def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, taesd=None):
return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, taesd=taesd)
def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, previewer=None):
return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, previewer=previewer)
class KSamplerAdvanced:
@classmethod
@ -1093,7 +1117,7 @@ class KSamplerAdvanced:
"return_with_leftover_noise": (["disable", "enable"], ),
},
"optional": {
"taesd": ("TAESD",)
"previewer": ("LATENT_PREVIEWER",)
}}
RETURN_TYPES = ("LATENT",)
@ -1101,14 +1125,14 @@ class KSamplerAdvanced:
CATEGORY = "sampling"
def sample(self, model, add_noise, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, start_at_step, end_at_step, return_with_leftover_noise, denoise=1.0, taesd=None):
def sample(self, model, add_noise, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, start_at_step, end_at_step, return_with_leftover_noise, denoise=1.0, previewer=None):
force_full_denoise = True
if return_with_leftover_noise == "enable":
force_full_denoise = False
disable_noise = False
if add_noise == "disable":
disable_noise = True
return common_ksampler(model, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, disable_noise=disable_noise, start_step=start_at_step, last_step=end_at_step, force_full_denoise=force_full_denoise, taesd=taesd)
return common_ksampler(model, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, disable_noise=disable_noise, start_step=start_at_step, last_step=end_at_step, force_full_denoise=force_full_denoise, previewer=previewer)
class SaveImage:
def __init__(self):
@ -1369,6 +1393,7 @@ NODE_CLASS_MAPPINGS = {
"VAELoader": VAELoader,
"TAESDDecode": TAESDDecode,
"TAESDEncode": TAESDEncode,
"TAESDPreviewer": TAESDPreviewer,
"TAESDLoader": TAESDLoader,
"EmptyLatentImage": EmptyLatentImage,
"LatentUpscale": LatentUpscale,
@ -1425,6 +1450,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"CheckpointLoaderSimple": "Load Checkpoint",
"VAELoader": "Load VAE",
"TAESDLoader": "Load TAESD",
"TAESDPreviewer": "TAESD Previewer",
"LoraLoader": "Load LoRA",
"CLIPLoader": "Load CLIP",
"ControlNetLoader": "Load ControlNet Model",