diff --git a/comfy/cldm/cldm.py b/comfy/cldm/cldm.py index d4d32b87..4a58c823 100644 --- a/comfy/cldm/cldm.py +++ b/comfy/cldm/cldm.py @@ -92,7 +92,7 @@ class ControlNet(nn.Module): transformer_depth_middle=None, transformer_depth_output=None, attn_precision=None, - union_controlnet=False, + union_controlnet_num_control_type=None, device=None, operations=comfy.ops.disable_weight_init, **kwargs, @@ -320,8 +320,8 @@ class ControlNet(nn.Module): self.middle_block_out = self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device) self._feature_size += ch - if union_controlnet: - self.num_control_type = 6 + if union_controlnet_num_control_type is not None: + self.num_control_type = union_controlnet_num_control_type num_trans_channel = 320 num_trans_head = 8 num_trans_layer = 1 diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 84286f1f..b8e27c71 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -414,7 +414,7 @@ def load_controlnet(ckpt_path, model=None): new_sd[diffusers_keys[k]] = controlnet_data.pop(k) if "control_add_embedding.linear_1.bias" in controlnet_data: #Union Controlnet - controlnet_config["union_controlnet"] = True + controlnet_config["union_controlnet_num_control_type"] = controlnet_data["task_embedding"].shape[0] for k in list(controlnet_data.keys()): new_k = k.replace('.attn.in_proj_', '.attn.in_proj.') new_sd[new_k] = controlnet_data.pop(k)