mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Move latent scale factor from VAE to model.
This commit is contained in:
parent
30a3861946
commit
8607c2d42d
16
comfy/latent_formats.py
Normal file
16
comfy/latent_formats.py
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
|
||||||
|
class LatentFormat:
|
||||||
|
def process_in(self, latent):
|
||||||
|
return latent * self.scale_factor
|
||||||
|
|
||||||
|
def process_out(self, latent):
|
||||||
|
return latent / self.scale_factor
|
||||||
|
|
||||||
|
class SD15(LatentFormat):
|
||||||
|
def __init__(self, scale_factor=0.18215):
|
||||||
|
self.scale_factor = scale_factor
|
||||||
|
|
||||||
|
class SDXL(LatentFormat):
|
||||||
|
def __init__(self):
|
||||||
|
self.scale_factor = 0.13025
|
||||||
|
|
@ -6,9 +6,11 @@ from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
class BaseModel(torch.nn.Module):
|
class BaseModel(torch.nn.Module):
|
||||||
def __init__(self, unet_config, v_prediction=False):
|
def __init__(self, model_config, v_prediction=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
unet_config = model_config.unet_config
|
||||||
|
self.latent_format = model_config.latent_format
|
||||||
self.register_schedule(given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
|
self.register_schedule(given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
|
||||||
self.diffusion_model = UNetModel(**unet_config)
|
self.diffusion_model = UNetModel(**unet_config)
|
||||||
self.v_prediction = v_prediction
|
self.v_prediction = v_prediction
|
||||||
@ -75,9 +77,16 @@ class BaseModel(torch.nn.Module):
|
|||||||
del to_load
|
del to_load
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def process_latent_in(self, latent):
|
||||||
|
return self.latent_format.process_in(latent)
|
||||||
|
|
||||||
|
def process_latent_out(self, latent):
|
||||||
|
return self.latent_format.process_out(latent)
|
||||||
|
|
||||||
|
|
||||||
class SD21UNCLIP(BaseModel):
|
class SD21UNCLIP(BaseModel):
|
||||||
def __init__(self, unet_config, noise_aug_config, v_prediction=True):
|
def __init__(self, model_config, noise_aug_config, v_prediction=True):
|
||||||
super().__init__(unet_config, v_prediction)
|
super().__init__(model_config, v_prediction)
|
||||||
self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**noise_aug_config)
|
self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**noise_aug_config)
|
||||||
|
|
||||||
def encode_adm(self, **kwargs):
|
def encode_adm(self, **kwargs):
|
||||||
@ -112,13 +121,13 @@ class SD21UNCLIP(BaseModel):
|
|||||||
return adm_out
|
return adm_out
|
||||||
|
|
||||||
class SDInpaint(BaseModel):
|
class SDInpaint(BaseModel):
|
||||||
def __init__(self, unet_config, v_prediction=False):
|
def __init__(self, model_config, v_prediction=False):
|
||||||
super().__init__(unet_config, v_prediction)
|
super().__init__(model_config, v_prediction)
|
||||||
self.concat_keys = ("mask", "masked_image")
|
self.concat_keys = ("mask", "masked_image")
|
||||||
|
|
||||||
class SDXLRefiner(BaseModel):
|
class SDXLRefiner(BaseModel):
|
||||||
def __init__(self, unet_config, v_prediction=False):
|
def __init__(self, model_config, v_prediction=False):
|
||||||
super().__init__(unet_config, v_prediction)
|
super().__init__(model_config, v_prediction)
|
||||||
self.embedder = Timestep(256)
|
self.embedder = Timestep(256)
|
||||||
|
|
||||||
def encode_adm(self, **kwargs):
|
def encode_adm(self, **kwargs):
|
||||||
@ -144,8 +153,8 @@ class SDXLRefiner(BaseModel):
|
|||||||
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
|
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
|
||||||
|
|
||||||
class SDXL(BaseModel):
|
class SDXL(BaseModel):
|
||||||
def __init__(self, unet_config, v_prediction=False):
|
def __init__(self, model_config, v_prediction=False):
|
||||||
super().__init__(unet_config, v_prediction)
|
super().__init__(model_config, v_prediction)
|
||||||
self.embedder = Timestep(256)
|
self.embedder = Timestep(256)
|
||||||
|
|
||||||
def encode_adm(self, **kwargs):
|
def encode_adm(self, **kwargs):
|
||||||
|
@ -586,6 +586,9 @@ class KSampler:
|
|||||||
positive = encode_adm(self.model, positive, noise.shape[0], noise.shape[3], noise.shape[2], self.device, "positive")
|
positive = encode_adm(self.model, positive, noise.shape[0], noise.shape[3], noise.shape[2], self.device, "positive")
|
||||||
negative = encode_adm(self.model, negative, noise.shape[0], noise.shape[3], noise.shape[2], self.device, "negative")
|
negative = encode_adm(self.model, negative, noise.shape[0], noise.shape[3], noise.shape[2], self.device, "negative")
|
||||||
|
|
||||||
|
if latent_image is not None:
|
||||||
|
latent_image = self.model.process_latent_in(latent_image)
|
||||||
|
|
||||||
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": self.model_options}
|
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": self.model_options}
|
||||||
|
|
||||||
cond_concat = None
|
cond_concat = None
|
||||||
@ -672,4 +675,4 @@ class KSampler:
|
|||||||
else:
|
else:
|
||||||
samples = getattr(k_diffusion_sampling, "sample_{}".format(self.sampler))(self.model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar)
|
samples = getattr(k_diffusion_sampling, "sample_{}".format(self.sampler))(self.model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar)
|
||||||
|
|
||||||
return samples.to(torch.float32)
|
return self.model.process_latent_out(samples.to(torch.float32))
|
||||||
|
32
comfy/sd.py
32
comfy/sd.py
@ -536,7 +536,7 @@ class CLIP:
|
|||||||
|
|
||||||
|
|
||||||
class VAE:
|
class VAE:
|
||||||
def __init__(self, ckpt_path=None, scale_factor=0.18215, device=None, config=None):
|
def __init__(self, ckpt_path=None, device=None, config=None):
|
||||||
if config is None:
|
if config is None:
|
||||||
#default SD1.x/SD2.x VAE parameters
|
#default SD1.x/SD2.x VAE parameters
|
||||||
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
||||||
@ -550,7 +550,6 @@ class VAE:
|
|||||||
sd = diffusers_convert.convert_vae_state_dict(sd)
|
sd = diffusers_convert.convert_vae_state_dict(sd)
|
||||||
self.first_stage_model.load_state_dict(sd, strict=False)
|
self.first_stage_model.load_state_dict(sd, strict=False)
|
||||||
|
|
||||||
self.scale_factor = scale_factor
|
|
||||||
if device is None:
|
if device is None:
|
||||||
device = model_management.get_torch_device()
|
device = model_management.get_torch_device()
|
||||||
self.device = device
|
self.device = device
|
||||||
@ -561,7 +560,7 @@ class VAE:
|
|||||||
steps += samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap)
|
steps += samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap)
|
||||||
pbar = utils.ProgressBar(steps)
|
pbar = utils.ProgressBar(steps)
|
||||||
|
|
||||||
decode_fn = lambda a: (self.first_stage_model.decode(1. / self.scale_factor * a.to(self.device)) + 1.0)
|
decode_fn = lambda a: (self.first_stage_model.decode(a.to(self.device)) + 1.0)
|
||||||
output = torch.clamp((
|
output = torch.clamp((
|
||||||
(utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8, pbar = pbar) +
|
(utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8, pbar = pbar) +
|
||||||
utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8, pbar = pbar) +
|
utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8, pbar = pbar) +
|
||||||
@ -575,7 +574,7 @@ class VAE:
|
|||||||
steps += pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap)
|
steps += pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap)
|
||||||
pbar = utils.ProgressBar(steps)
|
pbar = utils.ProgressBar(steps)
|
||||||
|
|
||||||
encode_fn = lambda a: self.first_stage_model.encode(2. * a.to(self.device) - 1.).sample() * self.scale_factor
|
encode_fn = lambda a: self.first_stage_model.encode(2. * a.to(self.device) - 1.).sample()
|
||||||
samples = utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
|
samples = utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
|
||||||
samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
|
samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
|
||||||
samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
|
samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
|
||||||
@ -593,7 +592,7 @@ class VAE:
|
|||||||
pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * 8), round(samples_in.shape[3] * 8)), device="cpu")
|
pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * 8), round(samples_in.shape[3] * 8)), device="cpu")
|
||||||
for x in range(0, samples_in.shape[0], batch_number):
|
for x in range(0, samples_in.shape[0], batch_number):
|
||||||
samples = samples_in[x:x+batch_number].to(self.device)
|
samples = samples_in[x:x+batch_number].to(self.device)
|
||||||
pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(1. / self.scale_factor * samples) + 1.0) / 2.0, min=0.0, max=1.0).cpu()
|
pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(samples) + 1.0) / 2.0, min=0.0, max=1.0).cpu()
|
||||||
except model_management.OOM_EXCEPTION as e:
|
except model_management.OOM_EXCEPTION as e:
|
||||||
print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
||||||
pixel_samples = self.decode_tiled_(samples_in)
|
pixel_samples = self.decode_tiled_(samples_in)
|
||||||
@ -620,7 +619,7 @@ class VAE:
|
|||||||
samples = torch.empty((pixel_samples.shape[0], 4, round(pixel_samples.shape[2] // 8), round(pixel_samples.shape[3] // 8)), device="cpu")
|
samples = torch.empty((pixel_samples.shape[0], 4, round(pixel_samples.shape[2] // 8), round(pixel_samples.shape[3] // 8)), device="cpu")
|
||||||
for x in range(0, pixel_samples.shape[0], batch_number):
|
for x in range(0, pixel_samples.shape[0], batch_number):
|
||||||
pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.device)
|
pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.device)
|
||||||
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).sample().cpu() * self.scale_factor
|
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).sample().cpu()
|
||||||
|
|
||||||
except model_management.OOM_EXCEPTION as e:
|
except model_management.OOM_EXCEPTION as e:
|
||||||
print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
|
print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
|
||||||
@ -958,6 +957,7 @@ def load_gligen(ckpt_path):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_clip=True, embedding_directory=None, state_dict=None, config=None):
|
def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_clip=True, embedding_directory=None, state_dict=None, config=None):
|
||||||
|
#TODO: this function is a mess and should be removed eventually
|
||||||
if config is None:
|
if config is None:
|
||||||
with open(config_path, 'r') as stream:
|
with open(config_path, 'r') as stream:
|
||||||
config = yaml.safe_load(stream)
|
config = yaml.safe_load(stream)
|
||||||
@ -992,12 +992,20 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
|||||||
if state_dict is None:
|
if state_dict is None:
|
||||||
state_dict = utils.load_torch_file(ckpt_path)
|
state_dict = utils.load_torch_file(ckpt_path)
|
||||||
|
|
||||||
|
class EmptyClass:
|
||||||
|
pass
|
||||||
|
|
||||||
|
model_config = EmptyClass()
|
||||||
|
model_config.unet_config = unet_config
|
||||||
|
from . import latent_formats
|
||||||
|
model_config.latent_format = latent_formats.SD15(scale_factor=scale_factor)
|
||||||
|
|
||||||
if config['model']["target"].endswith("LatentInpaintDiffusion"):
|
if config['model']["target"].endswith("LatentInpaintDiffusion"):
|
||||||
model = model_base.SDInpaint(unet_config, v_prediction=v_prediction)
|
model = model_base.SDInpaint(model_config, v_prediction=v_prediction)
|
||||||
elif config['model']["target"].endswith("ImageEmbeddingConditionedLatentDiffusion"):
|
elif config['model']["target"].endswith("ImageEmbeddingConditionedLatentDiffusion"):
|
||||||
model = model_base.SD21UNCLIP(unet_config, noise_aug_config["params"], v_prediction=v_prediction)
|
model = model_base.SD21UNCLIP(model_config, noise_aug_config["params"], v_prediction=v_prediction)
|
||||||
else:
|
else:
|
||||||
model = model_base.BaseModel(unet_config, v_prediction=v_prediction)
|
model = model_base.BaseModel(model_config, v_prediction=v_prediction)
|
||||||
|
|
||||||
if fp16:
|
if fp16:
|
||||||
model = model.half()
|
model = model.half()
|
||||||
@ -1006,14 +1014,12 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
|||||||
|
|
||||||
if output_vae:
|
if output_vae:
|
||||||
w = WeightsLoader()
|
w = WeightsLoader()
|
||||||
vae = VAE(scale_factor=scale_factor, config=vae_config)
|
vae = VAE(config=vae_config)
|
||||||
w.first_stage_model = vae.first_stage_model
|
w.first_stage_model = vae.first_stage_model
|
||||||
load_model_weights(w, state_dict)
|
load_model_weights(w, state_dict)
|
||||||
|
|
||||||
if output_clip:
|
if output_clip:
|
||||||
w = WeightsLoader()
|
w = WeightsLoader()
|
||||||
class EmptyClass:
|
|
||||||
pass
|
|
||||||
clip_target = EmptyClass()
|
clip_target = EmptyClass()
|
||||||
clip_target.params = clip_config.get("params", {})
|
clip_target.params = clip_config.get("params", {})
|
||||||
if clip_config["target"].endswith("FrozenOpenCLIPEmbedder"):
|
if clip_config["target"].endswith("FrozenOpenCLIPEmbedder"):
|
||||||
@ -1055,7 +1061,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
model.load_model_weights(sd, "model.diffusion_model.")
|
model.load_model_weights(sd, "model.diffusion_model.")
|
||||||
|
|
||||||
if output_vae:
|
if output_vae:
|
||||||
vae = VAE(scale_factor=model_config.vae_scale_factor)
|
vae = VAE()
|
||||||
w = WeightsLoader()
|
w = WeightsLoader()
|
||||||
w.first_stage_model = vae.first_stage_model
|
w.first_stage_model = vae.first_stage_model
|
||||||
load_model_weights(w, sd)
|
load_model_weights(w, sd)
|
||||||
|
@ -7,6 +7,7 @@ from . import sd2_clip
|
|||||||
from . import sdxl_clip
|
from . import sdxl_clip
|
||||||
|
|
||||||
from . import supported_models_base
|
from . import supported_models_base
|
||||||
|
from . import latent_formats
|
||||||
|
|
||||||
class SD15(supported_models_base.BASE):
|
class SD15(supported_models_base.BASE):
|
||||||
unet_config = {
|
unet_config = {
|
||||||
@ -21,7 +22,7 @@ class SD15(supported_models_base.BASE):
|
|||||||
"num_head_channels": -1,
|
"num_head_channels": -1,
|
||||||
}
|
}
|
||||||
|
|
||||||
vae_scale_factor = 0.18215
|
latent_format = latent_formats.SD15
|
||||||
|
|
||||||
def process_clip_state_dict(self, state_dict):
|
def process_clip_state_dict(self, state_dict):
|
||||||
k = list(state_dict.keys())
|
k = list(state_dict.keys())
|
||||||
@ -48,7 +49,7 @@ class SD20(supported_models_base.BASE):
|
|||||||
"adm_in_channels": None,
|
"adm_in_channels": None,
|
||||||
}
|
}
|
||||||
|
|
||||||
vae_scale_factor = 0.18215
|
latent_format = latent_formats.SD15
|
||||||
|
|
||||||
def v_prediction(self, state_dict):
|
def v_prediction(self, state_dict):
|
||||||
if self.unet_config["in_channels"] == 4: #SD2.0 inpainting models are not v prediction
|
if self.unet_config["in_channels"] == 4: #SD2.0 inpainting models are not v prediction
|
||||||
@ -97,10 +98,10 @@ class SDXLRefiner(supported_models_base.BASE):
|
|||||||
"transformer_depth": [0, 4, 4, 0],
|
"transformer_depth": [0, 4, 4, 0],
|
||||||
}
|
}
|
||||||
|
|
||||||
vae_scale_factor = 0.13025
|
latent_format = latent_formats.SDXL
|
||||||
|
|
||||||
def get_model(self, state_dict):
|
def get_model(self, state_dict):
|
||||||
return model_base.SDXLRefiner(self.unet_config)
|
return model_base.SDXLRefiner(self)
|
||||||
|
|
||||||
def process_clip_state_dict(self, state_dict):
|
def process_clip_state_dict(self, state_dict):
|
||||||
keys_to_replace = {}
|
keys_to_replace = {}
|
||||||
@ -124,10 +125,10 @@ class SDXL(supported_models_base.BASE):
|
|||||||
"adm_in_channels": 2816
|
"adm_in_channels": 2816
|
||||||
}
|
}
|
||||||
|
|
||||||
vae_scale_factor = 0.13025
|
latent_format = latent_formats.SDXL
|
||||||
|
|
||||||
def get_model(self, state_dict):
|
def get_model(self, state_dict):
|
||||||
return model_base.SDXL(self.unet_config)
|
return model_base.SDXL(self)
|
||||||
|
|
||||||
def process_clip_state_dict(self, state_dict):
|
def process_clip_state_dict(self, state_dict):
|
||||||
keys_to_replace = {}
|
keys_to_replace = {}
|
||||||
|
@ -49,16 +49,17 @@ class BASE:
|
|||||||
|
|
||||||
def __init__(self, unet_config):
|
def __init__(self, unet_config):
|
||||||
self.unet_config = unet_config
|
self.unet_config = unet_config
|
||||||
|
self.latent_format = self.latent_format()
|
||||||
for x in self.unet_extra_config:
|
for x in self.unet_extra_config:
|
||||||
self.unet_config[x] = self.unet_extra_config[x]
|
self.unet_config[x] = self.unet_extra_config[x]
|
||||||
|
|
||||||
def get_model(self, state_dict):
|
def get_model(self, state_dict):
|
||||||
if self.inpaint_model():
|
if self.inpaint_model():
|
||||||
return model_base.SDInpaint(self.unet_config, v_prediction=self.v_prediction(state_dict))
|
return model_base.SDInpaint(self, v_prediction=self.v_prediction(state_dict))
|
||||||
elif self.noise_aug_config is not None:
|
elif self.noise_aug_config is not None:
|
||||||
return model_base.SD21UNCLIP(self.unet_config, self.noise_aug_config, v_prediction=self.v_prediction(state_dict))
|
return model_base.SD21UNCLIP(self, self.noise_aug_config, v_prediction=self.v_prediction(state_dict))
|
||||||
else:
|
else:
|
||||||
return model_base.BaseModel(self.unet_config, v_prediction=self.v_prediction(state_dict))
|
return model_base.BaseModel(self, v_prediction=self.v_prediction(state_dict))
|
||||||
|
|
||||||
def process_clip_state_dict(self, state_dict):
|
def process_clip_state_dict(self, state_dict):
|
||||||
return state_dict
|
return state_dict
|
||||||
|
6
nodes.py
6
nodes.py
@ -284,6 +284,7 @@ class SaveLatent:
|
|||||||
|
|
||||||
output = {}
|
output = {}
|
||||||
output["latent_tensor"] = samples["samples"]
|
output["latent_tensor"] = samples["samples"]
|
||||||
|
output["latent_format_version_0"] = torch.tensor([])
|
||||||
|
|
||||||
safetensors.torch.save_file(output, file, metadata=metadata)
|
safetensors.torch.save_file(output, file, metadata=metadata)
|
||||||
|
|
||||||
@ -305,7 +306,10 @@ class LoadLatent:
|
|||||||
def load(self, latent):
|
def load(self, latent):
|
||||||
latent_path = folder_paths.get_annotated_filepath(latent)
|
latent_path = folder_paths.get_annotated_filepath(latent)
|
||||||
latent = safetensors.torch.load_file(latent_path, device="cpu")
|
latent = safetensors.torch.load_file(latent_path, device="cpu")
|
||||||
samples = {"samples": latent["latent_tensor"].float()}
|
multiplier = 1.0
|
||||||
|
if "latent_format_version_0" not in latent:
|
||||||
|
multiplier = 1.0 / 0.18215
|
||||||
|
samples = {"samples": latent["latent_tensor"].float() * multiplier}
|
||||||
return (samples, )
|
return (samples, )
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
Loading…
Reference in New Issue
Block a user