mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Stable Cascade Stage C.
This commit is contained in:
parent
5e06baf112
commit
f83109f09b
@ -318,9 +318,10 @@ def load_controlnet(ckpt_path, model=None):
|
|||||||
return ControlLora(controlnet_data)
|
return ControlLora(controlnet_data)
|
||||||
|
|
||||||
controlnet_config = None
|
controlnet_config = None
|
||||||
|
supported_inference_dtypes = None
|
||||||
|
|
||||||
if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format
|
if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format
|
||||||
unet_dtype = comfy.model_management.unet_dtype()
|
controlnet_config = comfy.model_detection.unet_config_from_diffusers_unet(controlnet_data)
|
||||||
controlnet_config = comfy.model_detection.unet_config_from_diffusers_unet(controlnet_data, unet_dtype)
|
|
||||||
diffusers_keys = comfy.utils.unet_to_diffusers(controlnet_config)
|
diffusers_keys = comfy.utils.unet_to_diffusers(controlnet_config)
|
||||||
diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight"
|
diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight"
|
||||||
diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias"
|
diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias"
|
||||||
@ -380,12 +381,20 @@ def load_controlnet(ckpt_path, model=None):
|
|||||||
return net
|
return net
|
||||||
|
|
||||||
if controlnet_config is None:
|
if controlnet_config is None:
|
||||||
unet_dtype = comfy.model_management.unet_dtype()
|
model_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, True)
|
||||||
controlnet_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, unet_dtype, True).unet_config
|
supported_inference_dtypes = model_config.supported_inference_dtypes
|
||||||
|
controlnet_config = model_config.unet_config
|
||||||
|
|
||||||
load_device = comfy.model_management.get_torch_device()
|
load_device = comfy.model_management.get_torch_device()
|
||||||
|
if supported_inference_dtypes is None:
|
||||||
|
unet_dtype = comfy.model_management.unet_dtype()
|
||||||
|
else:
|
||||||
|
unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
|
||||||
|
|
||||||
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
||||||
if manual_cast_dtype is not None:
|
if manual_cast_dtype is not None:
|
||||||
controlnet_config["operations"] = comfy.ops.manual_cast
|
controlnet_config["operations"] = comfy.ops.manual_cast
|
||||||
|
controlnet_config["dtype"] = unet_dtype
|
||||||
controlnet_config.pop("out_channels")
|
controlnet_config.pop("out_channels")
|
||||||
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
|
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
|
||||||
control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
|
control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
|
||||||
|
@ -37,3 +37,9 @@ class SDXL(LatentFormat):
|
|||||||
class SD_X4(LatentFormat):
|
class SD_X4(LatentFormat):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.scale_factor = 0.08333
|
self.scale_factor = 0.08333
|
||||||
|
|
||||||
|
class SC_Prior(LatentFormat):
|
||||||
|
def __init__(self):
|
||||||
|
self.scale_factor = 1.0
|
||||||
|
|
||||||
|
|
||||||
|
161
comfy/ldm/cascade/common.py
Normal file
161
comfy/ldm/cascade/common.py
Normal file
@ -0,0 +1,161 @@
|
|||||||
|
"""
|
||||||
|
This file is part of ComfyUI.
|
||||||
|
Copyright (C) 2024 Stability AI
|
||||||
|
|
||||||
|
This program is free software: you can redistribute it and/or modify
|
||||||
|
it under the terms of the GNU General Public License as published by
|
||||||
|
the Free Software Foundation, either version 3 of the License, or
|
||||||
|
(at your option) any later version.
|
||||||
|
|
||||||
|
This program is distributed in the hope that it will be useful,
|
||||||
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
GNU General Public License for more details.
|
||||||
|
|
||||||
|
You should have received a copy of the GNU General Public License
|
||||||
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
|
|
||||||
|
class Linear(torch.nn.Linear):
|
||||||
|
def reset_parameters(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
class Conv2d(torch.nn.Conv2d):
|
||||||
|
def reset_parameters(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
class OptimizedAttention(nn.Module):
|
||||||
|
def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.heads = nhead
|
||||||
|
|
||||||
|
self.to_q = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
|
||||||
|
self.to_k = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
|
||||||
|
self.to_v = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def forward(self, q, k, v):
|
||||||
|
q = self.to_q(q)
|
||||||
|
k = self.to_k(k)
|
||||||
|
v = self.to_v(v)
|
||||||
|
|
||||||
|
out = optimized_attention(q, k, v, self.heads)
|
||||||
|
|
||||||
|
return self.out_proj(out)
|
||||||
|
|
||||||
|
class Attention2D(nn.Module):
|
||||||
|
def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.attn = OptimizedAttention(c, nhead, dtype=dtype, device=device, operations=operations)
|
||||||
|
# self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def forward(self, x, kv, self_attn=False):
|
||||||
|
orig_shape = x.shape
|
||||||
|
x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4
|
||||||
|
if self_attn:
|
||||||
|
kv = torch.cat([x, kv], dim=1)
|
||||||
|
# x = self.attn(x, kv, kv, need_weights=False)[0]
|
||||||
|
x = self.attn(x, kv, kv)
|
||||||
|
x = x.permute(0, 2, 1).view(*orig_shape)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def LayerNorm2d_op(operations):
|
||||||
|
class LayerNorm2d(operations.LayerNorm):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return super().forward(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||||
|
return LayerNorm2d
|
||||||
|
|
||||||
|
class GlobalResponseNorm(nn.Module):
|
||||||
|
"from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105"
|
||||||
|
def __init__(self, dim, dtype=None, device=None):
|
||||||
|
super().__init__()
|
||||||
|
self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim, dtype=dtype, device=device))
|
||||||
|
self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim, dtype=dtype, device=device))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
|
||||||
|
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
|
||||||
|
return self.gamma * (x * Nx) + self.beta + x
|
||||||
|
|
||||||
|
|
||||||
|
class ResBlock(nn.Module):
|
||||||
|
def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0, dtype=None, device=None, operations=None): # , num_heads=4, expansion=2):
|
||||||
|
super().__init__()
|
||||||
|
self.depthwise = operations.Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c, dtype=dtype, device=device)
|
||||||
|
# self.depthwise = SAMBlock(c, num_heads, expansion)
|
||||||
|
self.norm = LayerNorm2d_op(operations)(c, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||||
|
self.channelwise = nn.Sequential(
|
||||||
|
operations.Linear(c + c_skip, c * 4, dtype=dtype, device=device),
|
||||||
|
nn.GELU(),
|
||||||
|
GlobalResponseNorm(c * 4, dtype=dtype, device=device),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
operations.Linear(c * 4, c, dtype=dtype, device=device)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, x_skip=None):
|
||||||
|
x_res = x
|
||||||
|
x = self.norm(self.depthwise(x))
|
||||||
|
if x_skip is not None:
|
||||||
|
x = torch.cat([x, x_skip], dim=1)
|
||||||
|
x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||||
|
return x + x_res
|
||||||
|
|
||||||
|
|
||||||
|
class AttnBlock(nn.Module):
|
||||||
|
def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.self_attn = self_attn
|
||||||
|
self.norm = LayerNorm2d_op(operations)(c, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||||
|
self.attention = Attention2D(c, nhead, dropout, dtype=dtype, device=device, operations=operations)
|
||||||
|
self.kv_mapper = nn.Sequential(
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Linear(c_cond, c, dtype=dtype, device=device)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, kv):
|
||||||
|
kv = self.kv_mapper(kv)
|
||||||
|
x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class FeedForwardBlock(nn.Module):
|
||||||
|
def __init__(self, c, dropout=0.0, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.norm = LayerNorm2d_op(operations)(c, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||||
|
self.channelwise = nn.Sequential(
|
||||||
|
operations.Linear(c, c * 4, dtype=dtype, device=device),
|
||||||
|
nn.GELU(),
|
||||||
|
GlobalResponseNorm(c * 4, dtype=dtype, device=device),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
operations.Linear(c * 4, c, dtype=dtype, device=device)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = x + self.channelwise(self.norm(x).permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class TimestepBlock(nn.Module):
|
||||||
|
def __init__(self, c, c_timestep, conds=['sca'], dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.mapper = operations.Linear(c_timestep, c * 2, dtype=dtype, device=device)
|
||||||
|
self.conds = conds
|
||||||
|
for cname in conds:
|
||||||
|
setattr(self, f"mapper_{cname}", operations.Linear(c_timestep, c * 2, dtype=dtype, device=device))
|
||||||
|
|
||||||
|
def forward(self, x, t):
|
||||||
|
t = t.chunk(len(self.conds) + 1, dim=1)
|
||||||
|
a, b = self.mapper(t[0])[:, :, None, None].chunk(2, dim=1)
|
||||||
|
for i, c in enumerate(self.conds):
|
||||||
|
ac, bc = getattr(self, f"mapper_{c}")(t[i + 1])[:, :, None, None].chunk(2, dim=1)
|
||||||
|
a, b = a + ac, b + bc
|
||||||
|
return x * (1 + a) + b
|
271
comfy/ldm/cascade/stage_c.py
Normal file
271
comfy/ldm/cascade/stage_c.py
Normal file
@ -0,0 +1,271 @@
|
|||||||
|
"""
|
||||||
|
This file is part of ComfyUI.
|
||||||
|
Copyright (C) 2024 Stability AI
|
||||||
|
|
||||||
|
This program is free software: you can redistribute it and/or modify
|
||||||
|
it under the terms of the GNU General Public License as published by
|
||||||
|
the Free Software Foundation, either version 3 of the License, or
|
||||||
|
(at your option) any later version.
|
||||||
|
|
||||||
|
This program is distributed in the hope that it will be useful,
|
||||||
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
GNU General Public License for more details.
|
||||||
|
|
||||||
|
You should have received a copy of the GNU General Public License
|
||||||
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import numpy as np
|
||||||
|
import math
|
||||||
|
from .common import AttnBlock, LayerNorm2d_op, ResBlock, FeedForwardBlock, TimestepBlock
|
||||||
|
# from .controlnet import ControlNetDeliverer
|
||||||
|
|
||||||
|
class UpDownBlock2d(nn.Module):
|
||||||
|
def __init__(self, c_in, c_out, mode, enabled=True, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
assert mode in ['up', 'down']
|
||||||
|
interpolation = nn.Upsample(scale_factor=2 if mode == 'up' else 0.5, mode='bilinear',
|
||||||
|
align_corners=True) if enabled else nn.Identity()
|
||||||
|
mapping = operations.Conv2d(c_in, c_out, kernel_size=1, dtype=dtype, device=device)
|
||||||
|
self.blocks = nn.ModuleList([interpolation, mapping] if mode == 'up' else [mapping, interpolation])
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
for block in self.blocks:
|
||||||
|
x = block(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class StageC(nn.Module):
|
||||||
|
def __init__(self, c_in=16, c_out=16, c_r=64, patch_size=1, c_cond=2048, c_hidden=[2048, 2048], nhead=[32, 32],
|
||||||
|
blocks=[[8, 24], [24, 8]], block_repeat=[[1, 1], [1, 1]], level_config=['CTA', 'CTA'],
|
||||||
|
c_clip_text=1280, c_clip_text_pooled=1280, c_clip_img=768, c_clip_seq=4, kernel_size=3,
|
||||||
|
dropout=[0.1, 0.1], self_attn=True, t_conds=['sca', 'crp'], switch_level=[False], stable_cascade_stage=None,
|
||||||
|
dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.dtype = dtype
|
||||||
|
self.c_r = c_r
|
||||||
|
self.t_conds = t_conds
|
||||||
|
self.c_clip_seq = c_clip_seq
|
||||||
|
if not isinstance(dropout, list):
|
||||||
|
dropout = [dropout] * len(c_hidden)
|
||||||
|
if not isinstance(self_attn, list):
|
||||||
|
self_attn = [self_attn] * len(c_hidden)
|
||||||
|
|
||||||
|
# CONDITIONING
|
||||||
|
self.clip_txt_mapper = operations.Linear(c_clip_text, c_cond, dtype=dtype, device=device)
|
||||||
|
self.clip_txt_pooled_mapper = operations.Linear(c_clip_text_pooled, c_cond * c_clip_seq, dtype=dtype, device=device)
|
||||||
|
self.clip_img_mapper = operations.Linear(c_clip_img, c_cond * c_clip_seq, dtype=dtype, device=device)
|
||||||
|
self.clip_norm = operations.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
self.embedding = nn.Sequential(
|
||||||
|
nn.PixelUnshuffle(patch_size),
|
||||||
|
operations.Conv2d(c_in * (patch_size ** 2), c_hidden[0], kernel_size=1, dtype=dtype, device=device),
|
||||||
|
LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6)
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True):
|
||||||
|
if block_type == 'C':
|
||||||
|
return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout, dtype=dtype, device=device, operations=operations)
|
||||||
|
elif block_type == 'A':
|
||||||
|
return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout, dtype=dtype, device=device, operations=operations)
|
||||||
|
elif block_type == 'F':
|
||||||
|
return FeedForwardBlock(c_hidden, dropout=dropout, dtype=dtype, device=device, operations=operations)
|
||||||
|
elif block_type == 'T':
|
||||||
|
return TimestepBlock(c_hidden, c_r, conds=t_conds, dtype=dtype, device=device, operations=operations)
|
||||||
|
else:
|
||||||
|
raise Exception(f'Block type {block_type} not supported')
|
||||||
|
|
||||||
|
# BLOCKS
|
||||||
|
# -- down blocks
|
||||||
|
self.down_blocks = nn.ModuleList()
|
||||||
|
self.down_downscalers = nn.ModuleList()
|
||||||
|
self.down_repeat_mappers = nn.ModuleList()
|
||||||
|
for i in range(len(c_hidden)):
|
||||||
|
if i > 0:
|
||||||
|
self.down_downscalers.append(nn.Sequential(
|
||||||
|
LayerNorm2d_op(operations)(c_hidden[i - 1], elementwise_affine=False, eps=1e-6),
|
||||||
|
UpDownBlock2d(c_hidden[i - 1], c_hidden[i], mode='down', enabled=switch_level[i - 1], dtype=dtype, device=device, operations=operations)
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
self.down_downscalers.append(nn.Identity())
|
||||||
|
down_block = nn.ModuleList()
|
||||||
|
for _ in range(blocks[0][i]):
|
||||||
|
for block_type in level_config[i]:
|
||||||
|
block = get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i], self_attn=self_attn[i])
|
||||||
|
down_block.append(block)
|
||||||
|
self.down_blocks.append(down_block)
|
||||||
|
if block_repeat is not None:
|
||||||
|
block_repeat_mappers = nn.ModuleList()
|
||||||
|
for _ in range(block_repeat[0][i] - 1):
|
||||||
|
block_repeat_mappers.append(nn.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1))
|
||||||
|
self.down_repeat_mappers.append(block_repeat_mappers)
|
||||||
|
|
||||||
|
# -- up blocks
|
||||||
|
self.up_blocks = nn.ModuleList()
|
||||||
|
self.up_upscalers = nn.ModuleList()
|
||||||
|
self.up_repeat_mappers = nn.ModuleList()
|
||||||
|
for i in reversed(range(len(c_hidden))):
|
||||||
|
if i > 0:
|
||||||
|
self.up_upscalers.append(nn.Sequential(
|
||||||
|
LayerNorm2d_op(operations)(c_hidden[i], elementwise_affine=False, eps=1e-6),
|
||||||
|
UpDownBlock2d(c_hidden[i], c_hidden[i - 1], mode='up', enabled=switch_level[i - 1], dtype=dtype, device=device, operations=operations)
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
self.up_upscalers.append(nn.Identity())
|
||||||
|
up_block = nn.ModuleList()
|
||||||
|
for j in range(blocks[1][::-1][i]):
|
||||||
|
for k, block_type in enumerate(level_config[i]):
|
||||||
|
c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0
|
||||||
|
block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i],
|
||||||
|
self_attn=self_attn[i])
|
||||||
|
up_block.append(block)
|
||||||
|
self.up_blocks.append(up_block)
|
||||||
|
if block_repeat is not None:
|
||||||
|
block_repeat_mappers = nn.ModuleList()
|
||||||
|
for _ in range(block_repeat[1][::-1][i] - 1):
|
||||||
|
block_repeat_mappers.append(nn.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1))
|
||||||
|
self.up_repeat_mappers.append(block_repeat_mappers)
|
||||||
|
|
||||||
|
# OUTPUT
|
||||||
|
self.clf = nn.Sequential(
|
||||||
|
LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6),
|
||||||
|
operations.Conv2d(c_hidden[0], c_out * (patch_size ** 2), kernel_size=1, dtype=dtype, device=device),
|
||||||
|
nn.PixelShuffle(patch_size),
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- WEIGHT INIT ---
|
||||||
|
# self.apply(self._init_weights) # General init
|
||||||
|
# nn.init.normal_(self.clip_txt_mapper.weight, std=0.02) # conditionings
|
||||||
|
# nn.init.normal_(self.clip_txt_pooled_mapper.weight, std=0.02) # conditionings
|
||||||
|
# nn.init.normal_(self.clip_img_mapper.weight, std=0.02) # conditionings
|
||||||
|
# torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs
|
||||||
|
# nn.init.constant_(self.clf[1].weight, 0) # outputs
|
||||||
|
#
|
||||||
|
# # blocks
|
||||||
|
# for level_block in self.down_blocks + self.up_blocks:
|
||||||
|
# for block in level_block:
|
||||||
|
# if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock):
|
||||||
|
# block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0]))
|
||||||
|
# elif isinstance(block, TimestepBlock):
|
||||||
|
# for layer in block.modules():
|
||||||
|
# if isinstance(layer, nn.Linear):
|
||||||
|
# nn.init.constant_(layer.weight, 0)
|
||||||
|
#
|
||||||
|
# def _init_weights(self, m):
|
||||||
|
# if isinstance(m, (nn.Conv2d, nn.Linear)):
|
||||||
|
# torch.nn.init.xavier_uniform_(m.weight)
|
||||||
|
# if m.bias is not None:
|
||||||
|
# nn.init.constant_(m.bias, 0)
|
||||||
|
|
||||||
|
def gen_r_embedding(self, r, max_positions=10000):
|
||||||
|
r = r * max_positions
|
||||||
|
half_dim = self.c_r // 2
|
||||||
|
emb = math.log(max_positions) / (half_dim - 1)
|
||||||
|
emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
|
||||||
|
emb = r[:, None] * emb[None, :]
|
||||||
|
emb = torch.cat([emb.sin(), emb.cos()], dim=1)
|
||||||
|
if self.c_r % 2 == 1: # zero pad
|
||||||
|
emb = nn.functional.pad(emb, (0, 1), mode='constant')
|
||||||
|
return emb
|
||||||
|
|
||||||
|
def gen_c_embeddings(self, clip_txt, clip_txt_pooled, clip_img):
|
||||||
|
clip_txt = self.clip_txt_mapper(clip_txt)
|
||||||
|
if len(clip_txt_pooled.shape) == 2:
|
||||||
|
clip_txt_pooled = clip_txt_pooled.unsqueeze(1)
|
||||||
|
if len(clip_img.shape) == 2:
|
||||||
|
clip_img = clip_img.unsqueeze(1)
|
||||||
|
clip_txt_pool = self.clip_txt_pooled_mapper(clip_txt_pooled).view(clip_txt_pooled.size(0), clip_txt_pooled.size(1) * self.c_clip_seq, -1)
|
||||||
|
clip_img = self.clip_img_mapper(clip_img).view(clip_img.size(0), clip_img.size(1) * self.c_clip_seq, -1)
|
||||||
|
clip = torch.cat([clip_txt, clip_txt_pool, clip_img], dim=1)
|
||||||
|
clip = self.clip_norm(clip)
|
||||||
|
return clip
|
||||||
|
|
||||||
|
def _down_encode(self, x, r_embed, clip, cnet=None):
|
||||||
|
level_outputs = []
|
||||||
|
block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
|
||||||
|
for down_block, downscaler, repmap in block_group:
|
||||||
|
x = downscaler(x)
|
||||||
|
for i in range(len(repmap) + 1):
|
||||||
|
for block in down_block:
|
||||||
|
if isinstance(block, ResBlock) or (
|
||||||
|
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||||
|
ResBlock)):
|
||||||
|
if cnet is not None:
|
||||||
|
next_cnet = cnet()
|
||||||
|
if next_cnet is not None:
|
||||||
|
x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode='bilinear',
|
||||||
|
align_corners=True)
|
||||||
|
x = block(x)
|
||||||
|
elif isinstance(block, AttnBlock) or (
|
||||||
|
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||||
|
AttnBlock)):
|
||||||
|
x = block(x, clip)
|
||||||
|
elif isinstance(block, TimestepBlock) or (
|
||||||
|
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||||
|
TimestepBlock)):
|
||||||
|
x = block(x, r_embed)
|
||||||
|
else:
|
||||||
|
x = block(x)
|
||||||
|
if i < len(repmap):
|
||||||
|
x = repmap[i](x)
|
||||||
|
level_outputs.insert(0, x)
|
||||||
|
return level_outputs
|
||||||
|
|
||||||
|
def _up_decode(self, level_outputs, r_embed, clip, cnet=None):
|
||||||
|
x = level_outputs[0]
|
||||||
|
block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
|
||||||
|
for i, (up_block, upscaler, repmap) in enumerate(block_group):
|
||||||
|
for j in range(len(repmap) + 1):
|
||||||
|
for k, block in enumerate(up_block):
|
||||||
|
if isinstance(block, ResBlock) or (
|
||||||
|
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||||
|
ResBlock)):
|
||||||
|
skip = level_outputs[i] if k == 0 and i > 0 else None
|
||||||
|
if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)):
|
||||||
|
x = torch.nn.functional.interpolate(x, skip.shape[-2:], mode='bilinear',
|
||||||
|
align_corners=True)
|
||||||
|
if cnet is not None:
|
||||||
|
next_cnet = cnet()
|
||||||
|
if next_cnet is not None:
|
||||||
|
x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode='bilinear',
|
||||||
|
align_corners=True)
|
||||||
|
x = block(x, skip)
|
||||||
|
elif isinstance(block, AttnBlock) or (
|
||||||
|
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||||
|
AttnBlock)):
|
||||||
|
x = block(x, clip)
|
||||||
|
elif isinstance(block, TimestepBlock) or (
|
||||||
|
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||||
|
TimestepBlock)):
|
||||||
|
x = block(x, r_embed)
|
||||||
|
else:
|
||||||
|
x = block(x)
|
||||||
|
if j < len(repmap):
|
||||||
|
x = repmap[j](x)
|
||||||
|
x = upscaler(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x, r, clip_text, clip_text_pooled, clip_img, cnet=None, **kwargs):
|
||||||
|
# Process the conditioning embeddings
|
||||||
|
r_embed = self.gen_r_embedding(r).to(dtype=x.dtype)
|
||||||
|
for c in self.t_conds:
|
||||||
|
t_cond = kwargs.get(c, torch.zeros_like(r))
|
||||||
|
r_embed = torch.cat([r_embed, self.gen_r_embedding(t_cond).to(dtype=x.dtype)], dim=1)
|
||||||
|
clip = self.gen_c_embeddings(clip_text, clip_text_pooled, clip_img)
|
||||||
|
|
||||||
|
# Model Blocks
|
||||||
|
x = self.embedding(x)
|
||||||
|
if cnet is not None:
|
||||||
|
cnet = ControlNetDeliverer(cnet)
|
||||||
|
level_outputs = self._down_encode(x, r_embed, clip, cnet)
|
||||||
|
x = self._up_decode(level_outputs, r_embed, clip, cnet)
|
||||||
|
return self.clf(x)
|
||||||
|
|
||||||
|
def update_weights_ema(self, src_model, beta=0.999):
|
||||||
|
for self_params, src_params in zip(self.parameters(), src_model.parameters()):
|
||||||
|
self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1 - beta)
|
||||||
|
for self_buffers, src_buffers in zip(self.buffers(), src_model.buffers()):
|
||||||
|
self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta)
|
@ -1,5 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
||||||
|
from comfy.ldm.cascade.stage_c import StageC
|
||||||
from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
|
from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
|
||||||
from comfy.ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
|
from comfy.ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
@ -12,9 +13,10 @@ class ModelType(Enum):
|
|||||||
EPS = 1
|
EPS = 1
|
||||||
V_PREDICTION = 2
|
V_PREDICTION = 2
|
||||||
V_PREDICTION_EDM = 3
|
V_PREDICTION_EDM = 3
|
||||||
|
STABLE_CASCADE = 4
|
||||||
|
|
||||||
|
|
||||||
from comfy.model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete, ModelSamplingContinuousEDM
|
from comfy.model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling
|
||||||
|
|
||||||
|
|
||||||
def model_sampling(model_config, model_type):
|
def model_sampling(model_config, model_type):
|
||||||
@ -27,6 +29,9 @@ def model_sampling(model_config, model_type):
|
|||||||
elif model_type == ModelType.V_PREDICTION_EDM:
|
elif model_type == ModelType.V_PREDICTION_EDM:
|
||||||
c = V_PREDICTION
|
c = V_PREDICTION
|
||||||
s = ModelSamplingContinuousEDM
|
s = ModelSamplingContinuousEDM
|
||||||
|
elif model_type == ModelType.STABLE_CASCADE:
|
||||||
|
c = EPS
|
||||||
|
s = StableCascadeSampling
|
||||||
|
|
||||||
class ModelSampling(s, c):
|
class ModelSampling(s, c):
|
||||||
pass
|
pass
|
||||||
@ -35,7 +40,7 @@ def model_sampling(model_config, model_type):
|
|||||||
|
|
||||||
|
|
||||||
class BaseModel(torch.nn.Module):
|
class BaseModel(torch.nn.Module):
|
||||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_model=UNetModel):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
unet_config = model_config.unet_config
|
unet_config = model_config.unet_config
|
||||||
@ -48,7 +53,7 @@ class BaseModel(torch.nn.Module):
|
|||||||
operations = comfy.ops.manual_cast
|
operations = comfy.ops.manual_cast
|
||||||
else:
|
else:
|
||||||
operations = comfy.ops.disable_weight_init
|
operations = comfy.ops.disable_weight_init
|
||||||
self.diffusion_model = UNetModel(**unet_config, device=device, operations=operations)
|
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
||||||
self.model_type = model_type
|
self.model_type = model_type
|
||||||
self.model_sampling = model_sampling(model_config, model_type)
|
self.model_sampling = model_sampling(model_config, model_type)
|
||||||
|
|
||||||
@ -427,3 +432,32 @@ class SD_X4Upscaler(BaseModel):
|
|||||||
out['c_concat'] = comfy.conds.CONDNoiseShape(image)
|
out['c_concat'] = comfy.conds.CONDNoiseShape(image)
|
||||||
out['y'] = comfy.conds.CONDRegular(noise_level)
|
out['y'] = comfy.conds.CONDRegular(noise_level)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class StableCascade_C(BaseModel):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):
|
||||||
|
super().__init__(model_config, model_type, device=device, unet_model=StageC)
|
||||||
|
self.diffusion_model.eval().requires_grad_(False)
|
||||||
|
|
||||||
|
def extra_conds(self, **kwargs):
|
||||||
|
out = {}
|
||||||
|
clip_text_pooled = kwargs["pooled_output"]
|
||||||
|
if clip_text_pooled is not None:
|
||||||
|
out['clip_text_pooled'] = comfy.conds.CONDRegular(clip_text_pooled)
|
||||||
|
|
||||||
|
if "unclip_conditioning" in kwargs:
|
||||||
|
embeds = []
|
||||||
|
for unclip_cond in kwargs["unclip_conditioning"]:
|
||||||
|
weight = unclip_cond["strength"]
|
||||||
|
embeds.append(unclip_cond["clip_vision_output"].image_embeds.unsqueeze(0) * weight)
|
||||||
|
clip_img = torch.cat(embeds, dim=1)
|
||||||
|
else:
|
||||||
|
clip_img = torch.zeros((1, 1, 768))
|
||||||
|
out["clip_img"] = comfy.conds.CONDRegular(clip_img)
|
||||||
|
out["sca"] = comfy.conds.CONDRegular(torch.zeros((1,)))
|
||||||
|
out["crp"] = comfy.conds.CONDRegular(torch.zeros((1,)))
|
||||||
|
|
||||||
|
cross_attn = kwargs.get("cross_attn", None)
|
||||||
|
if cross_attn is not None:
|
||||||
|
out['clip_text'] = comfy.conds.CONDCrossAttn(cross_attn)
|
||||||
|
return out
|
||||||
|
|
||||||
|
@ -28,9 +28,26 @@ def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
|
|||||||
return last_transformer_depth, context_dim, use_linear_in_transformer, time_stack
|
return last_transformer_depth, context_dim, use_linear_in_transformer, time_stack
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def detect_unet_config(state_dict, key_prefix, dtype):
|
def detect_unet_config(state_dict, key_prefix):
|
||||||
state_dict_keys = list(state_dict.keys())
|
state_dict_keys = list(state_dict.keys())
|
||||||
|
|
||||||
|
if '{}clf.1.weight'.format(key_prefix) in state_dict_keys: #stable cascade
|
||||||
|
unet_config = {}
|
||||||
|
text_mapper_name = '{}clip_txt_mapper.weight'.format(key_prefix)
|
||||||
|
if text_mapper_name in state_dict_keys:
|
||||||
|
unet_config['stable_cascade_stage'] = 'c'
|
||||||
|
w = state_dict[text_mapper_name]
|
||||||
|
if w.shape[0] == 1536: #stage c lite
|
||||||
|
unet_config['c_cond'] = 1536
|
||||||
|
unet_config['c_hidden'] = [1536, 1536]
|
||||||
|
unet_config['nhead'] = [24, 24]
|
||||||
|
unet_config['blocks'] = [[4, 12], [12, 4]]
|
||||||
|
elif w.shape[0] == 2048: #stage c full
|
||||||
|
unet_config['c_cond'] = 2048
|
||||||
|
elif '{}clip_mapper.weight'.format(key_prefix) in state_dict_keys:
|
||||||
|
unet_config['stable_cascade_stage'] = 'b'
|
||||||
|
return unet_config
|
||||||
|
|
||||||
unet_config = {
|
unet_config = {
|
||||||
"use_checkpoint": False,
|
"use_checkpoint": False,
|
||||||
"image_size": 32,
|
"image_size": 32,
|
||||||
@ -45,7 +62,6 @@ def detect_unet_config(state_dict, key_prefix, dtype):
|
|||||||
else:
|
else:
|
||||||
unet_config["adm_in_channels"] = None
|
unet_config["adm_in_channels"] = None
|
||||||
|
|
||||||
unet_config["dtype"] = dtype
|
|
||||||
model_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[0]
|
model_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[0]
|
||||||
in_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[1]
|
in_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[1]
|
||||||
|
|
||||||
@ -159,8 +175,8 @@ def model_config_from_unet_config(unet_config):
|
|||||||
print("no match", unet_config)
|
print("no match", unet_config)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def model_config_from_unet(state_dict, unet_key_prefix, dtype, use_base_if_no_match=False):
|
def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False):
|
||||||
unet_config = detect_unet_config(state_dict, unet_key_prefix, dtype)
|
unet_config = detect_unet_config(state_dict, unet_key_prefix)
|
||||||
model_config = model_config_from_unet_config(unet_config)
|
model_config = model_config_from_unet_config(unet_config)
|
||||||
if model_config is None and use_base_if_no_match:
|
if model_config is None and use_base_if_no_match:
|
||||||
return comfy.supported_models_base.BASE(unet_config)
|
return comfy.supported_models_base.BASE(unet_config)
|
||||||
@ -206,7 +222,7 @@ def convert_config(unet_config):
|
|||||||
return new_config
|
return new_config
|
||||||
|
|
||||||
|
|
||||||
def unet_config_from_diffusers_unet(state_dict, dtype):
|
def unet_config_from_diffusers_unet(state_dict, dtype=None):
|
||||||
match = {}
|
match = {}
|
||||||
transformer_depth = []
|
transformer_depth = []
|
||||||
|
|
||||||
@ -313,8 +329,8 @@ def unet_config_from_diffusers_unet(state_dict, dtype):
|
|||||||
return convert_config(unet_config)
|
return convert_config(unet_config)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def model_config_from_diffusers_unet(state_dict, dtype):
|
def model_config_from_diffusers_unet(state_dict):
|
||||||
unet_config = unet_config_from_diffusers_unet(state_dict, dtype)
|
unet_config = unet_config_from_diffusers_unet(state_dict)
|
||||||
if unet_config is not None:
|
if unet_config is not None:
|
||||||
return model_config_from_unet_config(unet_config)
|
return model_config_from_unet_config(unet_config)
|
||||||
return None
|
return None
|
||||||
|
@ -487,7 +487,7 @@ def unet_inital_load_device(parameters, dtype):
|
|||||||
else:
|
else:
|
||||||
return cpu_dev
|
return cpu_dev
|
||||||
|
|
||||||
def unet_dtype(device=None, model_params=0):
|
def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
|
||||||
if args.bf16_unet:
|
if args.bf16_unet:
|
||||||
return torch.bfloat16
|
return torch.bfloat16
|
||||||
if args.fp16_unet:
|
if args.fp16_unet:
|
||||||
@ -497,20 +497,31 @@ def unet_dtype(device=None, model_params=0):
|
|||||||
if args.fp8_e5m2_unet:
|
if args.fp8_e5m2_unet:
|
||||||
return torch.float8_e5m2
|
return torch.float8_e5m2
|
||||||
if should_use_fp16(device=device, model_params=model_params, manual_cast=True):
|
if should_use_fp16(device=device, model_params=model_params, manual_cast=True):
|
||||||
|
if torch.float16 in supported_dtypes:
|
||||||
return torch.float16
|
return torch.float16
|
||||||
|
if should_use_bf16(device):
|
||||||
|
if torch.bfloat16 in supported_dtypes:
|
||||||
|
return torch.bfloat16
|
||||||
return torch.float32
|
return torch.float32
|
||||||
|
|
||||||
# None means no manual cast
|
# None means no manual cast
|
||||||
def unet_manual_cast(weight_dtype, inference_device):
|
def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
|
||||||
if weight_dtype == torch.float32:
|
if weight_dtype == torch.float32:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
fp16_supported = comfy.model_management.should_use_fp16(inference_device, prioritize_performance=False)
|
fp16_supported = should_use_fp16(inference_device, prioritize_performance=False)
|
||||||
if fp16_supported and weight_dtype == torch.float16:
|
if fp16_supported and weight_dtype == torch.float16:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if fp16_supported:
|
bf16_supported = should_use_bf16(inference_device)
|
||||||
|
if bf16_supported and weight_dtype == torch.bfloat16:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if fp16_supported and torch.float16 in supported_dtypes:
|
||||||
return torch.float16
|
return torch.float16
|
||||||
|
|
||||||
|
elif bf16_supported and torch.bfloat16 in supported_dtypes:
|
||||||
|
return torch.bfloat16
|
||||||
else:
|
else:
|
||||||
return torch.float32
|
return torch.float32
|
||||||
|
|
||||||
@ -760,6 +771,19 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
|
|||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def should_use_bf16(device=None):
|
||||||
|
if is_intel_xpu():
|
||||||
|
return True
|
||||||
|
|
||||||
|
if device is None:
|
||||||
|
device = torch.device("cuda")
|
||||||
|
|
||||||
|
props = torch.cuda.get_device_properties(device)
|
||||||
|
if props.major >= 8:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
def soft_empty_cache(force=False):
|
def soft_empty_cache(force=False):
|
||||||
global cpu_state
|
global cpu_state
|
||||||
if cpu_state == CPUState.MPS:
|
if cpu_state == CPUState.MPS:
|
||||||
|
@ -132,3 +132,33 @@ class ModelSamplingContinuousEDM(torch.nn.Module):
|
|||||||
|
|
||||||
log_sigma_min = math.log(self.sigma_min)
|
log_sigma_min = math.log(self.sigma_min)
|
||||||
return math.exp((math.log(self.sigma_max) - log_sigma_min) * percent + log_sigma_min)
|
return math.exp((math.log(self.sigma_max) - log_sigma_min) * percent + log_sigma_min)
|
||||||
|
|
||||||
|
class StableCascadeSampling(ModelSamplingDiscrete):
|
||||||
|
def __init__(self, model_config=None):
|
||||||
|
super().__init__()
|
||||||
|
self.num_timesteps = 1000
|
||||||
|
cosine_s=8e-3
|
||||||
|
self.cosine_s = torch.tensor([cosine_s])
|
||||||
|
sigmas = torch.empty((self.num_timesteps), dtype=torch.float32)
|
||||||
|
self._init_alpha_cumprod = torch.cos(self.cosine_s / (1 + self.cosine_s) * torch.pi * 0.5) ** 2
|
||||||
|
for x in range(self.num_timesteps):
|
||||||
|
t = x / self.num_timesteps
|
||||||
|
sigmas[x] = self.sigma(t)
|
||||||
|
|
||||||
|
self.set_sigmas(sigmas)
|
||||||
|
|
||||||
|
def sigma(self, timestep):
|
||||||
|
alpha_cumprod = (torch.cos((timestep + self.cosine_s) / (1 + self.cosine_s) * torch.pi * 0.5) ** 2 / self._init_alpha_cumprod).clamp(0.0001, 0.9999)
|
||||||
|
return ((1 - alpha_cumprod) / alpha_cumprod) ** 0.5
|
||||||
|
|
||||||
|
def timestep(self, sigma):
|
||||||
|
return super().timestep(sigma) / 1000.0
|
||||||
|
|
||||||
|
def percent_to_sigma(self, percent):
|
||||||
|
if percent <= 0.0:
|
||||||
|
return 999999999.9
|
||||||
|
if percent >= 1.0:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
percent = 1.0 - percent
|
||||||
|
return self.sigma(torch.tensor(percent))
|
||||||
|
20
comfy/sd.py
20
comfy/sd.py
@ -450,15 +450,15 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
clip_target = None
|
clip_target = None
|
||||||
|
|
||||||
parameters = comfy.utils.calculate_parameters(sd, "model.diffusion_model.")
|
parameters = comfy.utils.calculate_parameters(sd, "model.diffusion_model.")
|
||||||
unet_dtype = model_management.unet_dtype(model_params=parameters)
|
|
||||||
load_device = model_management.get_torch_device()
|
load_device = model_management.get_torch_device()
|
||||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device)
|
|
||||||
|
|
||||||
class WeightsLoader(torch.nn.Module):
|
class WeightsLoader(torch.nn.Module):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.", unet_dtype)
|
model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.")
|
||||||
model_config.set_manual_cast(manual_cast_dtype)
|
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes)
|
||||||
|
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||||
|
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||||
|
|
||||||
if model_config is None:
|
if model_config is None:
|
||||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
|
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
|
||||||
@ -507,16 +507,15 @@ def load_unet_state_dict(sd): #load unet in diffusers format
|
|||||||
parameters = comfy.utils.calculate_parameters(sd)
|
parameters = comfy.utils.calculate_parameters(sd)
|
||||||
unet_dtype = model_management.unet_dtype(model_params=parameters)
|
unet_dtype = model_management.unet_dtype(model_params=parameters)
|
||||||
load_device = model_management.get_torch_device()
|
load_device = model_management.get_torch_device()
|
||||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device)
|
|
||||||
|
|
||||||
if "input_blocks.0.0.weight" in sd: #ldm
|
if "input_blocks.0.0.weight" in sd or 'clf.1.weight' in sd: #ldm or stable cascade
|
||||||
model_config = model_detection.model_config_from_unet(sd, "", unet_dtype)
|
model_config = model_detection.model_config_from_unet(sd, "")
|
||||||
if model_config is None:
|
if model_config is None:
|
||||||
return None
|
return None
|
||||||
new_sd = sd
|
new_sd = sd
|
||||||
|
|
||||||
else: #diffusers
|
else: #diffusers
|
||||||
model_config = model_detection.model_config_from_diffusers_unet(sd, unet_dtype)
|
model_config = model_detection.model_config_from_diffusers_unet(sd)
|
||||||
if model_config is None:
|
if model_config is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -528,8 +527,11 @@ def load_unet_state_dict(sd): #load unet in diffusers format
|
|||||||
new_sd[diffusers_keys[k]] = sd.pop(k)
|
new_sd[diffusers_keys[k]] = sd.pop(k)
|
||||||
else:
|
else:
|
||||||
print(diffusers_keys[k], k)
|
print(diffusers_keys[k], k)
|
||||||
|
|
||||||
offload_device = model_management.unet_offload_device()
|
offload_device = model_management.unet_offload_device()
|
||||||
model_config.set_manual_cast(manual_cast_dtype)
|
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes)
|
||||||
|
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||||
|
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||||
model = model_config.get_model(new_sd, "")
|
model = model_config.get_model(new_sd, "")
|
||||||
model = model.to(offload_device)
|
model = model.to(offload_device)
|
||||||
model.load_model_weights(new_sd, "")
|
model.load_model_weights(new_sd, "")
|
||||||
|
@ -306,5 +306,38 @@ class SD_X4Upscaler(SD20):
|
|||||||
out = model_base.SD_X4Upscaler(self, device=device)
|
out = model_base.SD_X4Upscaler(self, device=device)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, Segmind_Vega, SD_X4Upscaler]
|
class Stable_Cascade_C(supported_models_base.BASE):
|
||||||
|
unet_config = {
|
||||||
|
"stable_cascade_stage": 'c',
|
||||||
|
}
|
||||||
|
|
||||||
|
unet_extra_config = {}
|
||||||
|
|
||||||
|
latent_format = latent_formats.SC_Prior
|
||||||
|
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||||
|
|
||||||
|
def process_unet_state_dict(self, state_dict):
|
||||||
|
key_list = list(state_dict.keys())
|
||||||
|
for y in ["weight", "bias"]:
|
||||||
|
suffix = "in_proj_{}".format(y)
|
||||||
|
keys = filter(lambda a: a.endswith(suffix), key_list)
|
||||||
|
for k_from in keys:
|
||||||
|
weights = state_dict.pop(k_from)
|
||||||
|
prefix = k_from[:-(len(suffix) + 1)]
|
||||||
|
shape_from = weights.shape[0] // 3
|
||||||
|
for x in range(3):
|
||||||
|
p = ["to_q", "to_k", "to_v"]
|
||||||
|
k_to = "{}.{}.{}".format(prefix, p[x], y)
|
||||||
|
state_dict[k_to] = weights[shape_from*x:shape_from*(x + 1)]
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
out = model_base.StableCascade_C(self, device=device)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def clip_target(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C]
|
||||||
models += [SVD_img2vid]
|
models += [SVD_img2vid]
|
||||||
|
@ -22,13 +22,14 @@ class BASE:
|
|||||||
sampling_settings = {}
|
sampling_settings = {}
|
||||||
latent_format = latent_formats.LatentFormat
|
latent_format = latent_formats.LatentFormat
|
||||||
vae_key_prefix = ["first_stage_model."]
|
vae_key_prefix = ["first_stage_model."]
|
||||||
|
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||||
|
|
||||||
manual_cast_dtype = None
|
manual_cast_dtype = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def matches(s, unet_config):
|
def matches(s, unet_config):
|
||||||
for k in s.unet_config:
|
for k in s.unet_config:
|
||||||
if s.unet_config[k] != unet_config[k]:
|
if k not in unet_config or s.unet_config[k] != unet_config[k]:
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -80,5 +81,6 @@ class BASE:
|
|||||||
replace_prefix = {"": "first_stage_model."}
|
replace_prefix = {"": "first_stage_model."}
|
||||||
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||||
|
|
||||||
def set_manual_cast(self, manual_cast_dtype):
|
def set_inference_dtype(self, dtype, manual_cast_dtype):
|
||||||
|
self.unet_config['dtype'] = dtype
|
||||||
self.manual_cast_dtype = manual_cast_dtype
|
self.manual_cast_dtype = manual_cast_dtype
|
||||||
|
Loading…
Reference in New Issue
Block a user