diff --git a/comfy/controlnet.py b/comfy/controlnet.py index c0f9b651..86089196 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -449,7 +449,9 @@ def load_controlnet_flux_instantx(sd): if union_cnet in new_sd: num_union_modes = new_sd[union_cnet].shape[0] - control_model = comfy.ldm.flux.controlnet.ControlNetFlux(latent_input=True, num_union_modes=num_union_modes, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config) + control_latent_channels = new_sd.get("pos_embed_input.weight").shape[1] // 4 + + control_model = comfy.ldm.flux.controlnet.ControlNetFlux(latent_input=True, num_union_modes=num_union_modes, control_latent_channels=control_latent_channels, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config) control_model = controlnet_load_state_dict(control_model, new_sd) latent_format = comfy.latent_formats.Flux() diff --git a/comfy/ldm/flux/controlnet.py b/comfy/ldm/flux/controlnet.py index d8b77612..c033dea5 100644 --- a/comfy/ldm/flux/controlnet.py +++ b/comfy/ldm/flux/controlnet.py @@ -52,7 +52,7 @@ class MistolineControlnetBlock(nn.Module): class ControlNetFlux(Flux): - def __init__(self, latent_input=False, num_union_modes=0, mistoline=False, image_model=None, dtype=None, device=None, operations=None, **kwargs): + def __init__(self, latent_input=False, num_union_modes=0, mistoline=False, control_latent_channels=None, image_model=None, dtype=None, device=None, operations=None, **kwargs): super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs) self.main_model_double = 19 @@ -80,7 +80,12 @@ class ControlNetFlux(Flux): self.gradient_checkpointing = False self.latent_input = latent_input - self.pos_embed_input = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device) + if control_latent_channels is None: + control_latent_channels = self.in_channels + else: + control_latent_channels *= 2 * 2 #patch size + + self.pos_embed_input = operations.Linear(control_latent_channels, self.hidden_size, bias=True, dtype=dtype, device=device) if not self.latent_input: if self.mistoline: self.input_cond_block = MistolineCondDownsamplBlock(dtype=dtype, device=device, operations=operations)