Support VAEs in diffusers format.

This commit is contained in:
comfyanonymous 2023-05-28 02:02:09 -04:00
parent 0fc483dcfd
commit a532888846

View File

@ -14,6 +14,7 @@ from .t2i_adapter import adapter
from . import utils from . import utils
from . import clip_vision from . import clip_vision
from . import gligen from . import gligen
from . import diffusers_convert
def load_model_weights(model, sd, verbose=False, load_state_dict_to=[]): def load_model_weights(model, sd, verbose=False, load_state_dict_to=[]):
m, u = model.load_state_dict(sd, strict=False) m, u = model.load_state_dict(sd, strict=False)
@ -504,10 +505,16 @@ class VAE:
if config is None: if config is None:
#default SD1.x/SD2.x VAE parameters #default SD1.x/SD2.x VAE parameters
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
self.first_stage_model = AutoencoderKL(ddconfig, {'target': 'torch.nn.Identity'}, 4, monitor="val/rec_loss", ckpt_path=ckpt_path) self.first_stage_model = AutoencoderKL(ddconfig, {'target': 'torch.nn.Identity'}, 4, monitor="val/rec_loss")
else: else:
self.first_stage_model = AutoencoderKL(**(config['params']), ckpt_path=ckpt_path) self.first_stage_model = AutoencoderKL(**(config['params']))
self.first_stage_model = self.first_stage_model.eval() self.first_stage_model = self.first_stage_model.eval()
if ckpt_path is not None:
sd = utils.load_torch_file(ckpt_path)
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
sd = diffusers_convert.convert_vae_state_dict(sd)
self.first_stage_model.load_state_dict(sd, strict=False)
self.scale_factor = scale_factor self.scale_factor = scale_factor
if device is None: if device is None:
device = model_management.get_torch_device() device = model_management.get_torch_device()