diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 354b0036..dcfe492c 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -34,6 +34,8 @@ import comfy.t2i_adapter.adapter import comfy.ldm.cascade.controlnet import comfy.cldm.mmdit import comfy.ldm.hydit.controlnet +import comfy.ldm.flux.controlnet_xlabs + def broadcast_image_to(tensor, target_batch_size, batched_number): current_batch_size = tensor.shape[0] @@ -416,6 +418,7 @@ def load_controlnet_mmdit(sd): control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype) return control + def load_controlnet_hunyuandit(controlnet_data): model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(controlnet_data) @@ -427,6 +430,15 @@ def load_controlnet_hunyuandit(controlnet_data): control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds, strength_type=StrengthType.CONSTANT) return control +def load_controlnet_flux_xlabs(sd): + model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(sd) + control_model = comfy.ldm.flux.controlnet_xlabs.ControlNetFlux(operations=operations, device=load_device, dtype=unet_dtype, **model_config.unet_config) + control_model = controlnet_load_state_dict(control_model, sd) + extra_conds = ['y', 'guidance'] + control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds) + return control + + def load_controlnet(ckpt_path, model=None): controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True) if 'after_proj_list.18.bias' in controlnet_data.keys(): #Hunyuan DiT @@ -489,7 +501,10 @@ def load_controlnet(ckpt_path, model=None): logging.warning("leftover keys: {}".format(leftover_keys)) controlnet_data = new_sd elif "controlnet_blocks.0.weight" in controlnet_data: #SD3 diffusers format - return load_controlnet_mmdit(controlnet_data) + if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data: + return load_controlnet_flux_xlabs(controlnet_data) + else: + return load_controlnet_mmdit(controlnet_data) pth_key = 'control_model.zero_convs.0.0.weight' pth = False diff --git a/comfy/ldm/flux/controlnet_xlabs.py b/comfy/ldm/flux/controlnet_xlabs.py new file mode 100644 index 00000000..3f40021b --- /dev/null +++ b/comfy/ldm/flux/controlnet_xlabs.py @@ -0,0 +1,104 @@ +#Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py + +import torch +from torch import Tensor, nn +from einops import rearrange, repeat + +from .layers import (DoubleStreamBlock, EmbedND, LastLayer, + MLPEmbedder, SingleStreamBlock, + timestep_embedding) + +from .model import Flux +import comfy.ldm.common_dit + + +class ControlNetFlux(Flux): + def __init__(self, image_model=None, dtype=None, device=None, operations=None, **kwargs): + super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs) + + # add ControlNet blocks + self.controlnet_blocks = nn.ModuleList([]) + for _ in range(self.params.depth): + controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device) + # controlnet_block = zero_module(controlnet_block) + self.controlnet_blocks.append(controlnet_block) + self.pos_embed_input = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device) + self.gradient_checkpointing = False + self.input_hint_block = nn.Sequential( + operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device) + ) + + def forward_orig( + self, + img: Tensor, + img_ids: Tensor, + controlnet_cond: Tensor, + txt: Tensor, + txt_ids: Tensor, + timesteps: Tensor, + y: Tensor, + guidance: Tensor = None, + ) -> Tensor: + if img.ndim != 3 or txt.ndim != 3: + raise ValueError("Input img and txt tensors must have 3 dimensions.") + + # running on sequences img + img = self.img_in(img) + controlnet_cond = self.input_hint_block(controlnet_cond) + controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + controlnet_cond = self.pos_embed_input(controlnet_cond) + img = img + controlnet_cond + vec = self.time_in(timestep_embedding(timesteps, 256)) + if self.params.guidance_embed: + vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) + vec = vec + self.vector_in(y) + txt = self.txt_in(txt) + + ids = torch.cat((txt_ids, img_ids), dim=1) + pe = self.pe_embedder(ids) + + block_res_samples = () + + for block in self.double_blocks: + img, txt = block(img=img, txt=txt, vec=vec, pe=pe) + block_res_samples = block_res_samples + (img,) + + controlnet_block_res_samples = () + for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks): + block_res_sample = controlnet_block(block_res_sample) + controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,) + + return {"output": (controlnet_block_res_samples * 10)[:19]} + + def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs): + hint = hint * 2.0 - 1.0 + + bs, c, h, w = x.shape + patch_size = 2 + x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size)) + + img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size) + + h_len = ((h + (patch_size // 2)) // patch_size) + w_len = ((w + (patch_size // 2)) // patch_size) + img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype) + img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :] + img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) + + txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) + return self.forward_orig(img, img_ids, hint, context, txt_ids, timesteps, y, guidance) diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index db6cf3d2..b5373540 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -38,7 +38,7 @@ class Flux(nn.Module): Transformer model for flow matching on sequences. """ - def __init__(self, image_model=None, dtype=None, device=None, operations=None, **kwargs): + def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs): super().__init__() self.dtype = dtype params = FluxParams(**kwargs) @@ -83,7 +83,8 @@ class Flux(nn.Module): ] ) - self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations) + if final_layer: + self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations) def forward_orig( self, @@ -94,6 +95,7 @@ class Flux(nn.Module): timesteps: Tensor, y: Tensor, guidance: Tensor = None, + control=None, ) -> Tensor: if img.ndim != 3 or txt.ndim != 3: raise ValueError("Input img and txt tensors must have 3 dimensions.") @@ -112,8 +114,15 @@ class Flux(nn.Module): ids = torch.cat((txt_ids, img_ids), dim=1) pe = self.pe_embedder(ids) - for block in self.double_blocks: - img, txt = block(img=img, txt=txt, vec=vec, pe=pe) + for i in range(len(self.double_blocks)): + img, txt = self.double_blocks[i](img=img, txt=txt, vec=vec, pe=pe) + + if control is not None: #Controlnet + control_o = control.get("output") + if i < len(control_o): + add = control_o[i] + if add is not None: + img += add img = torch.cat((txt, img), 1) for block in self.single_blocks: @@ -123,7 +132,7 @@ class Flux(nn.Module): img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) return img - def forward(self, x, timestep, context, y, guidance, **kwargs): + def forward(self, x, timestep, context, y, guidance, control=None, **kwargs): bs, c, h, w = x.shape patch_size = 2 x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size)) @@ -138,5 +147,5 @@ class Flux(nn.Module): img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) - out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance) + out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control) return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]