From 9eb621c95a7c867d118e90348057474cbc96c20c Mon Sep 17 00:00:00 2001 From: Yukimasa Funaoka Date: Tue, 10 Oct 2023 13:21:44 +0900 Subject: [PATCH] Supports TAESD models in safetensors format --- comfy/latent_formats.py | 4 ++-- comfy/taesd/taesd.py | 12 ++++++++++-- latent_preview.py | 7 ++++++- 3 files changed, 18 insertions(+), 5 deletions(-) diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index fadc0eec..c209087e 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -20,7 +20,7 @@ class SD15(LatentFormat): [-0.2829, 0.1762, 0.2721], [-0.2120, -0.2616, -0.7177] ] - self.taesd_decoder_name = "taesd_decoder.pth" + self.taesd_decoder_name = "taesd_decoder" class SDXL(LatentFormat): def __init__(self): @@ -32,4 +32,4 @@ class SDXL(LatentFormat): [ 0.0568, 0.1687, -0.0755], [-0.3112, -0.2359, -0.2076] ] - self.taesd_decoder_name = "taesdxl_decoder.pth" + self.taesd_decoder_name = "taesdxl_decoder" diff --git a/comfy/taesd/taesd.py b/comfy/taesd/taesd.py index 1549345a..92f74c11 100644 --- a/comfy/taesd/taesd.py +++ b/comfy/taesd/taesd.py @@ -50,9 +50,17 @@ class TAESD(nn.Module): self.encoder = Encoder() self.decoder = Decoder() if encoder_path is not None: - self.encoder.load_state_dict(torch.load(encoder_path, map_location="cpu", weights_only=True)) + if encoder_path.lower().endswith(".safetensors"): + import safetensors.torch + self.encoder.load_state_dict(safetensors.torch.load_file(encoder_path, device="cpu")) + else: + self.encoder.load_state_dict(torch.load(encoder_path, map_location="cpu", weights_only=True)) if decoder_path is not None: - self.decoder.load_state_dict(torch.load(decoder_path, map_location="cpu", weights_only=True)) + if decoder_path.lower().endswith(".safetensors"): + import safetensors.torch + self.decoder.load_state_dict(safetensors.torch.load_file(decoder_path, device="cpu")) + else: + self.decoder.load_state_dict(torch.load(decoder_path, map_location="cpu", weights_only=True)) @staticmethod def scale_latents(x): diff --git a/latent_preview.py b/latent_preview.py index 740e0860..e1553c85 100644 --- a/latent_preview.py +++ b/latent_preview.py @@ -56,7 +56,12 @@ def get_previewer(device, latent_format): # TODO previewer methods taesd_decoder_path = None if latent_format.taesd_decoder_name is not None: - taesd_decoder_path = folder_paths.get_full_path("vae_approx", latent_format.taesd_decoder_name) + taesd_decoder_path = next( + (fn for fn in folder_paths.get_filename_list("vae_approx") + if fn.startswith(latent_format.taesd_decoder_name)), + "" + ) + taesd_decoder_path = folder_paths.get_full_path("vae_approx", taesd_decoder_path) if method == LatentPreviewMethod.Auto: method = LatentPreviewMethod.Latent2RGB