SD2.x controlnets now work.

This commit is contained in:
comfyanonymous 2023-03-08 01:13:38 -05:00
parent 5ea597710f
commit c70f0ac64b

View File

@ -489,21 +489,39 @@ def load_controlnet(ckpt_path, model=None):
if model_management.should_use_fp16() and controlnet_data[key].dtype == torch.float16: if model_management.should_use_fp16() and controlnet_data[key].dtype == torch.float16:
use_fp16 = True use_fp16 = True
control_model = cldm.ControlNet(image_size=32, if context_dim == 768:
in_channels=4, #SD1.x
hint_channels=3, control_model = cldm.ControlNet(image_size=32,
model_channels=320, in_channels=4,
attention_resolutions=[ 4, 2, 1 ], hint_channels=3,
num_res_blocks=2, model_channels=320,
channel_mult=[ 1, 2, 4, 4 ], attention_resolutions=[ 4, 2, 1 ],
num_heads=8, num_res_blocks=2,
use_spatial_transformer=True, channel_mult=[ 1, 2, 4, 4 ],
transformer_depth=1, num_heads=8,
context_dim=context_dim, use_spatial_transformer=True,
use_checkpoint=True, transformer_depth=1,
legacy=False, context_dim=context_dim,
use_fp16=use_fp16) use_checkpoint=True,
legacy=False,
use_fp16=use_fp16)
else:
#SD2.x
control_model = cldm.ControlNet(image_size=32,
in_channels=4,
hint_channels=3,
model_channels=320,
attention_resolutions=[ 4, 2, 1 ],
num_res_blocks=2,
channel_mult=[ 1, 2, 4, 4 ],
num_head_channels=64,
use_spatial_transformer=True,
use_linear_in_transformer=True,
transformer_depth=1,
context_dim=context_dim,
use_checkpoint=True,
legacy=False,
use_fp16=use_fp16)
if pth: if pth:
if 'difference' in controlnet_data: if 'difference' in controlnet_data:
if model is not None: if model is not None: