mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-19 19:03:51 +00:00
Simpler base model code.
This commit is contained in:
parent
4b0b516544
commit
de142eaad5
@ -4,7 +4,7 @@ import yaml
|
|||||||
|
|
||||||
import folder_paths
|
import folder_paths
|
||||||
from comfy.ldm.util import instantiate_from_config
|
from comfy.ldm.util import instantiate_from_config
|
||||||
from comfy.sd import ModelPatcher, load_model_weights, CLIP, VAE
|
from comfy.sd import ModelPatcher, load_model_weights, CLIP, VAE, load_checkpoint
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
import re
|
import re
|
||||||
import torch
|
import torch
|
||||||
@ -84,28 +84,4 @@ def load_diffusers(model_path, fp16=True, output_vae=True, output_clip=True, emb
|
|||||||
# Put together new checkpoint
|
# Put together new checkpoint
|
||||||
sd = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
|
sd = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
|
||||||
|
|
||||||
clip = None
|
return load_checkpoint(embedding_directory=embedding_directory, state_dict=sd, config=config)
|
||||||
vae = None
|
|
||||||
|
|
||||||
class WeightsLoader(torch.nn.Module):
|
|
||||||
pass
|
|
||||||
|
|
||||||
w = WeightsLoader()
|
|
||||||
load_state_dict_to = []
|
|
||||||
if output_vae:
|
|
||||||
vae = VAE(scale_factor=scale_factor, config=vae_config)
|
|
||||||
w.first_stage_model = vae.first_stage_model
|
|
||||||
load_state_dict_to = [w]
|
|
||||||
|
|
||||||
if output_clip:
|
|
||||||
clip = CLIP(config=clip_config, embedding_directory=embedding_directory)
|
|
||||||
w.cond_stage_model = clip.cond_stage_model
|
|
||||||
load_state_dict_to = [w]
|
|
||||||
|
|
||||||
model = instantiate_from_config(config["model"])
|
|
||||||
model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to)
|
|
||||||
|
|
||||||
if fp16:
|
|
||||||
model = model.half()
|
|
||||||
|
|
||||||
return ModelPatcher(model), clip, vae
|
|
||||||
|
66
comfy/model_base.py
Normal file
66
comfy/model_base.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
import torch
|
||||||
|
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel
|
||||||
|
from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
|
||||||
|
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
class BaseModel(torch.nn.Module):
|
||||||
|
def __init__(self, unet_config, v_prediction=False):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.register_schedule(given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
|
||||||
|
self.diffusion_model = UNetModel(**unet_config)
|
||||||
|
self.v_prediction = v_prediction
|
||||||
|
if self.v_prediction:
|
||||||
|
self.parameterization = "v"
|
||||||
|
else:
|
||||||
|
self.parameterization = "eps"
|
||||||
|
if "adm_in_channels" in unet_config:
|
||||||
|
self.adm_channels = unet_config["adm_in_channels"]
|
||||||
|
else:
|
||||||
|
self.adm_channels = 0
|
||||||
|
print("v_prediction", v_prediction)
|
||||||
|
print("adm", self.adm_channels)
|
||||||
|
|
||||||
|
def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
|
||||||
|
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
||||||
|
if given_betas is not None:
|
||||||
|
betas = given_betas
|
||||||
|
else:
|
||||||
|
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
|
||||||
|
alphas = 1. - betas
|
||||||
|
alphas_cumprod = np.cumprod(alphas, axis=0)
|
||||||
|
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
|
||||||
|
|
||||||
|
timesteps, = betas.shape
|
||||||
|
self.num_timesteps = int(timesteps)
|
||||||
|
self.linear_start = linear_start
|
||||||
|
self.linear_end = linear_end
|
||||||
|
|
||||||
|
self.register_buffer('betas', torch.tensor(betas, dtype=torch.float32))
|
||||||
|
self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32))
|
||||||
|
self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
|
||||||
|
|
||||||
|
def apply_model(self, x, t, c_concat=None, c_crossattn=None, c_adm=None, control=None, transformer_options={}):
|
||||||
|
if c_concat is not None:
|
||||||
|
xc = torch.cat([x] + c_concat, dim=1)
|
||||||
|
else:
|
||||||
|
xc = x
|
||||||
|
context = torch.cat(c_crossattn, 1)
|
||||||
|
return self.diffusion_model(xc, t, context=context, y=c_adm, control=control, transformer_options=transformer_options)
|
||||||
|
|
||||||
|
def get_dtype(self):
|
||||||
|
return self.diffusion_model.dtype
|
||||||
|
|
||||||
|
def is_adm(self):
|
||||||
|
return self.adm_channels > 0
|
||||||
|
|
||||||
|
class SD21UNCLIP(BaseModel):
|
||||||
|
def __init__(self, unet_config, noise_aug_config, v_prediction=True):
|
||||||
|
super().__init__(unet_config, v_prediction)
|
||||||
|
self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**noise_aug_config)
|
||||||
|
|
||||||
|
class SDInpaint(BaseModel):
|
||||||
|
def __init__(self, unet_config, v_prediction=False):
|
||||||
|
super().__init__(unet_config, v_prediction)
|
||||||
|
self.concat_keys = ("mask", "masked_image")
|
@ -248,7 +248,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
|||||||
|
|
||||||
c['transformer_options'] = transformer_options
|
c['transformer_options'] = transformer_options
|
||||||
|
|
||||||
output = model_function(input_x, timestep_, cond=c).chunk(batch_chunks)
|
output = model_function(input_x, timestep_, **c).chunk(batch_chunks)
|
||||||
del input_x
|
del input_x
|
||||||
|
|
||||||
model_management.throw_exception_if_processing_interrupted()
|
model_management.throw_exception_if_processing_interrupted()
|
||||||
@ -460,9 +460,11 @@ def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
|
|||||||
uncond[temp[1]] = [o[0], n]
|
uncond[temp[1]] = [o[0], n]
|
||||||
|
|
||||||
|
|
||||||
def encode_adm(noise_augmentor, conds, batch_size, device):
|
def encode_adm(conds, batch_size, device, noise_augmentor=None):
|
||||||
for t in range(len(conds)):
|
for t in range(len(conds)):
|
||||||
x = conds[t]
|
x = conds[t]
|
||||||
|
adm_out = None
|
||||||
|
if noise_augmentor is not None:
|
||||||
if 'adm' in x[1]:
|
if 'adm' in x[1]:
|
||||||
adm_inputs = []
|
adm_inputs = []
|
||||||
weights = []
|
weights = []
|
||||||
@ -488,6 +490,10 @@ def encode_adm(noise_augmentor, conds, batch_size, device):
|
|||||||
adm_out = torch.cat((c_adm, noise_level_emb), 1)
|
adm_out = torch.cat((c_adm, noise_level_emb), 1)
|
||||||
else:
|
else:
|
||||||
adm_out = torch.zeros((1, noise_augmentor.time_embed.dim * 2), device=device)
|
adm_out = torch.zeros((1, noise_augmentor.time_embed.dim * 2), device=device)
|
||||||
|
else:
|
||||||
|
if 'adm' in x[1]:
|
||||||
|
adm_out = x[1]["adm"].to(device)
|
||||||
|
if adm_out is not None:
|
||||||
x[1] = x[1].copy()
|
x[1] = x[1].copy()
|
||||||
x[1]["adm_encoded"] = torch.cat([adm_out] * batch_size)
|
x[1]["adm_encoded"] = torch.cat([adm_out] * batch_size)
|
||||||
|
|
||||||
@ -591,14 +597,17 @@ class KSampler:
|
|||||||
apply_empty_x_to_equal_area(positive, negative, 'control', lambda cond_cnets, x: cond_cnets[x])
|
apply_empty_x_to_equal_area(positive, negative, 'control', lambda cond_cnets, x: cond_cnets[x])
|
||||||
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
|
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
|
||||||
|
|
||||||
if self.model.model.diffusion_model.dtype == torch.float16:
|
if self.model.get_dtype() == torch.float16:
|
||||||
precision_scope = torch.autocast
|
precision_scope = torch.autocast
|
||||||
else:
|
else:
|
||||||
precision_scope = contextlib.nullcontext
|
precision_scope = contextlib.nullcontext
|
||||||
|
|
||||||
|
if self.model.is_adm():
|
||||||
|
noise_augmentor = None
|
||||||
if hasattr(self.model, 'noise_augmentor'): #unclip
|
if hasattr(self.model, 'noise_augmentor'): #unclip
|
||||||
positive = encode_adm(self.model.noise_augmentor, positive, noise.shape[0], self.device)
|
noise_augmentor = self.model.noise_augmentor
|
||||||
negative = encode_adm(self.model.noise_augmentor, negative, noise.shape[0], self.device)
|
positive = encode_adm(positive, noise.shape[0], self.device, noise_augmentor)
|
||||||
|
negative = encode_adm(negative, noise.shape[0], self.device, noise_augmentor)
|
||||||
|
|
||||||
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": self.model_options}
|
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": self.model_options}
|
||||||
|
|
||||||
|
68
comfy/sd.py
68
comfy/sd.py
@ -15,8 +15,15 @@ from . import utils
|
|||||||
from . import clip_vision
|
from . import clip_vision
|
||||||
from . import gligen
|
from . import gligen
|
||||||
from . import diffusers_convert
|
from . import diffusers_convert
|
||||||
|
from . import model_base
|
||||||
|
|
||||||
def load_model_weights(model, sd, verbose=False, load_state_dict_to=[]):
|
def load_model_weights(model, sd, verbose=False, load_state_dict_to=[]):
|
||||||
|
replace_prefix = {"model.diffusion_model.": "diffusion_model."}
|
||||||
|
for rp in replace_prefix:
|
||||||
|
replace = list(map(lambda a: (a, "{}{}".format(replace_prefix[rp], a[len(rp):])), filter(lambda a: a.startswith(rp), sd.keys())))
|
||||||
|
for x in replace:
|
||||||
|
sd[x[1]] = sd.pop(x[0])
|
||||||
|
|
||||||
m, u = model.load_state_dict(sd, strict=False)
|
m, u = model.load_state_dict(sd, strict=False)
|
||||||
|
|
||||||
k = list(sd.keys())
|
k = list(sd.keys())
|
||||||
@ -182,7 +189,7 @@ def model_lora_keys(model, key_map={}):
|
|||||||
|
|
||||||
counter = 0
|
counter = 0
|
||||||
for b in range(12):
|
for b in range(12):
|
||||||
tk = "model.diffusion_model.input_blocks.{}.1".format(b)
|
tk = "diffusion_model.input_blocks.{}.1".format(b)
|
||||||
up_counter = 0
|
up_counter = 0
|
||||||
for c in LORA_UNET_MAP_ATTENTIONS:
|
for c in LORA_UNET_MAP_ATTENTIONS:
|
||||||
k = "{}.{}.weight".format(tk, c)
|
k = "{}.{}.weight".format(tk, c)
|
||||||
@ -193,13 +200,13 @@ def model_lora_keys(model, key_map={}):
|
|||||||
if up_counter >= 4:
|
if up_counter >= 4:
|
||||||
counter += 1
|
counter += 1
|
||||||
for c in LORA_UNET_MAP_ATTENTIONS:
|
for c in LORA_UNET_MAP_ATTENTIONS:
|
||||||
k = "model.diffusion_model.middle_block.1.{}.weight".format(c)
|
k = "diffusion_model.middle_block.1.{}.weight".format(c)
|
||||||
if k in sdk:
|
if k in sdk:
|
||||||
lora_key = "lora_unet_mid_block_attentions_0_{}".format(LORA_UNET_MAP_ATTENTIONS[c])
|
lora_key = "lora_unet_mid_block_attentions_0_{}".format(LORA_UNET_MAP_ATTENTIONS[c])
|
||||||
key_map[lora_key] = k
|
key_map[lora_key] = k
|
||||||
counter = 3
|
counter = 3
|
||||||
for b in range(12):
|
for b in range(12):
|
||||||
tk = "model.diffusion_model.output_blocks.{}.1".format(b)
|
tk = "diffusion_model.output_blocks.{}.1".format(b)
|
||||||
up_counter = 0
|
up_counter = 0
|
||||||
for c in LORA_UNET_MAP_ATTENTIONS:
|
for c in LORA_UNET_MAP_ATTENTIONS:
|
||||||
k = "{}.{}.weight".format(tk, c)
|
k = "{}.{}.weight".format(tk, c)
|
||||||
@ -223,7 +230,7 @@ def model_lora_keys(model, key_map={}):
|
|||||||
ds_counter = 0
|
ds_counter = 0
|
||||||
counter = 0
|
counter = 0
|
||||||
for b in range(12):
|
for b in range(12):
|
||||||
tk = "model.diffusion_model.input_blocks.{}.0".format(b)
|
tk = "diffusion_model.input_blocks.{}.0".format(b)
|
||||||
key_in = False
|
key_in = False
|
||||||
for c in LORA_UNET_MAP_RESNET:
|
for c in LORA_UNET_MAP_RESNET:
|
||||||
k = "{}.{}.weight".format(tk, c)
|
k = "{}.{}.weight".format(tk, c)
|
||||||
@ -242,7 +249,7 @@ def model_lora_keys(model, key_map={}):
|
|||||||
|
|
||||||
counter = 0
|
counter = 0
|
||||||
for b in range(3):
|
for b in range(3):
|
||||||
tk = "model.diffusion_model.middle_block.{}".format(b)
|
tk = "diffusion_model.middle_block.{}".format(b)
|
||||||
key_in = False
|
key_in = False
|
||||||
for c in LORA_UNET_MAP_RESNET:
|
for c in LORA_UNET_MAP_RESNET:
|
||||||
k = "{}.{}.weight".format(tk, c)
|
k = "{}.{}.weight".format(tk, c)
|
||||||
@ -256,7 +263,7 @@ def model_lora_keys(model, key_map={}):
|
|||||||
counter = 0
|
counter = 0
|
||||||
us_counter = 0
|
us_counter = 0
|
||||||
for b in range(12):
|
for b in range(12):
|
||||||
tk = "model.diffusion_model.output_blocks.{}.0".format(b)
|
tk = "diffusion_model.output_blocks.{}.0".format(b)
|
||||||
key_in = False
|
key_in = False
|
||||||
for c in LORA_UNET_MAP_RESNET:
|
for c in LORA_UNET_MAP_RESNET:
|
||||||
k = "{}.{}.weight".format(tk, c)
|
k = "{}.{}.weight".format(tk, c)
|
||||||
@ -332,7 +339,7 @@ class ModelPatcher:
|
|||||||
patch_list[i] = patch_list[i].to(device)
|
patch_list[i] = patch_list[i].to(device)
|
||||||
|
|
||||||
def model_dtype(self):
|
def model_dtype(self):
|
||||||
return self.model.diffusion_model.dtype
|
return self.model.get_dtype()
|
||||||
|
|
||||||
def add_patches(self, patches, strength=1.0):
|
def add_patches(self, patches, strength=1.0):
|
||||||
p = {}
|
p = {}
|
||||||
@ -764,7 +771,7 @@ def load_controlnet(ckpt_path, model=None):
|
|||||||
for x in controlnet_data:
|
for x in controlnet_data:
|
||||||
c_m = "control_model."
|
c_m = "control_model."
|
||||||
if x.startswith(c_m):
|
if x.startswith(c_m):
|
||||||
sd_key = "model.diffusion_model.{}".format(x[len(c_m):])
|
sd_key = "diffusion_model.{}".format(x[len(c_m):])
|
||||||
if sd_key in model_sd:
|
if sd_key in model_sd:
|
||||||
cd = controlnet_data[x]
|
cd = controlnet_data[x]
|
||||||
cd += model_sd[sd_key].type(cd.dtype).to(cd.device)
|
cd += model_sd[sd_key].type(cd.dtype).to(cd.device)
|
||||||
@ -931,7 +938,8 @@ def load_gligen(ckpt_path):
|
|||||||
model = model.half()
|
model = model.half()
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=None):
|
def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_clip=True, embedding_directory=None, state_dict=None, config=None):
|
||||||
|
if config is None:
|
||||||
with open(config_path, 'r') as stream:
|
with open(config_path, 'r') as stream:
|
||||||
config = yaml.safe_load(stream)
|
config = yaml.safe_load(stream)
|
||||||
model_config_params = config['model']['params']
|
model_config_params = config['model']['params']
|
||||||
@ -942,8 +950,19 @@ def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, e
|
|||||||
fp16 = False
|
fp16 = False
|
||||||
if "unet_config" in model_config_params:
|
if "unet_config" in model_config_params:
|
||||||
if "params" in model_config_params["unet_config"]:
|
if "params" in model_config_params["unet_config"]:
|
||||||
if "use_fp16" in model_config_params["unet_config"]["params"]:
|
unet_config = model_config_params["unet_config"]["params"]
|
||||||
fp16 = model_config_params["unet_config"]["params"]["use_fp16"]
|
if "use_fp16" in unet_config:
|
||||||
|
fp16 = unet_config["use_fp16"]
|
||||||
|
|
||||||
|
noise_aug_config = None
|
||||||
|
if "noise_aug_config" in model_config_params:
|
||||||
|
noise_aug_config = model_config_params["noise_aug_config"]
|
||||||
|
|
||||||
|
v_prediction = False
|
||||||
|
|
||||||
|
if "parameterization" in model_config_params:
|
||||||
|
if model_config_params["parameterization"] == "v":
|
||||||
|
v_prediction = True
|
||||||
|
|
||||||
clip = None
|
clip = None
|
||||||
vae = None
|
vae = None
|
||||||
@ -963,9 +982,16 @@ def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, e
|
|||||||
w.cond_stage_model = clip.cond_stage_model
|
w.cond_stage_model = clip.cond_stage_model
|
||||||
load_state_dict_to = [w]
|
load_state_dict_to = [w]
|
||||||
|
|
||||||
model = instantiate_from_config(config["model"])
|
if config['model']["target"].endswith("LatentInpaintDiffusion"):
|
||||||
sd = utils.load_torch_file(ckpt_path)
|
model = model_base.SDInpaint(unet_config, v_prediction=v_prediction)
|
||||||
model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to)
|
elif config['model']["target"].endswith("ImageEmbeddingConditionedLatentDiffusion"):
|
||||||
|
model = model_base.SD21UNCLIP(unet_config, noise_aug_config["params"], v_prediction=v_prediction)
|
||||||
|
else:
|
||||||
|
model = model_base.BaseModel(unet_config, v_prediction=v_prediction)
|
||||||
|
|
||||||
|
if state_dict is None:
|
||||||
|
state_dict = utils.load_torch_file(ckpt_path)
|
||||||
|
model = load_model_weights(model, state_dict, verbose=False, load_state_dict_to=load_state_dict_to)
|
||||||
|
|
||||||
if fp16:
|
if fp16:
|
||||||
model = model.half()
|
model = model.half()
|
||||||
@ -1073,16 +1099,20 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
sd_config["unet_config"] = {"target": "comfy.ldm.modules.diffusionmodules.openaimodel.UNetModel", "params": unet_config}
|
sd_config["unet_config"] = {"target": "comfy.ldm.modules.diffusionmodules.openaimodel.UNetModel", "params": unet_config}
|
||||||
model_config = {"target": "comfy.ldm.models.diffusion.ddpm.LatentDiffusion", "params": sd_config}
|
model_config = {"target": "comfy.ldm.models.diffusion.ddpm.LatentDiffusion", "params": sd_config}
|
||||||
|
|
||||||
|
unclip_model = False
|
||||||
|
inpaint_model = False
|
||||||
if noise_aug_config is not None: #SD2.x unclip model
|
if noise_aug_config is not None: #SD2.x unclip model
|
||||||
sd_config["noise_aug_config"] = noise_aug_config
|
sd_config["noise_aug_config"] = noise_aug_config
|
||||||
sd_config["image_size"] = 96
|
sd_config["image_size"] = 96
|
||||||
sd_config["embedding_dropout"] = 0.25
|
sd_config["embedding_dropout"] = 0.25
|
||||||
sd_config["conditioning_key"] = 'crossattn-adm'
|
sd_config["conditioning_key"] = 'crossattn-adm'
|
||||||
|
unclip_model = True
|
||||||
model_config["target"] = "comfy.ldm.models.diffusion.ddpm.ImageEmbeddingConditionedLatentDiffusion"
|
model_config["target"] = "comfy.ldm.models.diffusion.ddpm.ImageEmbeddingConditionedLatentDiffusion"
|
||||||
elif unet_config["in_channels"] > 4: #inpainting model
|
elif unet_config["in_channels"] > 4: #inpainting model
|
||||||
sd_config["conditioning_key"] = "hybrid"
|
sd_config["conditioning_key"] = "hybrid"
|
||||||
sd_config["finetune_keys"] = None
|
sd_config["finetune_keys"] = None
|
||||||
model_config["target"] = "comfy.ldm.models.diffusion.ddpm.LatentInpaintDiffusion"
|
model_config["target"] = "comfy.ldm.models.diffusion.ddpm.LatentInpaintDiffusion"
|
||||||
|
inpaint_model = True
|
||||||
else:
|
else:
|
||||||
sd_config["conditioning_key"] = "crossattn"
|
sd_config["conditioning_key"] = "crossattn"
|
||||||
|
|
||||||
@ -1096,13 +1126,21 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
unet_config["num_classes"] = "sequential"
|
unet_config["num_classes"] = "sequential"
|
||||||
unet_config["adm_in_channels"] = sd[unclip].shape[1]
|
unet_config["adm_in_channels"] = sd[unclip].shape[1]
|
||||||
|
|
||||||
|
v_prediction = False
|
||||||
if unet_config["context_dim"] == 1024 and unet_config["in_channels"] == 4: #only SD2.x non inpainting models are v prediction
|
if unet_config["context_dim"] == 1024 and unet_config["in_channels"] == 4: #only SD2.x non inpainting models are v prediction
|
||||||
k = "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1.bias"
|
k = "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1.bias"
|
||||||
out = sd[k]
|
out = sd[k]
|
||||||
if torch.std(out, unbiased=False) > 0.09: # not sure how well this will actually work. I guess we will find out.
|
if torch.std(out, unbiased=False) > 0.09: # not sure how well this will actually work. I guess we will find out.
|
||||||
|
v_prediction = True
|
||||||
sd_config["parameterization"] = 'v'
|
sd_config["parameterization"] = 'v'
|
||||||
|
|
||||||
model = instantiate_from_config(model_config)
|
if inpaint_model:
|
||||||
|
model = model_base.SDInpaint(unet_config, v_prediction=v_prediction)
|
||||||
|
elif unclip_model:
|
||||||
|
model = model_base.SD21UNCLIP(unet_config, noise_aug_config["params"], v_prediction=v_prediction)
|
||||||
|
else:
|
||||||
|
model = model_base.BaseModel(unet_config, v_prediction=v_prediction)
|
||||||
|
|
||||||
model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to)
|
model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to)
|
||||||
|
|
||||||
if fp16:
|
if fp16:
|
||||||
|
Loading…
Reference in New Issue
Block a user