mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-20 03:13:30 +00:00
Cleanup.
This commit is contained in:
parent
851bb87ca9
commit
5e885bd9c8
@ -6,6 +6,8 @@ Tiny AutoEncoder for Stable Diffusion
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
import comfy.utils
|
||||||
|
|
||||||
def conv(n_in, n_out, **kwargs):
|
def conv(n_in, n_out, **kwargs):
|
||||||
return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
|
return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
|
||||||
|
|
||||||
@ -50,17 +52,9 @@ class TAESD(nn.Module):
|
|||||||
self.encoder = Encoder()
|
self.encoder = Encoder()
|
||||||
self.decoder = Decoder()
|
self.decoder = Decoder()
|
||||||
if encoder_path is not None:
|
if encoder_path is not None:
|
||||||
if encoder_path.lower().endswith(".safetensors"):
|
self.encoder.load_state_dict(comfy.utils.load_torch_file(encoder_path, safe_load=True))
|
||||||
import safetensors.torch
|
|
||||||
self.encoder.load_state_dict(safetensors.torch.load_file(encoder_path, device="cpu"))
|
|
||||||
else:
|
|
||||||
self.encoder.load_state_dict(torch.load(encoder_path, map_location="cpu", weights_only=True))
|
|
||||||
if decoder_path is not None:
|
if decoder_path is not None:
|
||||||
if decoder_path.lower().endswith(".safetensors"):
|
self.decoder.load_state_dict(comfy.utils.load_torch_file(decoder_path, safe_load=True))
|
||||||
import safetensors.torch
|
|
||||||
self.decoder.load_state_dict(safetensors.torch.load_file(decoder_path, device="cpu"))
|
|
||||||
else:
|
|
||||||
self.decoder.load_state_dict(torch.load(decoder_path, map_location="cpu", weights_only=True))
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def scale_latents(x):
|
def scale_latents(x):
|
||||||
|
Loading…
Reference in New Issue
Block a user