diff --git a/comfy/controlnet.py b/comfy/controlnet.py
index d9d990a7..41619758 100644
--- a/comfy/controlnet.py
+++ b/comfy/controlnet.py
@@ -318,9 +318,10 @@ def load_controlnet(ckpt_path, model=None):
return ControlLora(controlnet_data)
controlnet_config = None
+ supported_inference_dtypes = None
+
if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format
- unet_dtype = comfy.model_management.unet_dtype()
- controlnet_config = comfy.model_detection.unet_config_from_diffusers_unet(controlnet_data, unet_dtype)
+ controlnet_config = comfy.model_detection.unet_config_from_diffusers_unet(controlnet_data)
diffusers_keys = comfy.utils.unet_to_diffusers(controlnet_config)
diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight"
diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias"
@@ -380,12 +381,20 @@ def load_controlnet(ckpt_path, model=None):
return net
if controlnet_config is None:
- unet_dtype = comfy.model_management.unet_dtype()
- controlnet_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, unet_dtype, True).unet_config
+ model_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, True)
+ supported_inference_dtypes = model_config.supported_inference_dtypes
+ controlnet_config = model_config.unet_config
+
load_device = comfy.model_management.get_torch_device()
+ if supported_inference_dtypes is None:
+ unet_dtype = comfy.model_management.unet_dtype()
+ else:
+ unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
+
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
if manual_cast_dtype is not None:
controlnet_config["operations"] = comfy.ops.manual_cast
+ controlnet_config["dtype"] = unet_dtype
controlnet_config.pop("out_channels")
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py
index 2252a075..8ba76737 100644
--- a/comfy/latent_formats.py
+++ b/comfy/latent_formats.py
@@ -37,3 +37,9 @@ class SDXL(LatentFormat):
class SD_X4(LatentFormat):
def __init__(self):
self.scale_factor = 0.08333
+
+class SC_Prior(LatentFormat):
+ def __init__(self):
+ self.scale_factor = 1.0
+
+
diff --git a/comfy/ldm/cascade/common.py b/comfy/ldm/cascade/common.py
new file mode 100644
index 00000000..3c5bf997
--- /dev/null
+++ b/comfy/ldm/cascade/common.py
@@ -0,0 +1,161 @@
+"""
+ This file is part of ComfyUI.
+ Copyright (C) 2024 Stability AI
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see .
+"""
+
+import torch
+import torch.nn as nn
+from comfy.ldm.modules.attention import optimized_attention
+
+class Linear(torch.nn.Linear):
+ def reset_parameters(self):
+ return None
+
+class Conv2d(torch.nn.Conv2d):
+ def reset_parameters(self):
+ return None
+
+class OptimizedAttention(nn.Module):
+ def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None):
+ super().__init__()
+ self.heads = nhead
+
+ self.to_q = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
+ self.to_k = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
+ self.to_v = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
+
+ self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
+
+ def forward(self, q, k, v):
+ q = self.to_q(q)
+ k = self.to_k(k)
+ v = self.to_v(v)
+
+ out = optimized_attention(q, k, v, self.heads)
+
+ return self.out_proj(out)
+
+class Attention2D(nn.Module):
+ def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None):
+ super().__init__()
+ self.attn = OptimizedAttention(c, nhead, dtype=dtype, device=device, operations=operations)
+ # self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True, dtype=dtype, device=device)
+
+ def forward(self, x, kv, self_attn=False):
+ orig_shape = x.shape
+ x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4
+ if self_attn:
+ kv = torch.cat([x, kv], dim=1)
+ # x = self.attn(x, kv, kv, need_weights=False)[0]
+ x = self.attn(x, kv, kv)
+ x = x.permute(0, 2, 1).view(*orig_shape)
+ return x
+
+
+def LayerNorm2d_op(operations):
+ class LayerNorm2d(operations.LayerNorm):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(self, x):
+ return super().forward(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
+ return LayerNorm2d
+
+class GlobalResponseNorm(nn.Module):
+ "from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105"
+ def __init__(self, dim, dtype=None, device=None):
+ super().__init__()
+ self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim, dtype=dtype, device=device))
+ self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim, dtype=dtype, device=device))
+
+ def forward(self, x):
+ Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
+ Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
+ return self.gamma * (x * Nx) + self.beta + x
+
+
+class ResBlock(nn.Module):
+ def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0, dtype=None, device=None, operations=None): # , num_heads=4, expansion=2):
+ super().__init__()
+ self.depthwise = operations.Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c, dtype=dtype, device=device)
+ # self.depthwise = SAMBlock(c, num_heads, expansion)
+ self.norm = LayerNorm2d_op(operations)(c, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
+ self.channelwise = nn.Sequential(
+ operations.Linear(c + c_skip, c * 4, dtype=dtype, device=device),
+ nn.GELU(),
+ GlobalResponseNorm(c * 4, dtype=dtype, device=device),
+ nn.Dropout(dropout),
+ operations.Linear(c * 4, c, dtype=dtype, device=device)
+ )
+
+ def forward(self, x, x_skip=None):
+ x_res = x
+ x = self.norm(self.depthwise(x))
+ if x_skip is not None:
+ x = torch.cat([x, x_skip], dim=1)
+ x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
+ return x + x_res
+
+
+class AttnBlock(nn.Module):
+ def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0, dtype=None, device=None, operations=None):
+ super().__init__()
+ self.self_attn = self_attn
+ self.norm = LayerNorm2d_op(operations)(c, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
+ self.attention = Attention2D(c, nhead, dropout, dtype=dtype, device=device, operations=operations)
+ self.kv_mapper = nn.Sequential(
+ nn.SiLU(),
+ operations.Linear(c_cond, c, dtype=dtype, device=device)
+ )
+
+ def forward(self, x, kv):
+ kv = self.kv_mapper(kv)
+ x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn)
+ return x
+
+
+class FeedForwardBlock(nn.Module):
+ def __init__(self, c, dropout=0.0, dtype=None, device=None, operations=None):
+ super().__init__()
+ self.norm = LayerNorm2d_op(operations)(c, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
+ self.channelwise = nn.Sequential(
+ operations.Linear(c, c * 4, dtype=dtype, device=device),
+ nn.GELU(),
+ GlobalResponseNorm(c * 4, dtype=dtype, device=device),
+ nn.Dropout(dropout),
+ operations.Linear(c * 4, c, dtype=dtype, device=device)
+ )
+
+ def forward(self, x):
+ x = x + self.channelwise(self.norm(x).permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
+ return x
+
+
+class TimestepBlock(nn.Module):
+ def __init__(self, c, c_timestep, conds=['sca'], dtype=None, device=None, operations=None):
+ super().__init__()
+ self.mapper = operations.Linear(c_timestep, c * 2, dtype=dtype, device=device)
+ self.conds = conds
+ for cname in conds:
+ setattr(self, f"mapper_{cname}", operations.Linear(c_timestep, c * 2, dtype=dtype, device=device))
+
+ def forward(self, x, t):
+ t = t.chunk(len(self.conds) + 1, dim=1)
+ a, b = self.mapper(t[0])[:, :, None, None].chunk(2, dim=1)
+ for i, c in enumerate(self.conds):
+ ac, bc = getattr(self, f"mapper_{c}")(t[i + 1])[:, :, None, None].chunk(2, dim=1)
+ a, b = a + ac, b + bc
+ return x * (1 + a) + b
diff --git a/comfy/ldm/cascade/stage_c.py b/comfy/ldm/cascade/stage_c.py
new file mode 100644
index 00000000..2e0f4706
--- /dev/null
+++ b/comfy/ldm/cascade/stage_c.py
@@ -0,0 +1,271 @@
+"""
+ This file is part of ComfyUI.
+ Copyright (C) 2024 Stability AI
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see .
+"""
+
+import torch
+from torch import nn
+import numpy as np
+import math
+from .common import AttnBlock, LayerNorm2d_op, ResBlock, FeedForwardBlock, TimestepBlock
+# from .controlnet import ControlNetDeliverer
+
+class UpDownBlock2d(nn.Module):
+ def __init__(self, c_in, c_out, mode, enabled=True, dtype=None, device=None, operations=None):
+ super().__init__()
+ assert mode in ['up', 'down']
+ interpolation = nn.Upsample(scale_factor=2 if mode == 'up' else 0.5, mode='bilinear',
+ align_corners=True) if enabled else nn.Identity()
+ mapping = operations.Conv2d(c_in, c_out, kernel_size=1, dtype=dtype, device=device)
+ self.blocks = nn.ModuleList([interpolation, mapping] if mode == 'up' else [mapping, interpolation])
+
+ def forward(self, x):
+ for block in self.blocks:
+ x = block(x)
+ return x
+
+
+class StageC(nn.Module):
+ def __init__(self, c_in=16, c_out=16, c_r=64, patch_size=1, c_cond=2048, c_hidden=[2048, 2048], nhead=[32, 32],
+ blocks=[[8, 24], [24, 8]], block_repeat=[[1, 1], [1, 1]], level_config=['CTA', 'CTA'],
+ c_clip_text=1280, c_clip_text_pooled=1280, c_clip_img=768, c_clip_seq=4, kernel_size=3,
+ dropout=[0.1, 0.1], self_attn=True, t_conds=['sca', 'crp'], switch_level=[False], stable_cascade_stage=None,
+ dtype=None, device=None, operations=None):
+ super().__init__()
+ self.dtype = dtype
+ self.c_r = c_r
+ self.t_conds = t_conds
+ self.c_clip_seq = c_clip_seq
+ if not isinstance(dropout, list):
+ dropout = [dropout] * len(c_hidden)
+ if not isinstance(self_attn, list):
+ self_attn = [self_attn] * len(c_hidden)
+
+ # CONDITIONING
+ self.clip_txt_mapper = operations.Linear(c_clip_text, c_cond, dtype=dtype, device=device)
+ self.clip_txt_pooled_mapper = operations.Linear(c_clip_text_pooled, c_cond * c_clip_seq, dtype=dtype, device=device)
+ self.clip_img_mapper = operations.Linear(c_clip_img, c_cond * c_clip_seq, dtype=dtype, device=device)
+ self.clip_norm = operations.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
+
+ self.embedding = nn.Sequential(
+ nn.PixelUnshuffle(patch_size),
+ operations.Conv2d(c_in * (patch_size ** 2), c_hidden[0], kernel_size=1, dtype=dtype, device=device),
+ LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6)
+ )
+
+ def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True):
+ if block_type == 'C':
+ return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout, dtype=dtype, device=device, operations=operations)
+ elif block_type == 'A':
+ return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout, dtype=dtype, device=device, operations=operations)
+ elif block_type == 'F':
+ return FeedForwardBlock(c_hidden, dropout=dropout, dtype=dtype, device=device, operations=operations)
+ elif block_type == 'T':
+ return TimestepBlock(c_hidden, c_r, conds=t_conds, dtype=dtype, device=device, operations=operations)
+ else:
+ raise Exception(f'Block type {block_type} not supported')
+
+ # BLOCKS
+ # -- down blocks
+ self.down_blocks = nn.ModuleList()
+ self.down_downscalers = nn.ModuleList()
+ self.down_repeat_mappers = nn.ModuleList()
+ for i in range(len(c_hidden)):
+ if i > 0:
+ self.down_downscalers.append(nn.Sequential(
+ LayerNorm2d_op(operations)(c_hidden[i - 1], elementwise_affine=False, eps=1e-6),
+ UpDownBlock2d(c_hidden[i - 1], c_hidden[i], mode='down', enabled=switch_level[i - 1], dtype=dtype, device=device, operations=operations)
+ ))
+ else:
+ self.down_downscalers.append(nn.Identity())
+ down_block = nn.ModuleList()
+ for _ in range(blocks[0][i]):
+ for block_type in level_config[i]:
+ block = get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i], self_attn=self_attn[i])
+ down_block.append(block)
+ self.down_blocks.append(down_block)
+ if block_repeat is not None:
+ block_repeat_mappers = nn.ModuleList()
+ for _ in range(block_repeat[0][i] - 1):
+ block_repeat_mappers.append(nn.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1))
+ self.down_repeat_mappers.append(block_repeat_mappers)
+
+ # -- up blocks
+ self.up_blocks = nn.ModuleList()
+ self.up_upscalers = nn.ModuleList()
+ self.up_repeat_mappers = nn.ModuleList()
+ for i in reversed(range(len(c_hidden))):
+ if i > 0:
+ self.up_upscalers.append(nn.Sequential(
+ LayerNorm2d_op(operations)(c_hidden[i], elementwise_affine=False, eps=1e-6),
+ UpDownBlock2d(c_hidden[i], c_hidden[i - 1], mode='up', enabled=switch_level[i - 1], dtype=dtype, device=device, operations=operations)
+ ))
+ else:
+ self.up_upscalers.append(nn.Identity())
+ up_block = nn.ModuleList()
+ for j in range(blocks[1][::-1][i]):
+ for k, block_type in enumerate(level_config[i]):
+ c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0
+ block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i],
+ self_attn=self_attn[i])
+ up_block.append(block)
+ self.up_blocks.append(up_block)
+ if block_repeat is not None:
+ block_repeat_mappers = nn.ModuleList()
+ for _ in range(block_repeat[1][::-1][i] - 1):
+ block_repeat_mappers.append(nn.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1))
+ self.up_repeat_mappers.append(block_repeat_mappers)
+
+ # OUTPUT
+ self.clf = nn.Sequential(
+ LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6),
+ operations.Conv2d(c_hidden[0], c_out * (patch_size ** 2), kernel_size=1, dtype=dtype, device=device),
+ nn.PixelShuffle(patch_size),
+ )
+
+ # --- WEIGHT INIT ---
+ # self.apply(self._init_weights) # General init
+ # nn.init.normal_(self.clip_txt_mapper.weight, std=0.02) # conditionings
+ # nn.init.normal_(self.clip_txt_pooled_mapper.weight, std=0.02) # conditionings
+ # nn.init.normal_(self.clip_img_mapper.weight, std=0.02) # conditionings
+ # torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs
+ # nn.init.constant_(self.clf[1].weight, 0) # outputs
+ #
+ # # blocks
+ # for level_block in self.down_blocks + self.up_blocks:
+ # for block in level_block:
+ # if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock):
+ # block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0]))
+ # elif isinstance(block, TimestepBlock):
+ # for layer in block.modules():
+ # if isinstance(layer, nn.Linear):
+ # nn.init.constant_(layer.weight, 0)
+ #
+ # def _init_weights(self, m):
+ # if isinstance(m, (nn.Conv2d, nn.Linear)):
+ # torch.nn.init.xavier_uniform_(m.weight)
+ # if m.bias is not None:
+ # nn.init.constant_(m.bias, 0)
+
+ def gen_r_embedding(self, r, max_positions=10000):
+ r = r * max_positions
+ half_dim = self.c_r // 2
+ emb = math.log(max_positions) / (half_dim - 1)
+ emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
+ emb = r[:, None] * emb[None, :]
+ emb = torch.cat([emb.sin(), emb.cos()], dim=1)
+ if self.c_r % 2 == 1: # zero pad
+ emb = nn.functional.pad(emb, (0, 1), mode='constant')
+ return emb
+
+ def gen_c_embeddings(self, clip_txt, clip_txt_pooled, clip_img):
+ clip_txt = self.clip_txt_mapper(clip_txt)
+ if len(clip_txt_pooled.shape) == 2:
+ clip_txt_pooled = clip_txt_pooled.unsqueeze(1)
+ if len(clip_img.shape) == 2:
+ clip_img = clip_img.unsqueeze(1)
+ clip_txt_pool = self.clip_txt_pooled_mapper(clip_txt_pooled).view(clip_txt_pooled.size(0), clip_txt_pooled.size(1) * self.c_clip_seq, -1)
+ clip_img = self.clip_img_mapper(clip_img).view(clip_img.size(0), clip_img.size(1) * self.c_clip_seq, -1)
+ clip = torch.cat([clip_txt, clip_txt_pool, clip_img], dim=1)
+ clip = self.clip_norm(clip)
+ return clip
+
+ def _down_encode(self, x, r_embed, clip, cnet=None):
+ level_outputs = []
+ block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
+ for down_block, downscaler, repmap in block_group:
+ x = downscaler(x)
+ for i in range(len(repmap) + 1):
+ for block in down_block:
+ if isinstance(block, ResBlock) or (
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
+ ResBlock)):
+ if cnet is not None:
+ next_cnet = cnet()
+ if next_cnet is not None:
+ x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode='bilinear',
+ align_corners=True)
+ x = block(x)
+ elif isinstance(block, AttnBlock) or (
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
+ AttnBlock)):
+ x = block(x, clip)
+ elif isinstance(block, TimestepBlock) or (
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
+ TimestepBlock)):
+ x = block(x, r_embed)
+ else:
+ x = block(x)
+ if i < len(repmap):
+ x = repmap[i](x)
+ level_outputs.insert(0, x)
+ return level_outputs
+
+ def _up_decode(self, level_outputs, r_embed, clip, cnet=None):
+ x = level_outputs[0]
+ block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
+ for i, (up_block, upscaler, repmap) in enumerate(block_group):
+ for j in range(len(repmap) + 1):
+ for k, block in enumerate(up_block):
+ if isinstance(block, ResBlock) or (
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
+ ResBlock)):
+ skip = level_outputs[i] if k == 0 and i > 0 else None
+ if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)):
+ x = torch.nn.functional.interpolate(x, skip.shape[-2:], mode='bilinear',
+ align_corners=True)
+ if cnet is not None:
+ next_cnet = cnet()
+ if next_cnet is not None:
+ x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode='bilinear',
+ align_corners=True)
+ x = block(x, skip)
+ elif isinstance(block, AttnBlock) or (
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
+ AttnBlock)):
+ x = block(x, clip)
+ elif isinstance(block, TimestepBlock) or (
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
+ TimestepBlock)):
+ x = block(x, r_embed)
+ else:
+ x = block(x)
+ if j < len(repmap):
+ x = repmap[j](x)
+ x = upscaler(x)
+ return x
+
+ def forward(self, x, r, clip_text, clip_text_pooled, clip_img, cnet=None, **kwargs):
+ # Process the conditioning embeddings
+ r_embed = self.gen_r_embedding(r).to(dtype=x.dtype)
+ for c in self.t_conds:
+ t_cond = kwargs.get(c, torch.zeros_like(r))
+ r_embed = torch.cat([r_embed, self.gen_r_embedding(t_cond).to(dtype=x.dtype)], dim=1)
+ clip = self.gen_c_embeddings(clip_text, clip_text_pooled, clip_img)
+
+ # Model Blocks
+ x = self.embedding(x)
+ if cnet is not None:
+ cnet = ControlNetDeliverer(cnet)
+ level_outputs = self._down_encode(x, r_embed, clip, cnet)
+ x = self._up_decode(level_outputs, r_embed, clip, cnet)
+ return self.clf(x)
+
+ def update_weights_ema(self, src_model, beta=0.999):
+ for self_params, src_params in zip(self.parameters(), src_model.parameters()):
+ self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1 - beta)
+ for self_buffers, src_buffers in zip(self.buffers(), src_model.buffers()):
+ self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta)
diff --git a/comfy/model_base.py b/comfy/model_base.py
index aafb88e0..dde2a2f3 100644
--- a/comfy/model_base.py
+++ b/comfy/model_base.py
@@ -1,5 +1,6 @@
import torch
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
+from comfy.ldm.cascade.stage_c import StageC
from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
from comfy.ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
import comfy.model_management
@@ -12,9 +13,10 @@ class ModelType(Enum):
EPS = 1
V_PREDICTION = 2
V_PREDICTION_EDM = 3
+ STABLE_CASCADE = 4
-from comfy.model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete, ModelSamplingContinuousEDM
+from comfy.model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling
def model_sampling(model_config, model_type):
@@ -27,6 +29,9 @@ def model_sampling(model_config, model_type):
elif model_type == ModelType.V_PREDICTION_EDM:
c = V_PREDICTION
s = ModelSamplingContinuousEDM
+ elif model_type == ModelType.STABLE_CASCADE:
+ c = EPS
+ s = StableCascadeSampling
class ModelSampling(s, c):
pass
@@ -35,7 +40,7 @@ def model_sampling(model_config, model_type):
class BaseModel(torch.nn.Module):
- def __init__(self, model_config, model_type=ModelType.EPS, device=None):
+ def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_model=UNetModel):
super().__init__()
unet_config = model_config.unet_config
@@ -48,7 +53,7 @@ class BaseModel(torch.nn.Module):
operations = comfy.ops.manual_cast
else:
operations = comfy.ops.disable_weight_init
- self.diffusion_model = UNetModel(**unet_config, device=device, operations=operations)
+ self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
self.model_type = model_type
self.model_sampling = model_sampling(model_config, model_type)
@@ -427,3 +432,32 @@ class SD_X4Upscaler(BaseModel):
out['c_concat'] = comfy.conds.CONDNoiseShape(image)
out['y'] = comfy.conds.CONDRegular(noise_level)
return out
+
+class StableCascade_C(BaseModel):
+ def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):
+ super().__init__(model_config, model_type, device=device, unet_model=StageC)
+ self.diffusion_model.eval().requires_grad_(False)
+
+ def extra_conds(self, **kwargs):
+ out = {}
+ clip_text_pooled = kwargs["pooled_output"]
+ if clip_text_pooled is not None:
+ out['clip_text_pooled'] = comfy.conds.CONDRegular(clip_text_pooled)
+
+ if "unclip_conditioning" in kwargs:
+ embeds = []
+ for unclip_cond in kwargs["unclip_conditioning"]:
+ weight = unclip_cond["strength"]
+ embeds.append(unclip_cond["clip_vision_output"].image_embeds.unsqueeze(0) * weight)
+ clip_img = torch.cat(embeds, dim=1)
+ else:
+ clip_img = torch.zeros((1, 1, 768))
+ out["clip_img"] = comfy.conds.CONDRegular(clip_img)
+ out["sca"] = comfy.conds.CONDRegular(torch.zeros((1,)))
+ out["crp"] = comfy.conds.CONDRegular(torch.zeros((1,)))
+
+ cross_attn = kwargs.get("cross_attn", None)
+ if cross_attn is not None:
+ out['clip_text'] = comfy.conds.CONDCrossAttn(cross_attn)
+ return out
+
diff --git a/comfy/model_detection.py b/comfy/model_detection.py
index ea824c44..8d4fb7b6 100644
--- a/comfy/model_detection.py
+++ b/comfy/model_detection.py
@@ -28,9 +28,26 @@ def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
return last_transformer_depth, context_dim, use_linear_in_transformer, time_stack
return None
-def detect_unet_config(state_dict, key_prefix, dtype):
+def detect_unet_config(state_dict, key_prefix):
state_dict_keys = list(state_dict.keys())
+ if '{}clf.1.weight'.format(key_prefix) in state_dict_keys: #stable cascade
+ unet_config = {}
+ text_mapper_name = '{}clip_txt_mapper.weight'.format(key_prefix)
+ if text_mapper_name in state_dict_keys:
+ unet_config['stable_cascade_stage'] = 'c'
+ w = state_dict[text_mapper_name]
+ if w.shape[0] == 1536: #stage c lite
+ unet_config['c_cond'] = 1536
+ unet_config['c_hidden'] = [1536, 1536]
+ unet_config['nhead'] = [24, 24]
+ unet_config['blocks'] = [[4, 12], [12, 4]]
+ elif w.shape[0] == 2048: #stage c full
+ unet_config['c_cond'] = 2048
+ elif '{}clip_mapper.weight'.format(key_prefix) in state_dict_keys:
+ unet_config['stable_cascade_stage'] = 'b'
+ return unet_config
+
unet_config = {
"use_checkpoint": False,
"image_size": 32,
@@ -45,7 +62,6 @@ def detect_unet_config(state_dict, key_prefix, dtype):
else:
unet_config["adm_in_channels"] = None
- unet_config["dtype"] = dtype
model_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[0]
in_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[1]
@@ -159,8 +175,8 @@ def model_config_from_unet_config(unet_config):
print("no match", unet_config)
return None
-def model_config_from_unet(state_dict, unet_key_prefix, dtype, use_base_if_no_match=False):
- unet_config = detect_unet_config(state_dict, unet_key_prefix, dtype)
+def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False):
+ unet_config = detect_unet_config(state_dict, unet_key_prefix)
model_config = model_config_from_unet_config(unet_config)
if model_config is None and use_base_if_no_match:
return comfy.supported_models_base.BASE(unet_config)
@@ -206,7 +222,7 @@ def convert_config(unet_config):
return new_config
-def unet_config_from_diffusers_unet(state_dict, dtype):
+def unet_config_from_diffusers_unet(state_dict, dtype=None):
match = {}
transformer_depth = []
@@ -313,8 +329,8 @@ def unet_config_from_diffusers_unet(state_dict, dtype):
return convert_config(unet_config)
return None
-def model_config_from_diffusers_unet(state_dict, dtype):
- unet_config = unet_config_from_diffusers_unet(state_dict, dtype)
+def model_config_from_diffusers_unet(state_dict):
+ unet_config = unet_config_from_diffusers_unet(state_dict)
if unet_config is not None:
return model_config_from_unet_config(unet_config)
return None
diff --git a/comfy/model_management.py b/comfy/model_management.py
index 0b3f6ead..eb7178b4 100644
--- a/comfy/model_management.py
+++ b/comfy/model_management.py
@@ -487,7 +487,7 @@ def unet_inital_load_device(parameters, dtype):
else:
return cpu_dev
-def unet_dtype(device=None, model_params=0):
+def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
if args.bf16_unet:
return torch.bfloat16
if args.fp16_unet:
@@ -497,20 +497,31 @@ def unet_dtype(device=None, model_params=0):
if args.fp8_e5m2_unet:
return torch.float8_e5m2
if should_use_fp16(device=device, model_params=model_params, manual_cast=True):
- return torch.float16
+ if torch.float16 in supported_dtypes:
+ return torch.float16
+ if should_use_bf16(device):
+ if torch.bfloat16 in supported_dtypes:
+ return torch.bfloat16
return torch.float32
# None means no manual cast
-def unet_manual_cast(weight_dtype, inference_device):
+def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
if weight_dtype == torch.float32:
return None
- fp16_supported = comfy.model_management.should_use_fp16(inference_device, prioritize_performance=False)
+ fp16_supported = should_use_fp16(inference_device, prioritize_performance=False)
if fp16_supported and weight_dtype == torch.float16:
return None
- if fp16_supported:
+ bf16_supported = should_use_bf16(inference_device)
+ if bf16_supported and weight_dtype == torch.bfloat16:
+ return None
+
+ if fp16_supported and torch.float16 in supported_dtypes:
return torch.float16
+
+ elif bf16_supported and torch.bfloat16 in supported_dtypes:
+ return torch.bfloat16
else:
return torch.float32
@@ -760,6 +771,19 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
return True
+def should_use_bf16(device=None):
+ if is_intel_xpu():
+ return True
+
+ if device is None:
+ device = torch.device("cuda")
+
+ props = torch.cuda.get_device_properties(device)
+ if props.major >= 8:
+ return True
+
+ return False
+
def soft_empty_cache(force=False):
global cpu_state
if cpu_state == CPUState.MPS:
diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py
index d5870027..b1fbf3e2 100644
--- a/comfy/model_sampling.py
+++ b/comfy/model_sampling.py
@@ -132,3 +132,33 @@ class ModelSamplingContinuousEDM(torch.nn.Module):
log_sigma_min = math.log(self.sigma_min)
return math.exp((math.log(self.sigma_max) - log_sigma_min) * percent + log_sigma_min)
+
+class StableCascadeSampling(ModelSamplingDiscrete):
+ def __init__(self, model_config=None):
+ super().__init__()
+ self.num_timesteps = 1000
+ cosine_s=8e-3
+ self.cosine_s = torch.tensor([cosine_s])
+ sigmas = torch.empty((self.num_timesteps), dtype=torch.float32)
+ self._init_alpha_cumprod = torch.cos(self.cosine_s / (1 + self.cosine_s) * torch.pi * 0.5) ** 2
+ for x in range(self.num_timesteps):
+ t = x / self.num_timesteps
+ sigmas[x] = self.sigma(t)
+
+ self.set_sigmas(sigmas)
+
+ def sigma(self, timestep):
+ alpha_cumprod = (torch.cos((timestep + self.cosine_s) / (1 + self.cosine_s) * torch.pi * 0.5) ** 2 / self._init_alpha_cumprod).clamp(0.0001, 0.9999)
+ return ((1 - alpha_cumprod) / alpha_cumprod) ** 0.5
+
+ def timestep(self, sigma):
+ return super().timestep(sigma) / 1000.0
+
+ def percent_to_sigma(self, percent):
+ if percent <= 0.0:
+ return 999999999.9
+ if percent >= 1.0:
+ return 0.0
+
+ percent = 1.0 - percent
+ return self.sigma(torch.tensor(percent))
diff --git a/comfy/sd.py b/comfy/sd.py
index 5e37cff9..f3ec62b3 100644
--- a/comfy/sd.py
+++ b/comfy/sd.py
@@ -450,15 +450,15 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
clip_target = None
parameters = comfy.utils.calculate_parameters(sd, "model.diffusion_model.")
- unet_dtype = model_management.unet_dtype(model_params=parameters)
load_device = model_management.get_torch_device()
- manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device)
class WeightsLoader(torch.nn.Module):
pass
- model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.", unet_dtype)
- model_config.set_manual_cast(manual_cast_dtype)
+ model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.")
+ unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes)
+ manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
+ model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
if model_config is None:
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
@@ -507,16 +507,15 @@ def load_unet_state_dict(sd): #load unet in diffusers format
parameters = comfy.utils.calculate_parameters(sd)
unet_dtype = model_management.unet_dtype(model_params=parameters)
load_device = model_management.get_torch_device()
- manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device)
- if "input_blocks.0.0.weight" in sd: #ldm
- model_config = model_detection.model_config_from_unet(sd, "", unet_dtype)
+ if "input_blocks.0.0.weight" in sd or 'clf.1.weight' in sd: #ldm or stable cascade
+ model_config = model_detection.model_config_from_unet(sd, "")
if model_config is None:
return None
new_sd = sd
else: #diffusers
- model_config = model_detection.model_config_from_diffusers_unet(sd, unet_dtype)
+ model_config = model_detection.model_config_from_diffusers_unet(sd)
if model_config is None:
return None
@@ -528,8 +527,11 @@ def load_unet_state_dict(sd): #load unet in diffusers format
new_sd[diffusers_keys[k]] = sd.pop(k)
else:
print(diffusers_keys[k], k)
+
offload_device = model_management.unet_offload_device()
- model_config.set_manual_cast(manual_cast_dtype)
+ unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes)
+ manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
+ model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
model = model_config.get_model(new_sd, "")
model = model.to(offload_device)
model.load_model_weights(new_sd, "")
diff --git a/comfy/supported_models.py b/comfy/supported_models.py
index 1d442d4d..a8863e72 100644
--- a/comfy/supported_models.py
+++ b/comfy/supported_models.py
@@ -306,5 +306,38 @@ class SD_X4Upscaler(SD20):
out = model_base.SD_X4Upscaler(self, device=device)
return out
-models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, Segmind_Vega, SD_X4Upscaler]
+class Stable_Cascade_C(supported_models_base.BASE):
+ unet_config = {
+ "stable_cascade_stage": 'c',
+ }
+
+ unet_extra_config = {}
+
+ latent_format = latent_formats.SC_Prior
+ supported_inference_dtypes = [torch.bfloat16, torch.float32]
+
+ def process_unet_state_dict(self, state_dict):
+ key_list = list(state_dict.keys())
+ for y in ["weight", "bias"]:
+ suffix = "in_proj_{}".format(y)
+ keys = filter(lambda a: a.endswith(suffix), key_list)
+ for k_from in keys:
+ weights = state_dict.pop(k_from)
+ prefix = k_from[:-(len(suffix) + 1)]
+ shape_from = weights.shape[0] // 3
+ for x in range(3):
+ p = ["to_q", "to_k", "to_v"]
+ k_to = "{}.{}.{}".format(prefix, p[x], y)
+ state_dict[k_to] = weights[shape_from*x:shape_from*(x + 1)]
+ return state_dict
+
+ def get_model(self, state_dict, prefix="", device=None):
+ out = model_base.StableCascade_C(self, device=device)
+ return out
+
+ def clip_target(self):
+ return None
+
+
+models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C]
models += [SVD_img2vid]
diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py
index 58535a9f..3bd4f9c6 100644
--- a/comfy/supported_models_base.py
+++ b/comfy/supported_models_base.py
@@ -22,13 +22,14 @@ class BASE:
sampling_settings = {}
latent_format = latent_formats.LatentFormat
vae_key_prefix = ["first_stage_model."]
+ supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
manual_cast_dtype = None
@classmethod
def matches(s, unet_config):
for k in s.unet_config:
- if s.unet_config[k] != unet_config[k]:
+ if k not in unet_config or s.unet_config[k] != unet_config[k]:
return False
return True
@@ -80,5 +81,6 @@ class BASE:
replace_prefix = {"": "first_stage_model."}
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
- def set_manual_cast(self, manual_cast_dtype):
+ def set_inference_dtype(self, dtype, manual_cast_dtype):
+ self.unet_config['dtype'] = dtype
self.manual_cast_dtype = manual_cast_dtype