diff --git a/comfy/cli_args.py b/comfy/cli_args.py index cc4709f7..fdb2a34d 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -1,4 +1,35 @@ import argparse +import enum + + +class EnumAction(argparse.Action): + """ + Argparse action for handling Enums + """ + def __init__(self, **kwargs): + # Pop off the type value + enum_type = kwargs.pop("type", None) + + # Ensure an Enum subclass is provided + if enum_type is None: + raise ValueError("type must be assigned an Enum when using EnumAction") + if not issubclass(enum_type, enum.Enum): + raise TypeError("type must be an Enum when using EnumAction") + + # Generate choices from the Enum + choices = tuple(e.value for e in enum_type) + kwargs.setdefault("choices", choices) + kwargs.setdefault("metavar", f"[{','.join(list(choices))}]") + + super(EnumAction, self).__init__(**kwargs) + + self._enum = enum_type + + def __call__(self, parser, namespace, values, option_string=None): + # Convert value back into an Enum + value = self._enum(values) + setattr(namespace, self.dest, value) + parser = argparse.ArgumentParser() @@ -13,6 +44,11 @@ parser.add_argument("--dont-upcast-attention", action="store_true", help="Disabl parser.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).") parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.") +class PreviewType(enum.Enum): + TAESD = "taesd" +parser.add_argument("--disable-previews", action="store_true", help="Disable showing node previews.") +parser.add_argument("--default-preview-method", type=str, default=PreviewType.TAESD, metavar="PREVIEW_TYPE", help="Default preview method for sampler nodes.") + attn_group = parser.add_mutually_exclusive_group() attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.") attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.") diff --git a/main.py b/main.py index 9f1c8928..15f75f89 100644 --- a/main.py +++ b/main.py @@ -45,7 +45,6 @@ def hijack_progress(server): server.send_sync("progress", { "value": value, "max": total}, server.client_id) if preview_image_bytes is not None: server.send_sync(BinaryEventTypes.PREVIEW_IMAGE, preview_image_bytes, server.client_id) - pass comfy.utils.set_progress_bar_global_hook(hook) def cleanup_temp(): diff --git a/nodes.py b/nodes.py index ec9d9984..76074782 100644 --- a/nodes.py +++ b/nodes.py @@ -24,6 +24,7 @@ import comfy.samplers import comfy.sample import comfy.sd import comfy.utils +from comfy.cli_args import args from comfy.taesd.taesd import TAESD import comfy.clip_vision @@ -180,21 +181,6 @@ class VAEDecodeTiled: def decode(self, vae, samples): return (vae.decode_tiled(samples["samples"]), ) -class TAESDDecode: - @classmethod - def INPUT_TYPES(s): - return {"required": { "samples": ("LATENT", ), "taesd": ("TAESD", )}} - RETURN_TYPES = ("IMAGE",) - FUNCTION = "decode" - - CATEGORY = "latent" - - def decode(self, taesd, samples): - device = comfy.model_management.get_torch_device() - # [B, C, H, W] -> [B, H, W, C] - pixels = taesd.decoder(samples["samples"].to(device)).permute(0, 2, 3, 1).detach().clamp(0, 1) - return (pixels, ) - class VAEEncode: @classmethod def INPUT_TYPES(s): @@ -272,21 +258,6 @@ class VAEEncodeForInpaint: return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, ) -class TAESDEncode: - @classmethod - def INPUT_TYPES(s): - return {"required": { "pixels": ("IMAGE", ), "taesd": ("TAESD", )}} - RETURN_TYPES = ("LATENT",) - FUNCTION = "encode" - - CATEGORY = "latent" - - def encode(self, taesd, pixels): - device = comfy.model_management.get_torch_device() - # [B, H, W, C] -> [B, C, H, W] - samples = taesd.encoder(pixels.permute(0, 3, 1, 2).to(device)).to(device) - return ({"samples": samples}, ) - class TAESDPreviewerImpl(LatentPreviewer): def __init__(self, taesd): self.taesd = taesd @@ -297,18 +268,6 @@ class TAESDPreviewerImpl(LatentPreviewer): x_sample = x_sample * 0.5 return x_sample -class TAESDPreviewer: - @classmethod - def INPUT_TYPES(s): - return {"required": { "taesd": ("TAESD", ), }} - RETURN_TYPES = ("LATENT_PREVIEWER",) - FUNCTION = "make_previewer" - - CATEGORY = "latent/previewer" - - def make_previewer(self, taesd): - return (TAESDPreviewerImpl(taesd), ) - class SaveLatent: def __init__(self): self.output_dir = folder_paths.get_output_directory() @@ -524,26 +483,6 @@ class VAELoader: vae = comfy.sd.VAE(ckpt_path=vae_path) return (vae,) -class TAESDLoader: - @classmethod - def INPUT_TYPES(s): - model_list = folder_paths.get_filename_list("taesd") - return {"required": { - "encoder_name": (model_list, { "default": "taesd_encoder.pth" }), - "decoder_name": (model_list, { "default": "taesd_decoder.pth" }) - }} - RETURN_TYPES = ("TAESD",) - FUNCTION = "load_taesd" - - CATEGORY = "loaders" - - def load_taesd(self, encoder_name, decoder_name): - device = comfy.model_management.get_torch_device() - encoder_path = folder_paths.get_full_path("taesd", encoder_name) - decoder_path = folder_paths.get_full_path("taesd", decoder_name) - taesd = TAESD(encoder_path, decoder_path).to(device) - return (taesd,) - class ControlNetLoader: @classmethod def INPUT_TYPES(s): @@ -1039,7 +978,7 @@ def decode_latent_to_preview_image(previewer, device, preview_format, x0): return preview_bytes -def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, previewer=None): +def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False): device = comfy.model_management.get_torch_device() latent_image = latent["samples"] @@ -1057,6 +996,17 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, if preview_format not in ["JPEG", "PNG"]: preview_format = "JPEG" + previewer = None + if not args.disable_previews: + # TODO previewer methods + encoder_path = folder_paths.get_full_path("taesd", "taesd_encoder.pth") + decoder_path = folder_paths.get_full_path("taesd", "taesd_decoder.pth") + if encoder_path and decoder_path: + taesd = TAESD(encoder_path, decoder_path).to(device) + previewer = TAESDPreviewerImpl(taesd) + else: + print("Warning: TAESD previews enabled, but could not find models/taesd/taesd_encoder.pth and models/taesd/taesd_decoder.pth") + pbar = comfy.utils.ProgressBar(steps) def callback(step, x0, x, total_steps): preview_bytes = None @@ -1085,18 +1035,16 @@ class KSampler: "negative": ("CONDITIONING", ), "latent_image": ("LATENT", ), "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), - }, - "optional": { - "previewer": ("LATENT_PREVIEWER",) - }} + } + } RETURN_TYPES = ("LATENT",) FUNCTION = "sample" CATEGORY = "sampling" - def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, previewer=None): - return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, previewer=previewer) + def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0): + return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise) class KSamplerAdvanced: @classmethod @@ -1115,24 +1063,22 @@ class KSamplerAdvanced: "start_at_step": ("INT", {"default": 0, "min": 0, "max": 10000}), "end_at_step": ("INT", {"default": 10000, "min": 0, "max": 10000}), "return_with_leftover_noise": (["disable", "enable"], ), - }, - "optional": { - "previewer": ("LATENT_PREVIEWER",) - }} + } + } RETURN_TYPES = ("LATENT",) FUNCTION = "sample" CATEGORY = "sampling" - def sample(self, model, add_noise, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, start_at_step, end_at_step, return_with_leftover_noise, denoise=1.0, previewer=None): + def sample(self, model, add_noise, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, start_at_step, end_at_step, return_with_leftover_noise, denoise=1.0): force_full_denoise = True if return_with_leftover_noise == "enable": force_full_denoise = False disable_noise = False if add_noise == "disable": disable_noise = True - return common_ksampler(model, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, disable_noise=disable_noise, start_step=start_at_step, last_step=end_at_step, force_full_denoise=force_full_denoise, previewer=previewer) + return common_ksampler(model, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, disable_noise=disable_noise, start_step=start_at_step, last_step=end_at_step, force_full_denoise=force_full_denoise) class SaveImage: def __init__(self): @@ -1391,10 +1337,6 @@ NODE_CLASS_MAPPINGS = { "VAEEncode": VAEEncode, "VAEEncodeForInpaint": VAEEncodeForInpaint, "VAELoader": VAELoader, - "TAESDDecode": TAESDDecode, - "TAESDEncode": TAESDEncode, - "TAESDPreviewer": TAESDPreviewer, - "TAESDLoader": TAESDLoader, "EmptyLatentImage": EmptyLatentImage, "LatentUpscale": LatentUpscale, "LatentUpscaleBy": LatentUpscaleBy, @@ -1449,8 +1391,6 @@ NODE_DISPLAY_NAME_MAPPINGS = { "CheckpointLoader": "Load Checkpoint (With Config)", "CheckpointLoaderSimple": "Load Checkpoint", "VAELoader": "Load VAE", - "TAESDLoader": "Load TAESD", - "TAESDPreviewer": "TAESD Previewer", "LoraLoader": "Load LoRA", "CLIPLoader": "Load CLIP", "ControlNetLoader": "Load ControlNet Model", @@ -1473,8 +1413,6 @@ NODE_DISPLAY_NAME_MAPPINGS = { "SetLatentNoiseMask": "Set Latent Noise Mask", "VAEDecode": "VAE Decode", "VAEEncode": "VAE Encode", - "TAESDDecode": "TAESD Decode", - "TAESDEncode": "TAESD Encode", "LatentRotate": "Rotate Latent", "LatentFlip": "Flip Latent", "LatentCrop": "Crop Latent", diff --git a/web/scripts/app.js b/web/scripts/app.js index 495d43e1..9df94c9e 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -382,7 +382,7 @@ export class ComfyApp { this.images = output.images; imagesChanged = true; imgURLs = imgURLs.concat(output.images.map(params => { - return "/view?" + new URLSearchParams(src).toString() + app.getPreviewFormatParam(); + return "/view?" + new URLSearchParams(params).toString() + app.getPreviewFormatParam(); })) } }