mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Auto detect out_channels from model.
This commit is contained in:
parent
79f73a4b33
commit
a47f609f90
@ -34,7 +34,6 @@ def detect_unet_config(state_dict, key_prefix, dtype):
|
|||||||
unet_config = {
|
unet_config = {
|
||||||
"use_checkpoint": False,
|
"use_checkpoint": False,
|
||||||
"image_size": 32,
|
"image_size": 32,
|
||||||
"out_channels": 4,
|
|
||||||
"use_spatial_transformer": True,
|
"use_spatial_transformer": True,
|
||||||
"legacy": False
|
"legacy": False
|
||||||
}
|
}
|
||||||
@ -49,6 +48,7 @@ def detect_unet_config(state_dict, key_prefix, dtype):
|
|||||||
unet_config["dtype"] = dtype
|
unet_config["dtype"] = dtype
|
||||||
model_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[0]
|
model_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[0]
|
||||||
in_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[1]
|
in_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[1]
|
||||||
|
out_channels = state_dict['{}out.2.weight'.format(key_prefix)].shape[0]
|
||||||
|
|
||||||
num_res_blocks = []
|
num_res_blocks = []
|
||||||
channel_mult = []
|
channel_mult = []
|
||||||
@ -122,6 +122,7 @@ def detect_unet_config(state_dict, key_prefix, dtype):
|
|||||||
transformer_depth_middle = -1
|
transformer_depth_middle = -1
|
||||||
|
|
||||||
unet_config["in_channels"] = in_channels
|
unet_config["in_channels"] = in_channels
|
||||||
|
unet_config["out_channels"] = out_channels
|
||||||
unet_config["model_channels"] = model_channels
|
unet_config["model_channels"] = model_channels
|
||||||
unet_config["num_res_blocks"] = num_res_blocks
|
unet_config["num_res_blocks"] = num_res_blocks
|
||||||
unet_config["transformer_depth"] = transformer_depth
|
unet_config["transformer_depth"] = transformer_depth
|
||||||
|
Loading…
Reference in New Issue
Block a user