From 4248b1618ffd0878ae2502fd4633e30bcbd7554b Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 5 Jun 2025 07:07:17 -0700 Subject: [PATCH] Let chroma TE work on regular flux. (#8429) --- comfy/ldm/flux/controlnet.py | 5 ++++- comfy/ldm/flux/model.py | 6 +++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/comfy/ldm/flux/controlnet.py b/comfy/ldm/flux/controlnet.py index 5322c4891..dbd2a47c0 100644 --- a/comfy/ldm/flux/controlnet.py +++ b/comfy/ldm/flux/controlnet.py @@ -121,6 +121,9 @@ class ControlNetFlux(Flux): if img.ndim != 3 or txt.ndim != 3: raise ValueError("Input img and txt tensors must have 3 dimensions.") + if y is None: + y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype) + # running on sequences img img = self.img_in(img) @@ -174,7 +177,7 @@ class ControlNetFlux(Flux): out["output"] = out_output[:self.main_model_single] return out - def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs): + def forward(self, x, timesteps, context, y=None, guidance=None, hint=None, **kwargs): patch_size = 2 if self.latent_input: hint = comfy.ldm.common_dit.pad_to_patch_size(hint, (patch_size, patch_size)) diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index ef4ba4106..53f27e3a7 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -101,6 +101,10 @@ class Flux(nn.Module): transformer_options={}, attn_mask: Tensor = None, ) -> Tensor: + + if y is None: + y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype) + patches_replace = transformer_options.get("patches_replace", {}) if img.ndim != 3 or txt.ndim != 3: raise ValueError("Input img and txt tensors must have 3 dimensions.") @@ -188,7 +192,7 @@ class Flux(nn.Module): img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) return img - def forward(self, x, timestep, context, y, guidance=None, control=None, transformer_options={}, **kwargs): + def forward(self, x, timestep, context, y=None, guidance=None, control=None, transformer_options={}, **kwargs): bs, c, h, w = x.shape patch_size = self.patch_size x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))