mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Basic SD3 controlnet implementation.
Still missing the node to properly use it.
This commit is contained in:
parent
66aaa14001
commit
f8f7568d03
91
comfy/cldm/mmdit.py
Normal file
91
comfy/cldm/mmdit.py
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
import torch
|
||||||
|
from typing import Dict, Optional
|
||||||
|
import comfy.ldm.modules.diffusionmodules.mmdit
|
||||||
|
import comfy.latent_formats
|
||||||
|
|
||||||
|
class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_blocks = None,
|
||||||
|
dtype = None,
|
||||||
|
device = None,
|
||||||
|
operations = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(dtype=dtype, device=device, operations=operations, final_layer=False, num_blocks=num_blocks, **kwargs)
|
||||||
|
# controlnet_blocks
|
||||||
|
self.controlnet_blocks = torch.nn.ModuleList([])
|
||||||
|
for _ in range(len(self.joint_blocks)):
|
||||||
|
self.controlnet_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, device=device, dtype=dtype))
|
||||||
|
|
||||||
|
self.pos_embed_input = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(
|
||||||
|
None,
|
||||||
|
self.patch_size,
|
||||||
|
self.in_channels,
|
||||||
|
self.hidden_size,
|
||||||
|
bias=True,
|
||||||
|
strict_img_size=False,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations
|
||||||
|
)
|
||||||
|
|
||||||
|
self.latent_format = comfy.latent_formats.SD3()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
timesteps: torch.Tensor,
|
||||||
|
y: Optional[torch.Tensor] = None,
|
||||||
|
context: Optional[torch.Tensor] = None,
|
||||||
|
hint = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
#weird sd3 controlnet specific stuff
|
||||||
|
hint = hint * self.latent_format.scale_factor # self.latent_format.process_in(hint)
|
||||||
|
y = torch.zeros_like(y)
|
||||||
|
|
||||||
|
|
||||||
|
if self.context_processor is not None:
|
||||||
|
context = self.context_processor(context)
|
||||||
|
|
||||||
|
hw = x.shape[-2:]
|
||||||
|
x = self.x_embedder(x) + self.cropped_pos_embed(hw, device=x.device).to(dtype=x.dtype, device=x.device)
|
||||||
|
x += self.pos_embed_input(hint)
|
||||||
|
|
||||||
|
c = self.t_embedder(timesteps, dtype=x.dtype)
|
||||||
|
if y is not None and self.y_embedder is not None:
|
||||||
|
y = self.y_embedder(y)
|
||||||
|
c = c + y
|
||||||
|
|
||||||
|
if context is not None:
|
||||||
|
context = self.context_embedder(context)
|
||||||
|
|
||||||
|
if self.register_length > 0:
|
||||||
|
context = torch.cat(
|
||||||
|
(
|
||||||
|
repeat(self.register, "1 ... -> b ...", b=x.shape[0]),
|
||||||
|
default(context, torch.Tensor([]).type_as(x)),
|
||||||
|
),
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
|
||||||
|
output = []
|
||||||
|
|
||||||
|
blocks = len(self.joint_blocks)
|
||||||
|
for i in range(blocks):
|
||||||
|
context, x = self.joint_blocks[i](
|
||||||
|
context,
|
||||||
|
x,
|
||||||
|
c=c,
|
||||||
|
use_checkpoint=self.use_checkpoint,
|
||||||
|
)
|
||||||
|
|
||||||
|
out = self.controlnet_blocks[i](x)
|
||||||
|
count = self.depth // blocks
|
||||||
|
if i == blocks - 1:
|
||||||
|
count -= 1
|
||||||
|
for j in range(count):
|
||||||
|
output.append(out)
|
||||||
|
|
||||||
|
return {"output": output}
|
@ -11,6 +11,7 @@ import comfy.ops
|
|||||||
import comfy.cldm.cldm
|
import comfy.cldm.cldm
|
||||||
import comfy.t2i_adapter.adapter
|
import comfy.t2i_adapter.adapter
|
||||||
import comfy.ldm.cascade.controlnet
|
import comfy.ldm.cascade.controlnet
|
||||||
|
import comfy.cldm.mmdit
|
||||||
|
|
||||||
|
|
||||||
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
||||||
@ -94,13 +95,17 @@ class ControlBase:
|
|||||||
|
|
||||||
for key in control:
|
for key in control:
|
||||||
control_output = control[key]
|
control_output = control[key]
|
||||||
|
applied_to = set()
|
||||||
for i in range(len(control_output)):
|
for i in range(len(control_output)):
|
||||||
x = control_output[i]
|
x = control_output[i]
|
||||||
if x is not None:
|
if x is not None:
|
||||||
if self.global_average_pooling:
|
if self.global_average_pooling:
|
||||||
x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3])
|
x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3])
|
||||||
|
|
||||||
x *= self.strength
|
if x not in applied_to: #memory saving strategy, allow shared tensors and only apply strength to shared tensors once
|
||||||
|
applied_to.add(x)
|
||||||
|
x *= self.strength
|
||||||
|
|
||||||
if x.dtype != output_dtype:
|
if x.dtype != output_dtype:
|
||||||
x = x.to(output_dtype)
|
x = x.to(output_dtype)
|
||||||
|
|
||||||
@ -120,17 +125,18 @@ class ControlBase:
|
|||||||
if o[i].shape[0] < prev_val.shape[0]:
|
if o[i].shape[0] < prev_val.shape[0]:
|
||||||
o[i] = prev_val + o[i]
|
o[i] = prev_val + o[i]
|
||||||
else:
|
else:
|
||||||
o[i] += prev_val
|
o[i] = prev_val + o[i] #TODO: change back to inplace add if shared tensors stop being an issue
|
||||||
return out
|
return out
|
||||||
|
|
||||||
class ControlNet(ControlBase):
|
class ControlNet(ControlBase):
|
||||||
def __init__(self, control_model=None, global_average_pooling=False, device=None, load_device=None, manual_cast_dtype=None):
|
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, device=None, load_device=None, manual_cast_dtype=None):
|
||||||
super().__init__(device)
|
super().__init__(device)
|
||||||
self.control_model = control_model
|
self.control_model = control_model
|
||||||
self.load_device = load_device
|
self.load_device = load_device
|
||||||
if control_model is not None:
|
if control_model is not None:
|
||||||
self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
|
self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
|
||||||
|
|
||||||
|
self.compression_ratio = compression_ratio
|
||||||
self.global_average_pooling = global_average_pooling
|
self.global_average_pooling = global_average_pooling
|
||||||
self.model_sampling_current = None
|
self.model_sampling_current = None
|
||||||
self.manual_cast_dtype = manual_cast_dtype
|
self.manual_cast_dtype = manual_cast_dtype
|
||||||
@ -308,6 +314,37 @@ class ControlLora(ControlNet):
|
|||||||
def inference_memory_requirements(self, dtype):
|
def inference_memory_requirements(self, dtype):
|
||||||
return comfy.utils.calculate_parameters(self.control_weights) * comfy.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)
|
return comfy.utils.calculate_parameters(self.control_weights) * comfy.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)
|
||||||
|
|
||||||
|
def load_controlnet_mmdit(sd):
|
||||||
|
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
|
||||||
|
model_config = comfy.model_detection.model_config_from_unet(new_sd, "", True)
|
||||||
|
num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
|
||||||
|
for k in sd:
|
||||||
|
new_sd[k] = sd[k]
|
||||||
|
|
||||||
|
supported_inference_dtypes = model_config.supported_inference_dtypes
|
||||||
|
|
||||||
|
controlnet_config = model_config.unet_config
|
||||||
|
unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
|
||||||
|
load_device = comfy.model_management.get_torch_device()
|
||||||
|
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
||||||
|
if manual_cast_dtype is not None:
|
||||||
|
operations = comfy.ops.manual_cast
|
||||||
|
else:
|
||||||
|
operations = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
|
control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=load_device, dtype=unet_dtype, **controlnet_config)
|
||||||
|
missing, unexpected = control_model.load_state_dict(new_sd, strict=False)
|
||||||
|
|
||||||
|
if len(missing) > 0:
|
||||||
|
logging.warning("missing controlnet keys: {}".format(missing))
|
||||||
|
|
||||||
|
if len(unexpected) > 0:
|
||||||
|
logging.debug("unexpected controlnet keys: {}".format(unexpected))
|
||||||
|
|
||||||
|
control = ControlNet(control_model, compression_ratio=1, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
||||||
|
return control
|
||||||
|
|
||||||
|
|
||||||
def load_controlnet(ckpt_path, model=None):
|
def load_controlnet(ckpt_path, model=None):
|
||||||
controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
|
controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
|
||||||
if "lora_controlnet" in controlnet_data:
|
if "lora_controlnet" in controlnet_data:
|
||||||
@ -360,6 +397,8 @@ def load_controlnet(ckpt_path, model=None):
|
|||||||
if len(leftover_keys) > 0:
|
if len(leftover_keys) > 0:
|
||||||
logging.warning("leftover keys: {}".format(leftover_keys))
|
logging.warning("leftover keys: {}".format(leftover_keys))
|
||||||
controlnet_data = new_sd
|
controlnet_data = new_sd
|
||||||
|
elif "controlnet_blocks.0.weight" in controlnet_data: #SD3 diffusers format
|
||||||
|
return load_controlnet_mmdit(controlnet_data)
|
||||||
|
|
||||||
pth_key = 'control_model.zero_convs.0.0.weight'
|
pth_key = 'control_model.zero_convs.0.0.weight'
|
||||||
pth = False
|
pth = False
|
||||||
|
@ -745,6 +745,8 @@ class MMDiT(nn.Module):
|
|||||||
qkv_bias: bool = True,
|
qkv_bias: bool = True,
|
||||||
context_processor_layers = None,
|
context_processor_layers = None,
|
||||||
context_size = 4096,
|
context_size = 4096,
|
||||||
|
num_blocks = None,
|
||||||
|
final_layer = True,
|
||||||
dtype = None, #TODO
|
dtype = None, #TODO
|
||||||
device = None,
|
device = None,
|
||||||
operations = None,
|
operations = None,
|
||||||
@ -766,7 +768,10 @@ class MMDiT(nn.Module):
|
|||||||
# apply magic --> this defines a head_size of 64
|
# apply magic --> this defines a head_size of 64
|
||||||
self.hidden_size = 64 * depth
|
self.hidden_size = 64 * depth
|
||||||
num_heads = depth
|
num_heads = depth
|
||||||
|
if num_blocks is None:
|
||||||
|
num_blocks = depth
|
||||||
|
|
||||||
|
self.depth = depth
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
|
|
||||||
self.x_embedder = PatchEmbed(
|
self.x_embedder = PatchEmbed(
|
||||||
@ -821,7 +826,7 @@ class MMDiT(nn.Module):
|
|||||||
mlp_ratio=mlp_ratio,
|
mlp_ratio=mlp_ratio,
|
||||||
qkv_bias=qkv_bias,
|
qkv_bias=qkv_bias,
|
||||||
attn_mode=attn_mode,
|
attn_mode=attn_mode,
|
||||||
pre_only=i == depth - 1,
|
pre_only=(i == num_blocks - 1) and final_layer,
|
||||||
rmsnorm=rmsnorm,
|
rmsnorm=rmsnorm,
|
||||||
scale_mod_only=scale_mod_only,
|
scale_mod_only=scale_mod_only,
|
||||||
swiglu=swiglu,
|
swiglu=swiglu,
|
||||||
@ -830,11 +835,12 @@ class MMDiT(nn.Module):
|
|||||||
device=device,
|
device=device,
|
||||||
operations=operations
|
operations=operations
|
||||||
)
|
)
|
||||||
for i in range(depth)
|
for i in range(num_blocks)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
self.final_layer = FinalLayer(self.hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations)
|
if final_layer:
|
||||||
|
self.final_layer = FinalLayer(self.hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
if compile_core:
|
if compile_core:
|
||||||
assert False
|
assert False
|
||||||
@ -893,6 +899,7 @@ class MMDiT(nn.Module):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
c_mod: torch.Tensor,
|
c_mod: torch.Tensor,
|
||||||
context: Optional[torch.Tensor] = None,
|
context: Optional[torch.Tensor] = None,
|
||||||
|
control = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if self.register_length > 0:
|
if self.register_length > 0:
|
||||||
context = torch.cat(
|
context = torch.cat(
|
||||||
@ -905,13 +912,20 @@ class MMDiT(nn.Module):
|
|||||||
|
|
||||||
# context is B, L', D
|
# context is B, L', D
|
||||||
# x is B, L, D
|
# x is B, L, D
|
||||||
for block in self.joint_blocks:
|
blocks = len(self.joint_blocks)
|
||||||
context, x = block(
|
for i in range(blocks):
|
||||||
|
context, x = self.joint_blocks[i](
|
||||||
context,
|
context,
|
||||||
x,
|
x,
|
||||||
c=c_mod,
|
c=c_mod,
|
||||||
use_checkpoint=self.use_checkpoint,
|
use_checkpoint=self.use_checkpoint,
|
||||||
)
|
)
|
||||||
|
if control is not None:
|
||||||
|
control_o = control.get("output")
|
||||||
|
if i < len(control_o):
|
||||||
|
add = control_o[i]
|
||||||
|
if add is not None:
|
||||||
|
x += add
|
||||||
|
|
||||||
x = self.final_layer(x, c_mod) # (N, T, patch_size ** 2 * out_channels)
|
x = self.final_layer(x, c_mod) # (N, T, patch_size ** 2 * out_channels)
|
||||||
return x
|
return x
|
||||||
@ -922,6 +936,7 @@ class MMDiT(nn.Module):
|
|||||||
t: torch.Tensor,
|
t: torch.Tensor,
|
||||||
y: Optional[torch.Tensor] = None,
|
y: Optional[torch.Tensor] = None,
|
||||||
context: Optional[torch.Tensor] = None,
|
context: Optional[torch.Tensor] = None,
|
||||||
|
control = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Forward pass of DiT.
|
Forward pass of DiT.
|
||||||
@ -943,7 +958,7 @@ class MMDiT(nn.Module):
|
|||||||
if context is not None:
|
if context is not None:
|
||||||
context = self.context_embedder(context)
|
context = self.context_embedder(context)
|
||||||
|
|
||||||
x = self.forward_core_with_concat(x, c, context)
|
x = self.forward_core_with_concat(x, c, context, control)
|
||||||
|
|
||||||
x = self.unpatchify(x, hw=hw) # (N, out_channels, H, W)
|
x = self.unpatchify(x, hw=hw) # (N, out_channels, H, W)
|
||||||
return x[:,:,:hw[-2],:hw[-1]]
|
return x[:,:,:hw[-2],:hw[-1]]
|
||||||
@ -956,7 +971,8 @@ class OpenAISignatureMMDITWrapper(MMDiT):
|
|||||||
timesteps: torch.Tensor,
|
timesteps: torch.Tensor,
|
||||||
context: Optional[torch.Tensor] = None,
|
context: Optional[torch.Tensor] = None,
|
||||||
y: Optional[torch.Tensor] = None,
|
y: Optional[torch.Tensor] = None,
|
||||||
|
control = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return super().forward(x, timesteps, context=context, y=y)
|
return super().forward(x, timesteps, context=context, y=y, control=control)
|
||||||
|
|
||||||
|
@ -41,7 +41,9 @@ def detect_unet_config(state_dict, key_prefix):
|
|||||||
unet_config["in_channels"] = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[1]
|
unet_config["in_channels"] = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[1]
|
||||||
patch_size = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[2]
|
patch_size = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[2]
|
||||||
unet_config["patch_size"] = patch_size
|
unet_config["patch_size"] = patch_size
|
||||||
unet_config["out_channels"] = state_dict['{}final_layer.linear.weight'.format(key_prefix)].shape[0] // (patch_size * patch_size)
|
final_layer = '{}final_layer.linear.weight'.format(key_prefix)
|
||||||
|
if final_layer in state_dict:
|
||||||
|
unet_config["out_channels"] = state_dict[final_layer].shape[0] // (patch_size * patch_size)
|
||||||
|
|
||||||
unet_config["depth"] = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[0] // 64
|
unet_config["depth"] = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[0] // 64
|
||||||
unet_config["input_size"] = None
|
unet_config["input_size"] = None
|
||||||
@ -435,10 +437,11 @@ def model_config_from_diffusers_unet(state_dict):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def convert_diffusers_mmdit(state_dict, output_prefix=""):
|
def convert_diffusers_mmdit(state_dict, output_prefix=""):
|
||||||
depth = count_blocks(state_dict, 'transformer_blocks.{}.')
|
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
|
||||||
if depth > 0:
|
if num_blocks > 0:
|
||||||
|
depth = state_dict["pos_embed.proj.weight"].shape[0] // 64
|
||||||
out_sd = {}
|
out_sd = {}
|
||||||
sd_map = comfy.utils.mmdit_to_diffusers({"depth": depth}, output_prefix=output_prefix)
|
sd_map = comfy.utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix)
|
||||||
for k in sd_map:
|
for k in sd_map:
|
||||||
weight = state_dict.get(k, None)
|
weight = state_dict.get(k, None)
|
||||||
if weight is not None:
|
if weight is not None:
|
||||||
|
@ -298,7 +298,8 @@ def mmdit_to_diffusers(mmdit_config, output_prefix=""):
|
|||||||
key_map = {}
|
key_map = {}
|
||||||
|
|
||||||
depth = mmdit_config.get("depth", 0)
|
depth = mmdit_config.get("depth", 0)
|
||||||
for i in range(depth):
|
num_blocks = mmdit_config.get("num_blocks", depth)
|
||||||
|
for i in range(num_blocks):
|
||||||
block_from = "transformer_blocks.{}".format(i)
|
block_from = "transformer_blocks.{}".format(i)
|
||||||
block_to = "{}joint_blocks.{}".format(output_prefix, i)
|
block_to = "{}joint_blocks.{}".format(output_prefix, i)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user