mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Supports TAESD models in safetensors format
This commit is contained in:
parent
ae3e4e9ad8
commit
9eb621c95a
@ -20,7 +20,7 @@ class SD15(LatentFormat):
|
|||||||
[-0.2829, 0.1762, 0.2721],
|
[-0.2829, 0.1762, 0.2721],
|
||||||
[-0.2120, -0.2616, -0.7177]
|
[-0.2120, -0.2616, -0.7177]
|
||||||
]
|
]
|
||||||
self.taesd_decoder_name = "taesd_decoder.pth"
|
self.taesd_decoder_name = "taesd_decoder"
|
||||||
|
|
||||||
class SDXL(LatentFormat):
|
class SDXL(LatentFormat):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@ -32,4 +32,4 @@ class SDXL(LatentFormat):
|
|||||||
[ 0.0568, 0.1687, -0.0755],
|
[ 0.0568, 0.1687, -0.0755],
|
||||||
[-0.3112, -0.2359, -0.2076]
|
[-0.3112, -0.2359, -0.2076]
|
||||||
]
|
]
|
||||||
self.taesd_decoder_name = "taesdxl_decoder.pth"
|
self.taesd_decoder_name = "taesdxl_decoder"
|
||||||
|
@ -50,9 +50,17 @@ class TAESD(nn.Module):
|
|||||||
self.encoder = Encoder()
|
self.encoder = Encoder()
|
||||||
self.decoder = Decoder()
|
self.decoder = Decoder()
|
||||||
if encoder_path is not None:
|
if encoder_path is not None:
|
||||||
self.encoder.load_state_dict(torch.load(encoder_path, map_location="cpu", weights_only=True))
|
if encoder_path.lower().endswith(".safetensors"):
|
||||||
|
import safetensors.torch
|
||||||
|
self.encoder.load_state_dict(safetensors.torch.load_file(encoder_path, device="cpu"))
|
||||||
|
else:
|
||||||
|
self.encoder.load_state_dict(torch.load(encoder_path, map_location="cpu", weights_only=True))
|
||||||
if decoder_path is not None:
|
if decoder_path is not None:
|
||||||
self.decoder.load_state_dict(torch.load(decoder_path, map_location="cpu", weights_only=True))
|
if decoder_path.lower().endswith(".safetensors"):
|
||||||
|
import safetensors.torch
|
||||||
|
self.decoder.load_state_dict(safetensors.torch.load_file(decoder_path, device="cpu"))
|
||||||
|
else:
|
||||||
|
self.decoder.load_state_dict(torch.load(decoder_path, map_location="cpu", weights_only=True))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def scale_latents(x):
|
def scale_latents(x):
|
||||||
|
@ -56,7 +56,12 @@ def get_previewer(device, latent_format):
|
|||||||
# TODO previewer methods
|
# TODO previewer methods
|
||||||
taesd_decoder_path = None
|
taesd_decoder_path = None
|
||||||
if latent_format.taesd_decoder_name is not None:
|
if latent_format.taesd_decoder_name is not None:
|
||||||
taesd_decoder_path = folder_paths.get_full_path("vae_approx", latent_format.taesd_decoder_name)
|
taesd_decoder_path = next(
|
||||||
|
(fn for fn in folder_paths.get_filename_list("vae_approx")
|
||||||
|
if fn.startswith(latent_format.taesd_decoder_name)),
|
||||||
|
""
|
||||||
|
)
|
||||||
|
taesd_decoder_path = folder_paths.get_full_path("vae_approx", taesd_decoder_path)
|
||||||
|
|
||||||
if method == LatentPreviewMethod.Auto:
|
if method == LatentPreviewMethod.Auto:
|
||||||
method = LatentPreviewMethod.Latent2RGB
|
method = LatentPreviewMethod.Latent2RGB
|
||||||
|
Loading…
Reference in New Issue
Block a user