diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 0c8cd30c..7b202b7a 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -148,7 +148,7 @@ class ControlBase: elif self.strength_type == StrengthType.LINEAR_UP: x *= (self.strength ** float(len(control_output) - i)) - if x.dtype != output_dtype: + if output_dtype is not None and x.dtype != output_dtype: x = x.to(output_dtype) out[key].append(x) @@ -206,7 +206,6 @@ class ControlNet(ControlBase): if self.manual_cast_dtype is not None: dtype = self.manual_cast_dtype - output_dtype = x_noisy.dtype if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]: if self.cond_hint is not None: del self.cond_hint @@ -236,7 +235,7 @@ class ControlNet(ControlBase): x_noisy = self.model_sampling_current.calculate_input(t, x_noisy) control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=context.to(dtype), **extra) - return self.control_merge(control, control_prev, output_dtype) + return self.control_merge(control, control_prev, output_dtype=None) def copy(self): c = ControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype) diff --git a/comfy/ldm/flux/controlnet.py b/comfy/ldm/flux/controlnet.py index 0e160b07..2c658a4b 100644 --- a/comfy/ldm/flux/controlnet.py +++ b/comfy/ldm/flux/controlnet.py @@ -23,8 +23,12 @@ class ControlNetFlux(Flux): self.controlnet_blocks = nn.ModuleList([]) for _ in range(self.params.depth): controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device) - # controlnet_block = zero_module(controlnet_block) self.controlnet_blocks.append(controlnet_block) + + self.controlnet_single_blocks = nn.ModuleList([]) + for _ in range(self.params.depth_single_blocks): + self.controlnet_single_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)) + self.gradient_checkpointing = False self.latent_input = latent_input self.pos_embed_input = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device) @@ -78,26 +82,39 @@ class ControlNetFlux(Flux): ids = torch.cat((txt_ids, img_ids), dim=1) pe = self.pe_embedder(ids) - block_res_samples = () + controlnet_double = () - for block in self.double_blocks: - img, txt = block(img=img, txt=txt, vec=vec, pe=pe) - block_res_samples = block_res_samples + (img,) + for i in range(len(self.double_blocks)): + img, txt = self.double_blocks[i](img=img, txt=txt, vec=vec, pe=pe) + controlnet_double = controlnet_double + (self.controlnet_blocks[i](img),) - controlnet_block_res_samples = () - for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks): - block_res_sample = controlnet_block(block_res_sample) - controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,) + img = torch.cat((txt, img), 1) + controlnet_single = () - repeat = math.ceil(self.main_model_double / len(controlnet_block_res_samples)) + for i in range(len(self.single_blocks)): + img = self.single_blocks[i](img, vec=vec, pe=pe) + controlnet_single = controlnet_single + (self.controlnet_single_blocks[i](img[:, txt.shape[1] :, ...]),) + + repeat = math.ceil(self.main_model_double / len(controlnet_double)) if self.latent_input: out_input = () - for x in controlnet_block_res_samples: + for x in controlnet_double: out_input += (x,) * repeat else: - out_input = (controlnet_block_res_samples * repeat) - return {"input": out_input[:self.main_model_double]} + out_input = (controlnet_double * repeat) + + out = {"input": out_input[:self.main_model_double]} + if len(controlnet_single) > 0: + repeat = math.ceil(self.main_model_single / len(controlnet_single)) + out_output = () + if self.latent_input: + for x in controlnet_single: + out_output += (x,) * repeat + else: + out_output = (controlnet_single * repeat) + out["output"] = out_output[:self.main_model_single] + return out def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs): patch_size = 2