mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Add support for TAESD decoder for SDXL.
This commit is contained in:
parent
20f579d91d
commit
cef6aa62b2
@ -193,7 +193,7 @@ You can set this command line setting to disable the upcasting to fp32 in some c
|
|||||||
|
|
||||||
Use ```--preview-method auto``` to enable previews.
|
Use ```--preview-method auto``` to enable previews.
|
||||||
|
|
||||||
The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_encoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_encoder.pth) and [taesd_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth) models and place them in the `models/vae_approx` folder. Once they're installed, restart ComfyUI to enable high-quality previews.
|
The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth) (for SD1.x and SD2.x) and [taesdxl_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesdxl_decoder.pth) (for SDXL) models and place them in the `models/vae_approx` folder. Once they're installed, restart ComfyUI to enable high-quality previews.
|
||||||
|
|
||||||
## Support and dev channel
|
## Support and dev channel
|
||||||
|
|
||||||
|
@ -9,8 +9,23 @@ class LatentFormat:
|
|||||||
class SD15(LatentFormat):
|
class SD15(LatentFormat):
|
||||||
def __init__(self, scale_factor=0.18215):
|
def __init__(self, scale_factor=0.18215):
|
||||||
self.scale_factor = scale_factor
|
self.scale_factor = scale_factor
|
||||||
|
self.latent_rgb_factors = [
|
||||||
|
# R G B
|
||||||
|
[0.298, 0.207, 0.208], # L1
|
||||||
|
[0.187, 0.286, 0.173], # L2
|
||||||
|
[-0.158, 0.189, 0.264], # L3
|
||||||
|
[-0.184, -0.271, -0.473], # L4
|
||||||
|
]
|
||||||
|
self.taesd_decoder_name = "taesd_decoder.pth"
|
||||||
|
|
||||||
class SDXL(LatentFormat):
|
class SDXL(LatentFormat):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.scale_factor = 0.13025
|
self.scale_factor = 0.13025
|
||||||
|
self.latent_rgb_factors = [ #TODO: these are the factors for SD1.5, need to estimate new ones for SDXL
|
||||||
|
# R G B
|
||||||
|
[0.298, 0.207, 0.208], # L1
|
||||||
|
[0.187, 0.286, 0.173], # L2
|
||||||
|
[-0.158, 0.189, 0.264], # L3
|
||||||
|
[-0.184, -0.271, -0.473], # L4
|
||||||
|
]
|
||||||
|
self.taesd_decoder_name = "taesdxl_decoder.pth"
|
||||||
|
@ -49,14 +49,8 @@ class TAESDPreviewerImpl(LatentPreviewer):
|
|||||||
|
|
||||||
|
|
||||||
class Latent2RGBPreviewer(LatentPreviewer):
|
class Latent2RGBPreviewer(LatentPreviewer):
|
||||||
def __init__(self):
|
def __init__(self, latent_rgb_factors):
|
||||||
self.latent_rgb_factors = torch.tensor([
|
self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu")
|
||||||
# R G B
|
|
||||||
[0.298, 0.207, 0.208], # L1
|
|
||||||
[0.187, 0.286, 0.173], # L2
|
|
||||||
[-0.158, 0.189, 0.264], # L3
|
|
||||||
[-0.184, -0.271, -0.473], # L4
|
|
||||||
], device="cpu")
|
|
||||||
|
|
||||||
def decode_latent_to_preview(self, x0):
|
def decode_latent_to_preview(self, x0):
|
||||||
latent_image = x0[0].permute(1, 2, 0).cpu() @ self.latent_rgb_factors
|
latent_image = x0[0].permute(1, 2, 0).cpu() @ self.latent_rgb_factors
|
||||||
@ -69,12 +63,12 @@ class Latent2RGBPreviewer(LatentPreviewer):
|
|||||||
return Image.fromarray(latents_ubyte.numpy())
|
return Image.fromarray(latents_ubyte.numpy())
|
||||||
|
|
||||||
|
|
||||||
def get_previewer(device):
|
def get_previewer(device, latent_format):
|
||||||
previewer = None
|
previewer = None
|
||||||
method = args.preview_method
|
method = args.preview_method
|
||||||
if method != LatentPreviewMethod.NoPreviews:
|
if method != LatentPreviewMethod.NoPreviews:
|
||||||
# TODO previewer methods
|
# TODO previewer methods
|
||||||
taesd_decoder_path = folder_paths.get_full_path("vae_approx", "taesd_decoder.pth")
|
taesd_decoder_path = folder_paths.get_full_path("vae_approx", latent_format.taesd_decoder_name)
|
||||||
|
|
||||||
if method == LatentPreviewMethod.Auto:
|
if method == LatentPreviewMethod.Auto:
|
||||||
method = LatentPreviewMethod.Latent2RGB
|
method = LatentPreviewMethod.Latent2RGB
|
||||||
@ -86,10 +80,10 @@ def get_previewer(device):
|
|||||||
taesd = TAESD(None, taesd_decoder_path).to(device)
|
taesd = TAESD(None, taesd_decoder_path).to(device)
|
||||||
previewer = TAESDPreviewerImpl(taesd)
|
previewer = TAESDPreviewerImpl(taesd)
|
||||||
else:
|
else:
|
||||||
print("Warning: TAESD previews enabled, but could not find models/vae_approx/taesd_decoder.pth")
|
print("Warning: TAESD previews enabled, but could not find models/vae_approx/{}".format(latent_format.taesd_decoder_name))
|
||||||
|
|
||||||
if previewer is None:
|
if previewer is None:
|
||||||
previewer = Latent2RGBPreviewer()
|
previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors)
|
||||||
return previewer
|
return previewer
|
||||||
|
|
||||||
|
|
||||||
|
2
nodes.py
2
nodes.py
@ -954,7 +954,7 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
|
|||||||
if preview_format not in ["JPEG", "PNG"]:
|
if preview_format not in ["JPEG", "PNG"]:
|
||||||
preview_format = "JPEG"
|
preview_format = "JPEG"
|
||||||
|
|
||||||
previewer = latent_preview.get_previewer(device)
|
previewer = latent_preview.get_previewer(device, model.model.latent_format)
|
||||||
|
|
||||||
pbar = comfy.utils.ProgressBar(steps)
|
pbar = comfy.utils.ProgressBar(steps)
|
||||||
def callback(step, x0, x, total_steps):
|
def callback(step, x0, x, total_steps):
|
||||||
|
Loading…
Reference in New Issue
Block a user