From cef6aa62b2745ac84f0c0d875a614cbf45ac5661 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 25 Jun 2023 02:38:14 -0400 Subject: [PATCH] Add support for TAESD decoder for SDXL. --- README.md | 2 +- comfy/latent_formats.py | 17 ++++++++++++++++- latent_preview.py | 18 ++++++------------ nodes.py | 2 +- 4 files changed, 24 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index ccbe234f..56ee873e 100644 --- a/README.md +++ b/README.md @@ -193,7 +193,7 @@ You can set this command line setting to disable the upcasting to fp32 in some c Use ```--preview-method auto``` to enable previews. -The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_encoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_encoder.pth) and [taesd_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth) models and place them in the `models/vae_approx` folder. Once they're installed, restart ComfyUI to enable high-quality previews. +The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth) (for SD1.x and SD2.x) and [taesdxl_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesdxl_decoder.pth) (for SDXL) models and place them in the `models/vae_approx` folder. Once they're installed, restart ComfyUI to enable high-quality previews. ## Support and dev channel diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 3e193828..07937f73 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -9,8 +9,23 @@ class LatentFormat: class SD15(LatentFormat): def __init__(self, scale_factor=0.18215): self.scale_factor = scale_factor + self.latent_rgb_factors = [ + # R G B + [0.298, 0.207, 0.208], # L1 + [0.187, 0.286, 0.173], # L2 + [-0.158, 0.189, 0.264], # L3 + [-0.184, -0.271, -0.473], # L4 + ] + self.taesd_decoder_name = "taesd_decoder.pth" class SDXL(LatentFormat): def __init__(self): self.scale_factor = 0.13025 - + self.latent_rgb_factors = [ #TODO: these are the factors for SD1.5, need to estimate new ones for SDXL + # R G B + [0.298, 0.207, 0.208], # L1 + [0.187, 0.286, 0.173], # L2 + [-0.158, 0.189, 0.264], # L3 + [-0.184, -0.271, -0.473], # L4 + ] + self.taesd_decoder_name = "taesdxl_decoder.pth" diff --git a/latent_preview.py b/latent_preview.py index ef6c201b..1d143339 100644 --- a/latent_preview.py +++ b/latent_preview.py @@ -49,14 +49,8 @@ class TAESDPreviewerImpl(LatentPreviewer): class Latent2RGBPreviewer(LatentPreviewer): - def __init__(self): - self.latent_rgb_factors = torch.tensor([ - # R G B - [0.298, 0.207, 0.208], # L1 - [0.187, 0.286, 0.173], # L2 - [-0.158, 0.189, 0.264], # L3 - [-0.184, -0.271, -0.473], # L4 - ], device="cpu") + def __init__(self, latent_rgb_factors): + self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu") def decode_latent_to_preview(self, x0): latent_image = x0[0].permute(1, 2, 0).cpu() @ self.latent_rgb_factors @@ -69,12 +63,12 @@ class Latent2RGBPreviewer(LatentPreviewer): return Image.fromarray(latents_ubyte.numpy()) -def get_previewer(device): +def get_previewer(device, latent_format): previewer = None method = args.preview_method if method != LatentPreviewMethod.NoPreviews: # TODO previewer methods - taesd_decoder_path = folder_paths.get_full_path("vae_approx", "taesd_decoder.pth") + taesd_decoder_path = folder_paths.get_full_path("vae_approx", latent_format.taesd_decoder_name) if method == LatentPreviewMethod.Auto: method = LatentPreviewMethod.Latent2RGB @@ -86,10 +80,10 @@ def get_previewer(device): taesd = TAESD(None, taesd_decoder_path).to(device) previewer = TAESDPreviewerImpl(taesd) else: - print("Warning: TAESD previews enabled, but could not find models/vae_approx/taesd_decoder.pth") + print("Warning: TAESD previews enabled, but could not find models/vae_approx/{}".format(latent_format.taesd_decoder_name)) if previewer is None: - previewer = Latent2RGBPreviewer() + previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors) return previewer diff --git a/nodes.py b/nodes.py index c565501a..456805c1 100644 --- a/nodes.py +++ b/nodes.py @@ -954,7 +954,7 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, if preview_format not in ["JPEG", "PNG"]: preview_format = "JPEG" - previewer = latent_preview.get_previewer(device) + previewer = latent_preview.get_previewer(device, model.model.latent_format) pbar = comfy.utils.ProgressBar(steps) def callback(step, x0, x, total_steps):