mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Auto detect out_channels from model.
This commit is contained in:
parent
79f73a4b33
commit
a47f609f90
@ -34,7 +34,6 @@ def detect_unet_config(state_dict, key_prefix, dtype):
|
||||
unet_config = {
|
||||
"use_checkpoint": False,
|
||||
"image_size": 32,
|
||||
"out_channels": 4,
|
||||
"use_spatial_transformer": True,
|
||||
"legacy": False
|
||||
}
|
||||
@ -49,6 +48,7 @@ def detect_unet_config(state_dict, key_prefix, dtype):
|
||||
unet_config["dtype"] = dtype
|
||||
model_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[0]
|
||||
in_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[1]
|
||||
out_channels = state_dict['{}out.2.weight'.format(key_prefix)].shape[0]
|
||||
|
||||
num_res_blocks = []
|
||||
channel_mult = []
|
||||
@ -122,6 +122,7 @@ def detect_unet_config(state_dict, key_prefix, dtype):
|
||||
transformer_depth_middle = -1
|
||||
|
||||
unet_config["in_channels"] = in_channels
|
||||
unet_config["out_channels"] = out_channels
|
||||
unet_config["model_channels"] = model_channels
|
||||
unet_config["num_res_blocks"] = num_res_blocks
|
||||
unet_config["transformer_depth"] = transformer_depth
|
||||
|
Loading…
Reference in New Issue
Block a user