Auto detect out_channels from model.

This commit is contained in:
comfyanonymous 2024-01-02 01:50:57 -05:00
parent 79f73a4b33
commit a47f609f90

View File

@ -34,7 +34,6 @@ def detect_unet_config(state_dict, key_prefix, dtype):
unet_config = { unet_config = {
"use_checkpoint": False, "use_checkpoint": False,
"image_size": 32, "image_size": 32,
"out_channels": 4,
"use_spatial_transformer": True, "use_spatial_transformer": True,
"legacy": False "legacy": False
} }
@ -49,6 +48,7 @@ def detect_unet_config(state_dict, key_prefix, dtype):
unet_config["dtype"] = dtype unet_config["dtype"] = dtype
model_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[0] model_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[0]
in_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[1] in_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[1]
out_channels = state_dict['{}out.2.weight'.format(key_prefix)].shape[0]
num_res_blocks = [] num_res_blocks = []
channel_mult = [] channel_mult = []
@ -122,6 +122,7 @@ def detect_unet_config(state_dict, key_prefix, dtype):
transformer_depth_middle = -1 transformer_depth_middle = -1
unet_config["in_channels"] = in_channels unet_config["in_channels"] = in_channels
unet_config["out_channels"] = out_channels
unet_config["model_channels"] = model_channels unet_config["model_channels"] = model_channels
unet_config["num_res_blocks"] = num_res_blocks unet_config["num_res_blocks"] = num_res_blocks
unet_config["transformer_depth"] = transformer_depth unet_config["transformer_depth"] = transformer_depth