Make new LATENT_PREVIEWER type for declaring KSampler preview methods

This commit is contained in:
space-nuko 2023-05-30 23:43:22 -05:00
parent a9fa2d3727
commit f326a0a468

View File

@ -34,6 +34,11 @@ import importlib
import folder_paths import folder_paths
class LatentPreviewer:
def decode_latent_to_preview(self, device, x0):
pass
def before_node_execution(): def before_node_execution():
comfy.model_management.throw_exception_if_processing_interrupted() comfy.model_management.throw_exception_if_processing_interrupted()
@ -282,6 +287,27 @@ class TAESDEncode:
samples = taesd.encoder(pixels.permute(0, 3, 1, 2).to(device)).to(device) samples = taesd.encoder(pixels.permute(0, 3, 1, 2).to(device)).to(device)
return ({"samples": samples}, ) return ({"samples": samples}, )
class TAESDPreviewerImpl(LatentPreviewer):
def __init__(self, taesd):
self.taesd = taesd
def decode_latent_to_preview(self, device, x0):
x_sample = self.taesd.decoder(x0.to(device))[0].detach()
x_sample = self.taesd.unscale_latents(x_sample) # returns value in [-2, 2]
x_sample = x_sample * 0.5
return x_sample
class TAESDPreviewer:
@classmethod
def INPUT_TYPES(s):
return {"required": { "taesd": ("TAESD", ), }}
RETURN_TYPES = ("LATENT_PREVIEWER",)
FUNCTION = "make_previewer"
CATEGORY = "latent/previewer"
def make_previewer(self, taesd):
return (TAESDPreviewerImpl(taesd), )
class SaveLatent: class SaveLatent:
def __init__(self): def __init__(self):
@ -986,10 +1012,8 @@ class SetLatentNoiseMask:
return (s,) return (s,)
def decode_latent_to_preview_image(taesd, device, preview_format, x0): def decode_latent_to_preview_image(previewer, device, preview_format, x0):
x_sample = taesd.decoder(x0.to(device))[0].detach() x_sample = previewer.decode_latent_to_preview(device, x0)
x_sample = taesd.unscale_latents(x_sample) # returns value in [-2, 2]
x_sample = x_sample * 0.5
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0) x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
@ -1015,7 +1039,7 @@ def decode_latent_to_preview_image(taesd, device, preview_format, x0):
return preview_bytes return preview_bytes
def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, taesd=None): def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, previewer=None):
device = comfy.model_management.get_torch_device() device = comfy.model_management.get_torch_device()
latent_image = latent["samples"] latent_image = latent["samples"]
@ -1036,8 +1060,8 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
pbar = comfy.utils.ProgressBar(steps) pbar = comfy.utils.ProgressBar(steps)
def callback(step, x0, x, total_steps): def callback(step, x0, x, total_steps):
preview_bytes = None preview_bytes = None
if taesd: if previewer:
preview_bytes = decode_latent_to_preview_image(taesd, device, preview_format, x0) preview_bytes = decode_latent_to_preview_image(previewer, device, preview_format, x0)
pbar.update_absolute(step + 1, total_steps, preview_bytes) pbar.update_absolute(step + 1, total_steps, preview_bytes)
samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
@ -1063,7 +1087,7 @@ class KSampler:
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}, },
"optional": { "optional": {
"taesd": ("TAESD",) "previewer": ("LATENT_PREVIEWER",)
}} }}
RETURN_TYPES = ("LATENT",) RETURN_TYPES = ("LATENT",)
@ -1071,8 +1095,8 @@ class KSampler:
CATEGORY = "sampling" CATEGORY = "sampling"
def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, taesd=None): def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, previewer=None):
return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, taesd=taesd) return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, previewer=previewer)
class KSamplerAdvanced: class KSamplerAdvanced:
@classmethod @classmethod
@ -1093,7 +1117,7 @@ class KSamplerAdvanced:
"return_with_leftover_noise": (["disable", "enable"], ), "return_with_leftover_noise": (["disable", "enable"], ),
}, },
"optional": { "optional": {
"taesd": ("TAESD",) "previewer": ("LATENT_PREVIEWER",)
}} }}
RETURN_TYPES = ("LATENT",) RETURN_TYPES = ("LATENT",)
@ -1101,14 +1125,14 @@ class KSamplerAdvanced:
CATEGORY = "sampling" CATEGORY = "sampling"
def sample(self, model, add_noise, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, start_at_step, end_at_step, return_with_leftover_noise, denoise=1.0, taesd=None): def sample(self, model, add_noise, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, start_at_step, end_at_step, return_with_leftover_noise, denoise=1.0, previewer=None):
force_full_denoise = True force_full_denoise = True
if return_with_leftover_noise == "enable": if return_with_leftover_noise == "enable":
force_full_denoise = False force_full_denoise = False
disable_noise = False disable_noise = False
if add_noise == "disable": if add_noise == "disable":
disable_noise = True disable_noise = True
return common_ksampler(model, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, disable_noise=disable_noise, start_step=start_at_step, last_step=end_at_step, force_full_denoise=force_full_denoise, taesd=taesd) return common_ksampler(model, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, disable_noise=disable_noise, start_step=start_at_step, last_step=end_at_step, force_full_denoise=force_full_denoise, previewer=previewer)
class SaveImage: class SaveImage:
def __init__(self): def __init__(self):
@ -1369,6 +1393,7 @@ NODE_CLASS_MAPPINGS = {
"VAELoader": VAELoader, "VAELoader": VAELoader,
"TAESDDecode": TAESDDecode, "TAESDDecode": TAESDDecode,
"TAESDEncode": TAESDEncode, "TAESDEncode": TAESDEncode,
"TAESDPreviewer": TAESDPreviewer,
"TAESDLoader": TAESDLoader, "TAESDLoader": TAESDLoader,
"EmptyLatentImage": EmptyLatentImage, "EmptyLatentImage": EmptyLatentImage,
"LatentUpscale": LatentUpscale, "LatentUpscale": LatentUpscale,
@ -1425,6 +1450,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"CheckpointLoaderSimple": "Load Checkpoint", "CheckpointLoaderSimple": "Load Checkpoint",
"VAELoader": "Load VAE", "VAELoader": "Load VAE",
"TAESDLoader": "Load TAESD", "TAESDLoader": "Load TAESD",
"TAESDPreviewer": "TAESD Previewer",
"LoraLoader": "Load LoRA", "LoraLoader": "Load LoRA",
"CLIPLoader": "Load CLIP", "CLIPLoader": "Load CLIP",
"ControlNetLoader": "Load ControlNet Model", "ControlNetLoader": "Load ControlNet Model",