diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 8d4fb7b6..8fca6d8c 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -46,6 +46,18 @@ def detect_unet_config(state_dict, key_prefix): unet_config['c_cond'] = 2048 elif '{}clip_mapper.weight'.format(key_prefix) in state_dict_keys: unet_config['stable_cascade_stage'] = 'b' + w = state_dict['{}down_blocks.1.0.channelwise.0.weight'.format(key_prefix)] + if w.shape[-1] == 640: + unet_config['c_hidden'] = [320, 640, 1280, 1280] + unet_config['nhead'] = [-1, -1, 20, 20] + unet_config['blocks'] = [[2, 6, 28, 6], [6, 28, 6, 2]] + unet_config['block_repeat'] = [[1, 1, 1, 1], [3, 3, 2, 2]] + elif w.shape[-1] == 576: #stage b lite + unet_config['c_hidden'] = [320, 576, 1152, 1152] + unet_config['nhead'] = [-1, 9, 18, 18] + unet_config['blocks'] = [[2, 4, 14, 4], [4, 14, 4, 2]] + unet_config['block_repeat'] = [[1, 1, 1, 1], [2, 2, 2, 2]] + return unet_config unet_config = {