Add support for TAESD decoder for SDXL.

This commit is contained in:
comfyanonymous 2023-06-25 02:38:14 -04:00
parent 20f579d91d
commit cef6aa62b2
4 changed files with 24 additions and 15 deletions

View File

@ -193,7 +193,7 @@ You can set this command line setting to disable the upcasting to fp32 in some c
Use ```--preview-method auto``` to enable previews. Use ```--preview-method auto``` to enable previews.
The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_encoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_encoder.pth) and [taesd_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth) models and place them in the `models/vae_approx` folder. Once they're installed, restart ComfyUI to enable high-quality previews. The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth) (for SD1.x and SD2.x) and [taesdxl_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesdxl_decoder.pth) (for SDXL) models and place them in the `models/vae_approx` folder. Once they're installed, restart ComfyUI to enable high-quality previews.
## Support and dev channel ## Support and dev channel

View File

@ -9,8 +9,23 @@ class LatentFormat:
class SD15(LatentFormat): class SD15(LatentFormat):
def __init__(self, scale_factor=0.18215): def __init__(self, scale_factor=0.18215):
self.scale_factor = scale_factor self.scale_factor = scale_factor
self.latent_rgb_factors = [
# R G B
[0.298, 0.207, 0.208], # L1
[0.187, 0.286, 0.173], # L2
[-0.158, 0.189, 0.264], # L3
[-0.184, -0.271, -0.473], # L4
]
self.taesd_decoder_name = "taesd_decoder.pth"
class SDXL(LatentFormat): class SDXL(LatentFormat):
def __init__(self): def __init__(self):
self.scale_factor = 0.13025 self.scale_factor = 0.13025
self.latent_rgb_factors = [ #TODO: these are the factors for SD1.5, need to estimate new ones for SDXL
# R G B
[0.298, 0.207, 0.208], # L1
[0.187, 0.286, 0.173], # L2
[-0.158, 0.189, 0.264], # L3
[-0.184, -0.271, -0.473], # L4
]
self.taesd_decoder_name = "taesdxl_decoder.pth"

View File

@ -49,14 +49,8 @@ class TAESDPreviewerImpl(LatentPreviewer):
class Latent2RGBPreviewer(LatentPreviewer): class Latent2RGBPreviewer(LatentPreviewer):
def __init__(self): def __init__(self, latent_rgb_factors):
self.latent_rgb_factors = torch.tensor([ self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu")
# R G B
[0.298, 0.207, 0.208], # L1
[0.187, 0.286, 0.173], # L2
[-0.158, 0.189, 0.264], # L3
[-0.184, -0.271, -0.473], # L4
], device="cpu")
def decode_latent_to_preview(self, x0): def decode_latent_to_preview(self, x0):
latent_image = x0[0].permute(1, 2, 0).cpu() @ self.latent_rgb_factors latent_image = x0[0].permute(1, 2, 0).cpu() @ self.latent_rgb_factors
@ -69,12 +63,12 @@ class Latent2RGBPreviewer(LatentPreviewer):
return Image.fromarray(latents_ubyte.numpy()) return Image.fromarray(latents_ubyte.numpy())
def get_previewer(device): def get_previewer(device, latent_format):
previewer = None previewer = None
method = args.preview_method method = args.preview_method
if method != LatentPreviewMethod.NoPreviews: if method != LatentPreviewMethod.NoPreviews:
# TODO previewer methods # TODO previewer methods
taesd_decoder_path = folder_paths.get_full_path("vae_approx", "taesd_decoder.pth") taesd_decoder_path = folder_paths.get_full_path("vae_approx", latent_format.taesd_decoder_name)
if method == LatentPreviewMethod.Auto: if method == LatentPreviewMethod.Auto:
method = LatentPreviewMethod.Latent2RGB method = LatentPreviewMethod.Latent2RGB
@ -86,10 +80,10 @@ def get_previewer(device):
taesd = TAESD(None, taesd_decoder_path).to(device) taesd = TAESD(None, taesd_decoder_path).to(device)
previewer = TAESDPreviewerImpl(taesd) previewer = TAESDPreviewerImpl(taesd)
else: else:
print("Warning: TAESD previews enabled, but could not find models/vae_approx/taesd_decoder.pth") print("Warning: TAESD previews enabled, but could not find models/vae_approx/{}".format(latent_format.taesd_decoder_name))
if previewer is None: if previewer is None:
previewer = Latent2RGBPreviewer() previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors)
return previewer return previewer

View File

@ -954,7 +954,7 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
if preview_format not in ["JPEG", "PNG"]: if preview_format not in ["JPEG", "PNG"]:
preview_format = "JPEG" preview_format = "JPEG"
previewer = latent_preview.get_previewer(device) previewer = latent_preview.get_previewer(device, model.model.latent_format)
pbar = comfy.utils.ProgressBar(steps) pbar = comfy.utils.ProgressBar(steps)
def callback(step, x0, x, total_steps): def callback(step, x0, x, total_steps):