From 03e83bb5d052be16aaad7599ffeb1eade9481508 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 6 Mar 2024 02:25:42 -0500 Subject: [PATCH] Support stable cascade canny controlnet. --- comfy/controlnet.py | 15 ++++-- comfy/ldm/cascade/controlnet.py | 94 +++++++++++++++++++++++++++++++++ comfy/ldm/cascade/stage_c.py | 17 +++--- 3 files changed, 116 insertions(+), 10 deletions(-) create mode 100644 comfy/ldm/cascade/controlnet.py diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 602b2d3b..39344a4d 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -9,6 +9,7 @@ import comfy.ops import comfy.cldm.cldm import comfy.t2i_adapter.adapter +import comfy.ldm.cascade.controlnet def broadcast_image_to(tensor, target_batch_size, batched_number): @@ -78,6 +79,7 @@ class ControlBase: c.strength = self.strength c.timestep_percent_range = self.timestep_percent_range c.global_average_pooling = self.global_average_pooling + c.compression_ratio = self.compression_ratio def inference_memory_requirements(self, dtype): if self.previous_controlnet is not None: @@ -433,11 +435,12 @@ def load_controlnet(ckpt_path, model=None): return control class T2IAdapter(ControlBase): - def __init__(self, t2i_model, channels_in, device=None): + def __init__(self, t2i_model, channels_in, compression_ratio, device=None): super().__init__(device) self.t2i_model = t2i_model self.channels_in = channels_in self.control_input = None + self.compression_ratio = compression_ratio def scale_image_to(self, width, height): unshuffle_amount = self.t2i_model.unshuffle_amount @@ -482,11 +485,13 @@ class T2IAdapter(ControlBase): return self.control_merge(control_input, mid, control_prev, x_noisy.dtype) def copy(self): - c = T2IAdapter(self.t2i_model, self.channels_in) + c = T2IAdapter(self.t2i_model, self.channels_in, self.compression_ratio) self.copy_to(c) return c def load_t2i_adapter(t2i_data): + compression_ratio = 8 + if 'adapter' in t2i_data: t2i_data = t2i_data['adapter'] if 'adapter.body.0.resnets.0.block1.weight' in t2i_data: #diffusers format @@ -514,8 +519,12 @@ def load_t2i_adapter(t2i_data): if cin == 256 or cin == 768: xl = True model_ad = comfy.t2i_adapter.adapter.Adapter(cin=cin, channels=[channel, channel*2, channel*4, channel*4][:4], nums_rb=2, ksize=ksize, sk=True, use_conv=use_conv, xl=xl) + elif "backbone.0.0.weight" in keys: + model_ad = comfy.ldm.cascade.controlnet.ControlNet(c_in=t2i_data['backbone.0.0.weight'].shape[1], proj_blocks=[0, 4, 8, 12, 51, 55, 59, 63]) + compression_ratio = 32 else: return None + missing, unexpected = model_ad.load_state_dict(t2i_data) if len(missing) > 0: print("t2i missing", missing) @@ -523,4 +532,4 @@ def load_t2i_adapter(t2i_data): if len(unexpected) > 0: print("t2i unexpected", unexpected) - return T2IAdapter(model_ad, model_ad.input_channels) + return T2IAdapter(model_ad, model_ad.input_channels, compression_ratio) diff --git a/comfy/ldm/cascade/controlnet.py b/comfy/ldm/cascade/controlnet.py new file mode 100644 index 00000000..fdb1ba4d --- /dev/null +++ b/comfy/ldm/cascade/controlnet.py @@ -0,0 +1,94 @@ +""" + This file is part of ComfyUI. + Copyright (C) 2024 Stability AI + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" + +import torch +import torchvision +from torch import nn +from .common import LayerNorm2d_op + + +class CNetResBlock(nn.Module): + def __init__(self, c, dtype=None, device=None, operations=None): + super().__init__() + self.blocks = nn.Sequential( + LayerNorm2d_op(operations)(c, dtype=dtype, device=device), + nn.GELU(), + operations.Conv2d(c, c, kernel_size=3, padding=1), + LayerNorm2d_op(operations)(c, dtype=dtype, device=device), + nn.GELU(), + operations.Conv2d(c, c, kernel_size=3, padding=1), + ) + + def forward(self, x): + return x + self.blocks(x) + + +class ControlNet(nn.Module): + def __init__(self, c_in=3, c_proj=2048, proj_blocks=None, bottleneck_mode=None, dtype=None, device=None, operations=nn): + super().__init__() + if bottleneck_mode is None: + bottleneck_mode = 'effnet' + self.proj_blocks = proj_blocks + if bottleneck_mode == 'effnet': + embd_channels = 1280 + self.backbone = torchvision.models.efficientnet_v2_s().features.eval() + if c_in != 3: + in_weights = self.backbone[0][0].weight.data + self.backbone[0][0] = operations.Conv2d(c_in, 24, kernel_size=3, stride=2, bias=False, dtype=dtype, device=device) + if c_in > 3: + # nn.init.constant_(self.backbone[0][0].weight, 0) + self.backbone[0][0].weight.data[:, :3] = in_weights[:, :3].clone() + else: + self.backbone[0][0].weight.data = in_weights[:, :c_in].clone() + elif bottleneck_mode == 'simple': + embd_channels = c_in + self.backbone = nn.Sequential( + operations.Conv2d(embd_channels, embd_channels * 4, kernel_size=3, padding=1, dtype=dtype, device=device), + nn.LeakyReLU(0.2, inplace=True), + operations.Conv2d(embd_channels * 4, embd_channels, kernel_size=3, padding=1, dtype=dtype, device=device), + ) + elif bottleneck_mode == 'large': + self.backbone = nn.Sequential( + operations.Conv2d(c_in, 4096 * 4, kernel_size=1, dtype=dtype, device=device), + nn.LeakyReLU(0.2, inplace=True), + operations.Conv2d(4096 * 4, 1024, kernel_size=1, dtype=dtype, device=device), + *[CNetResBlock(1024) for _ in range(8)], + operations.Conv2d(1024, 1280, kernel_size=1, dtype=dtype, device=device), + ) + embd_channels = 1280 + else: + raise ValueError(f'Unknown bottleneck mode: {bottleneck_mode}') + self.projections = nn.ModuleList() + for _ in range(len(proj_blocks)): + self.projections.append(nn.Sequential( + operations.Conv2d(embd_channels, embd_channels, kernel_size=1, bias=False, dtype=dtype, device=device), + nn.LeakyReLU(0.2, inplace=True), + operations.Conv2d(embd_channels, c_proj, kernel_size=1, bias=False, dtype=dtype, device=device), + )) + # nn.init.constant_(self.projections[-1][-1].weight, 0) # zero output projection + self.xl = False + self.input_channels = c_in + self.unshuffle_amount = 8 + + def forward(self, x): + print(x) + x = self.backbone(x) + proj_outputs = [None for _ in range(max(self.proj_blocks) + 1)] + for i, idx in enumerate(self.proj_blocks): + proj_outputs[idx] = self.projections[i](x) + return proj_outputs diff --git a/comfy/ldm/cascade/stage_c.py b/comfy/ldm/cascade/stage_c.py index 08e33ade..67c1e52b 100644 --- a/comfy/ldm/cascade/stage_c.py +++ b/comfy/ldm/cascade/stage_c.py @@ -194,10 +194,10 @@ class StageC(nn.Module): hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, ResBlock)): if cnet is not None: - next_cnet = cnet() + next_cnet = cnet.pop() if next_cnet is not None: x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode='bilinear', - align_corners=True) + align_corners=True).to(x.dtype) x = block(x) elif isinstance(block, AttnBlock) or ( hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, @@ -228,10 +228,10 @@ class StageC(nn.Module): x = torch.nn.functional.interpolate(x, skip.shape[-2:], mode='bilinear', align_corners=True) if cnet is not None: - next_cnet = cnet() + next_cnet = cnet.pop() if next_cnet is not None: x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode='bilinear', - align_corners=True) + align_corners=True).to(x.dtype) x = block(x, skip) elif isinstance(block, AttnBlock) or ( hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, @@ -248,7 +248,7 @@ class StageC(nn.Module): x = upscaler(x) return x - def forward(self, x, r, clip_text, clip_text_pooled, clip_img, cnet=None, **kwargs): + def forward(self, x, r, clip_text, clip_text_pooled, clip_img, control=None, **kwargs): # Process the conditioning embeddings r_embed = self.gen_r_embedding(r).to(dtype=x.dtype) for c in self.t_conds: @@ -256,10 +256,13 @@ class StageC(nn.Module): r_embed = torch.cat([r_embed, self.gen_r_embedding(t_cond).to(dtype=x.dtype)], dim=1) clip = self.gen_c_embeddings(clip_text, clip_text_pooled, clip_img) + if control is not None: + cnet = control.get("input") + else: + cnet = None + # Model Blocks x = self.embedding(x) - if cnet is not None: - cnet = ControlNetDeliverer(cnet) level_outputs = self._down_encode(x, r_embed, clip, cnet) x = self._up_decode(level_outputs, r_embed, clip, cnet) return self.clf(x)