mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-10 18:05:16 +00:00
Add ControlNet support.
This commit is contained in:
parent
bc69fb5245
commit
4efa67fa12
286
comfy/cldm/cldm.py
Normal file
286
comfy/cldm/cldm.py
Normal file
@ -0,0 +1,286 @@
|
||||
#taken from: https://github.com/lllyasviel/ControlNet
|
||||
#and modified
|
||||
|
||||
import einops
|
||||
import torch
|
||||
import torch as th
|
||||
import torch.nn as nn
|
||||
|
||||
from ldm.modules.diffusionmodules.util import (
|
||||
conv_nd,
|
||||
linear,
|
||||
zero_module,
|
||||
timestep_embedding,
|
||||
)
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from torchvision.utils import make_grid
|
||||
from ldm.modules.attention import SpatialTransformer
|
||||
from ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock
|
||||
from ldm.models.diffusion.ddpm import LatentDiffusion
|
||||
from ldm.util import log_txt_as_img, exists, instantiate_from_config
|
||||
|
||||
|
||||
class ControlledUnetModel(UNetModel):
|
||||
#implemented in the ldm unet
|
||||
pass
|
||||
|
||||
class ControlNet(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
image_size,
|
||||
in_channels,
|
||||
model_channels,
|
||||
hint_channels,
|
||||
num_res_blocks,
|
||||
attention_resolutions,
|
||||
dropout=0,
|
||||
channel_mult=(1, 2, 4, 8),
|
||||
conv_resample=True,
|
||||
dims=2,
|
||||
use_checkpoint=False,
|
||||
use_fp16=False,
|
||||
num_heads=-1,
|
||||
num_head_channels=-1,
|
||||
num_heads_upsample=-1,
|
||||
use_scale_shift_norm=False,
|
||||
resblock_updown=False,
|
||||
use_new_attention_order=False,
|
||||
use_spatial_transformer=False, # custom transformer support
|
||||
transformer_depth=1, # custom transformer support
|
||||
context_dim=None, # custom transformer support
|
||||
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
|
||||
legacy=True,
|
||||
disable_self_attentions=None,
|
||||
num_attention_blocks=None,
|
||||
disable_middle_self_attn=False,
|
||||
use_linear_in_transformer=False,
|
||||
):
|
||||
super().__init__()
|
||||
if use_spatial_transformer:
|
||||
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
|
||||
|
||||
if context_dim is not None:
|
||||
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
|
||||
from omegaconf.listconfig import ListConfig
|
||||
if type(context_dim) == ListConfig:
|
||||
context_dim = list(context_dim)
|
||||
|
||||
if num_heads_upsample == -1:
|
||||
num_heads_upsample = num_heads
|
||||
|
||||
if num_heads == -1:
|
||||
assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
|
||||
|
||||
if num_head_channels == -1:
|
||||
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
|
||||
|
||||
self.dims = dims
|
||||
self.image_size = image_size
|
||||
self.in_channels = in_channels
|
||||
self.model_channels = model_channels
|
||||
if isinstance(num_res_blocks, int):
|
||||
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
|
||||
else:
|
||||
if len(num_res_blocks) != len(channel_mult):
|
||||
raise ValueError("provide num_res_blocks either as an int (globally constant) or "
|
||||
"as a list/tuple (per-level) with the same length as channel_mult")
|
||||
self.num_res_blocks = num_res_blocks
|
||||
if disable_self_attentions is not None:
|
||||
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
|
||||
assert len(disable_self_attentions) == len(channel_mult)
|
||||
if num_attention_blocks is not None:
|
||||
assert len(num_attention_blocks) == len(self.num_res_blocks)
|
||||
assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
|
||||
print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
|
||||
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
|
||||
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
|
||||
f"attention will still not be set.")
|
||||
|
||||
self.attention_resolutions = attention_resolutions
|
||||
self.dropout = dropout
|
||||
self.channel_mult = channel_mult
|
||||
self.conv_resample = conv_resample
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.dtype = th.float16 if use_fp16 else th.float32
|
||||
self.num_heads = num_heads
|
||||
self.num_head_channels = num_head_channels
|
||||
self.num_heads_upsample = num_heads_upsample
|
||||
self.predict_codebook_ids = n_embed is not None
|
||||
|
||||
time_embed_dim = model_channels * 4
|
||||
self.time_embed = nn.Sequential(
|
||||
linear(model_channels, time_embed_dim),
|
||||
nn.SiLU(),
|
||||
linear(time_embed_dim, time_embed_dim),
|
||||
)
|
||||
|
||||
self.input_blocks = nn.ModuleList(
|
||||
[
|
||||
TimestepEmbedSequential(
|
||||
conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
||||
)
|
||||
]
|
||||
)
|
||||
self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)])
|
||||
|
||||
self.input_hint_block = TimestepEmbedSequential(
|
||||
conv_nd(dims, hint_channels, 16, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, 16, 16, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, 16, 32, 3, padding=1, stride=2),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, 32, 32, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, 32, 96, 3, padding=1, stride=2),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, 96, 96, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, 96, 256, 3, padding=1, stride=2),
|
||||
nn.SiLU(),
|
||||
zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
|
||||
)
|
||||
|
||||
self._feature_size = model_channels
|
||||
input_block_chans = [model_channels]
|
||||
ch = model_channels
|
||||
ds = 1
|
||||
for level, mult in enumerate(channel_mult):
|
||||
for nr in range(self.num_res_blocks[level]):
|
||||
layers = [
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels=mult * model_channels,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
)
|
||||
]
|
||||
ch = mult * model_channels
|
||||
if ds in attention_resolutions:
|
||||
if num_head_channels == -1:
|
||||
dim_head = ch // num_heads
|
||||
else:
|
||||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
if legacy:
|
||||
#num_heads = 1
|
||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||
if exists(disable_self_attentions):
|
||||
disabled_sa = disable_self_attentions[level]
|
||||
else:
|
||||
disabled_sa = False
|
||||
|
||||
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
|
||||
layers.append(
|
||||
AttentionBlock(
|
||||
ch,
|
||||
use_checkpoint=use_checkpoint,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=dim_head,
|
||||
use_new_attention_order=use_new_attention_order,
|
||||
) if not use_spatial_transformer else SpatialTransformer(
|
||||
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
|
||||
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint
|
||||
)
|
||||
)
|
||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self.zero_convs.append(self.make_zero_conv(ch))
|
||||
self._feature_size += ch
|
||||
input_block_chans.append(ch)
|
||||
if level != len(channel_mult) - 1:
|
||||
out_ch = ch
|
||||
self.input_blocks.append(
|
||||
TimestepEmbedSequential(
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels=out_ch,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
down=True,
|
||||
)
|
||||
if resblock_updown
|
||||
else Downsample(
|
||||
ch, conv_resample, dims=dims, out_channels=out_ch
|
||||
)
|
||||
)
|
||||
)
|
||||
ch = out_ch
|
||||
input_block_chans.append(ch)
|
||||
self.zero_convs.append(self.make_zero_conv(ch))
|
||||
ds *= 2
|
||||
self._feature_size += ch
|
||||
|
||||
if num_head_channels == -1:
|
||||
dim_head = ch // num_heads
|
||||
else:
|
||||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
if legacy:
|
||||
#num_heads = 1
|
||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||
self.middle_block = TimestepEmbedSequential(
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
),
|
||||
AttentionBlock(
|
||||
ch,
|
||||
use_checkpoint=use_checkpoint,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=dim_head,
|
||||
use_new_attention_order=use_new_attention_order,
|
||||
) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
|
||||
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
|
||||
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint
|
||||
),
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
),
|
||||
)
|
||||
self.middle_block_out = self.make_zero_conv(ch)
|
||||
self._feature_size += ch
|
||||
|
||||
def make_zero_conv(self, channels):
|
||||
return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0)))
|
||||
|
||||
def forward(self, x, hint, timesteps, context, **kwargs):
|
||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
||||
emb = self.time_embed(t_emb)
|
||||
|
||||
guided_hint = self.input_hint_block(hint, emb, context)
|
||||
|
||||
outs = []
|
||||
|
||||
h = x.type(self.dtype)
|
||||
for module, zero_conv in zip(self.input_blocks, self.zero_convs):
|
||||
if guided_hint is not None:
|
||||
h = module(h, emb, context)
|
||||
h += guided_hint
|
||||
guided_hint = None
|
||||
else:
|
||||
h = module(h, emb, context)
|
||||
outs.append(zero_conv(h, emb, context))
|
||||
|
||||
h = self.middle_block(h, emb, context)
|
||||
outs.append(self.middle_block_out(h, emb, context))
|
||||
|
||||
return outs
|
||||
|
@ -856,13 +856,13 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, extra_args=None
|
||||
|
||||
device = noise.device
|
||||
|
||||
if model.inner_model.parameterization == "v":
|
||||
if model.parameterization == "v":
|
||||
model_type = "v"
|
||||
else:
|
||||
model_type = "noise"
|
||||
|
||||
model_fn = model_wrapper(
|
||||
model.inner_model.apply_model,
|
||||
model.inner_model.inner_model.apply_model,
|
||||
sampling_function,
|
||||
ns,
|
||||
model_type=model_type,
|
||||
|
@ -1320,12 +1320,12 @@ class DiffusionWrapper(torch.nn.Module):
|
||||
self.conditioning_key = conditioning_key
|
||||
assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm']
|
||||
|
||||
def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None):
|
||||
def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None, control=None):
|
||||
if self.conditioning_key is None:
|
||||
out = self.diffusion_model(x, t)
|
||||
out = self.diffusion_model(x, t, control=control)
|
||||
elif self.conditioning_key == 'concat':
|
||||
xc = torch.cat([x] + c_concat, dim=1)
|
||||
out = self.diffusion_model(xc, t)
|
||||
out = self.diffusion_model(xc, t, control=control)
|
||||
elif self.conditioning_key == 'crossattn':
|
||||
if not self.sequential_cross_attn:
|
||||
cc = torch.cat(c_crossattn, 1)
|
||||
@ -1335,25 +1335,25 @@ class DiffusionWrapper(torch.nn.Module):
|
||||
# TorchScript changes names of the arguments
|
||||
# with argument cc defined as context=cc scripted model will produce
|
||||
# an error: RuntimeError: forward() is missing value for argument 'argument_3'.
|
||||
out = self.scripted_diffusion_model(x, t, cc)
|
||||
out = self.scripted_diffusion_model(x, t, cc, control=control)
|
||||
else:
|
||||
out = self.diffusion_model(x, t, context=cc)
|
||||
out = self.diffusion_model(x, t, context=cc, control=control)
|
||||
elif self.conditioning_key == 'hybrid':
|
||||
xc = torch.cat([x] + c_concat, dim=1)
|
||||
cc = torch.cat(c_crossattn, 1)
|
||||
out = self.diffusion_model(xc, t, context=cc)
|
||||
out = self.diffusion_model(xc, t, context=cc, control=control)
|
||||
elif self.conditioning_key == 'hybrid-adm':
|
||||
assert c_adm is not None
|
||||
xc = torch.cat([x] + c_concat, dim=1)
|
||||
cc = torch.cat(c_crossattn, 1)
|
||||
out = self.diffusion_model(xc, t, context=cc, y=c_adm)
|
||||
out = self.diffusion_model(xc, t, context=cc, y=c_adm, control=control)
|
||||
elif self.conditioning_key == 'crossattn-adm':
|
||||
assert c_adm is not None
|
||||
cc = torch.cat(c_crossattn, 1)
|
||||
out = self.diffusion_model(x, t, context=cc, y=c_adm)
|
||||
out = self.diffusion_model(x, t, context=cc, y=c_adm, control=control)
|
||||
elif self.conditioning_key == 'adm':
|
||||
cc = c_crossattn[0]
|
||||
out = self.diffusion_model(x, t, y=cc)
|
||||
out = self.diffusion_model(x, t, y=cc, control=control)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
@ -753,7 +753,7 @@ class UNetModel(nn.Module):
|
||||
self.middle_block.apply(convert_module_to_f32)
|
||||
self.output_blocks.apply(convert_module_to_f32)
|
||||
|
||||
def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
|
||||
def forward(self, x, timesteps=None, context=None, y=None, control=None, **kwargs):
|
||||
"""
|
||||
Apply the model to an input batch.
|
||||
:param x: an [N x C x ...] Tensor of inputs.
|
||||
@ -778,8 +778,14 @@ class UNetModel(nn.Module):
|
||||
h = module(h, emb, context)
|
||||
hs.append(h)
|
||||
h = self.middle_block(h, emb, context)
|
||||
if control is not None:
|
||||
h += control.pop()
|
||||
|
||||
for module in self.output_blocks:
|
||||
h = th.cat([h, hs.pop()], dim=1)
|
||||
hsp = hs.pop()
|
||||
if control is not None:
|
||||
hsp += control.pop()
|
||||
h = th.cat([h, hsp], dim=1)
|
||||
h = module(h, emb, context)
|
||||
h = h.type(x.dtype)
|
||||
if self.predict_codebook_ids:
|
||||
|
@ -48,7 +48,7 @@ print("Set vram state to:", ["CPU", "NO VRAM", "LOW VRAM", "NORMAL VRAM"][vram_s
|
||||
|
||||
|
||||
current_loaded_model = None
|
||||
|
||||
current_gpu_controlnets = []
|
||||
|
||||
model_accelerated = False
|
||||
|
||||
@ -56,6 +56,7 @@ model_accelerated = False
|
||||
def unload_model():
|
||||
global current_loaded_model
|
||||
global model_accelerated
|
||||
global current_gpu_controlnets
|
||||
if current_loaded_model is not None:
|
||||
if model_accelerated:
|
||||
accelerate.hooks.remove_hook_from_submodules(current_loaded_model.model)
|
||||
@ -64,6 +65,10 @@ def unload_model():
|
||||
current_loaded_model.model.cpu()
|
||||
current_loaded_model.unpatch_model()
|
||||
current_loaded_model = None
|
||||
if len(current_gpu_controlnets) > 0:
|
||||
for n in current_gpu_controlnets:
|
||||
n.cpu()
|
||||
current_gpu_controlnets = []
|
||||
|
||||
|
||||
def load_model_gpu(model):
|
||||
@ -95,6 +100,16 @@ def load_model_gpu(model):
|
||||
model_accelerated = True
|
||||
return current_loaded_model
|
||||
|
||||
def load_controlnet_gpu(models):
|
||||
global current_gpu_controlnets
|
||||
for m in current_gpu_controlnets:
|
||||
if m not in models:
|
||||
m.cpu()
|
||||
|
||||
current_gpu_controlnets = []
|
||||
for m in models:
|
||||
current_gpu_controlnets.append(m.cuda())
|
||||
|
||||
|
||||
def get_free_memory():
|
||||
dev = torch.cuda.current_device()
|
||||
|
@ -21,12 +21,13 @@ class CFGDenoiser(torch.nn.Module):
|
||||
uncond = self.inner_model(x, sigma, cond=uncond)
|
||||
return uncond + (cond - uncond) * cond_scale
|
||||
|
||||
def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_concat=None):
|
||||
def get_area_and_mult(cond, x_in, cond_concat_in):
|
||||
|
||||
#The main sampling function shared by all the samplers
|
||||
#Returns predicted noise
|
||||
def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, cond_concat=None):
|
||||
def get_area_and_mult(cond, x_in, cond_concat_in, timestep_in):
|
||||
area = (x_in.shape[2], x_in.shape[3], 0, 0)
|
||||
strength = 1.0
|
||||
min_sigma = 0.0
|
||||
max_sigma = 999.0
|
||||
if 'area' in cond[1]:
|
||||
area = cond[1]['area']
|
||||
if 'strength' in cond[1]:
|
||||
@ -56,9 +57,15 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_c
|
||||
cr = x[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
|
||||
cropped.append(cr)
|
||||
conditionning['c_concat'] = torch.cat(cropped, dim=1)
|
||||
return (input_x, mult, conditionning, area)
|
||||
|
||||
control = None
|
||||
if 'control' in cond[1]:
|
||||
control = cond[1]['control']
|
||||
return (input_x, mult, conditionning, area, control)
|
||||
|
||||
def cond_equal_size(c1, c2):
|
||||
if c1 is c2:
|
||||
return True
|
||||
if c1.keys() != c2.keys():
|
||||
return False
|
||||
if 'c_crossattn' in c1:
|
||||
@ -69,6 +76,17 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_c
|
||||
return False
|
||||
return True
|
||||
|
||||
def can_concat_cond(c1, c2):
|
||||
if c1[0].shape != c2[0].shape:
|
||||
return False
|
||||
if (c1[4] is None) != (c2[4] is None):
|
||||
return False
|
||||
if c1[4] is not None:
|
||||
if c1[4] is not c2[4]:
|
||||
return False
|
||||
|
||||
return cond_equal_size(c1[2], c2[2])
|
||||
|
||||
def cond_cat(c_list):
|
||||
c_crossattn = []
|
||||
c_concat = []
|
||||
@ -84,7 +102,7 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_c
|
||||
out['c_concat'] = [torch.cat(c_concat)]
|
||||
return out
|
||||
|
||||
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, sigma, max_total_area, cond_concat_in):
|
||||
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, cond_concat_in):
|
||||
out_cond = torch.zeros_like(x_in)
|
||||
out_count = torch.ones_like(x_in)/100000.0
|
||||
|
||||
@ -96,13 +114,13 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_c
|
||||
|
||||
to_run = []
|
||||
for x in cond:
|
||||
p = get_area_and_mult(x, x_in, cond_concat_in)
|
||||
p = get_area_and_mult(x, x_in, cond_concat_in, timestep)
|
||||
if p is None:
|
||||
continue
|
||||
|
||||
to_run += [(p, COND)]
|
||||
for x in uncond:
|
||||
p = get_area_and_mult(x, x_in, cond_concat_in)
|
||||
p = get_area_and_mult(x, x_in, cond_concat_in, timestep)
|
||||
if p is None:
|
||||
continue
|
||||
|
||||
@ -113,9 +131,8 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_c
|
||||
first_shape = first[0][0].shape
|
||||
to_batch_temp = []
|
||||
for x in range(len(to_run)):
|
||||
if to_run[x][0][0].shape == first_shape:
|
||||
if cond_equal_size(to_run[x][0][2], first[0][2]):
|
||||
to_batch_temp += [x]
|
||||
if can_concat_cond(to_run[x][0], first[0]):
|
||||
to_batch_temp += [x]
|
||||
|
||||
to_batch_temp.reverse()
|
||||
to_batch = to_batch_temp[:1]
|
||||
@ -131,6 +148,7 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_c
|
||||
c = []
|
||||
cond_or_uncond = []
|
||||
area = []
|
||||
control = None
|
||||
for x in to_batch:
|
||||
o = to_run.pop(x)
|
||||
p = o[0]
|
||||
@ -139,13 +157,17 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_c
|
||||
c += [p[2]]
|
||||
area += [p[3]]
|
||||
cond_or_uncond += [o[1]]
|
||||
control = p[4]
|
||||
|
||||
batch_chunks = len(cond_or_uncond)
|
||||
input_x = torch.cat(input_x)
|
||||
c = cond_cat(c)
|
||||
sigma_ = torch.cat([sigma] * batch_chunks)
|
||||
timestep_ = torch.cat([timestep] * batch_chunks)
|
||||
|
||||
output = model_function(input_x, sigma_, cond=c).chunk(batch_chunks)
|
||||
if control is not None:
|
||||
c['control'] = control.get_control(input_x, timestep_, c['c_crossattn'])
|
||||
|
||||
output = model_function(input_x, timestep_, cond=c).chunk(batch_chunks)
|
||||
del input_x
|
||||
|
||||
for o in range(batch_chunks):
|
||||
@ -166,10 +188,29 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_c
|
||||
|
||||
|
||||
max_total_area = model_management.maximum_batch_area()
|
||||
cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, sigma, max_total_area, cond_concat)
|
||||
cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, cond_concat)
|
||||
return uncond + (cond - uncond) * cond_scale
|
||||
|
||||
class CFGDenoiserComplex(torch.nn.Module):
|
||||
|
||||
class CompVisVDenoiser(k_diffusion_external.DiscreteVDDPMDenoiser):
|
||||
def __init__(self, model, quantize=False, device='cpu'):
|
||||
super().__init__(model, model.alphas_cumprod, quantize=quantize)
|
||||
|
||||
def get_v(self, x, t, cond, **kwargs):
|
||||
return self.inner_model.apply_model(x, t, cond, **kwargs)
|
||||
|
||||
|
||||
class CFGNoisePredictor(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.inner_model = model
|
||||
self.alphas_cumprod = model.alphas_cumprod
|
||||
def apply_model(self, x, timestep, cond, uncond, cond_scale, cond_concat=None):
|
||||
out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, cond_concat)
|
||||
return out
|
||||
|
||||
|
||||
class KSamplerX0Inpaint(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.inner_model = model
|
||||
@ -177,7 +218,7 @@ class CFGDenoiserComplex(torch.nn.Module):
|
||||
if denoise_mask is not None:
|
||||
latent_mask = 1. - denoise_mask
|
||||
x = x * denoise_mask + (self.latent_image + self.noise * sigma) * latent_mask
|
||||
out = sampling_function(self.inner_model, x, sigma, uncond, cond, cond_scale, cond_concat)
|
||||
out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, cond_concat=cond_concat)
|
||||
if denoise_mask is not None:
|
||||
out *= denoise_mask
|
||||
|
||||
@ -196,8 +237,6 @@ def simple_scheduler(model, steps):
|
||||
def blank_inpaint_image_like(latent_image):
|
||||
blank_image = torch.ones_like(latent_image)
|
||||
# these are the values for "zero" in pixel space translated to latent space
|
||||
# the proper way to do this is to apply the mask to the image in pixel space and then send it through the VAE
|
||||
# unfortunately that gives zero flexibility so I did things like this instead which hopefully works
|
||||
blank_image[:,0] *= 0.8223
|
||||
blank_image[:,1] *= -0.6876
|
||||
blank_image[:,2] *= 0.6364
|
||||
@ -234,6 +273,42 @@ def create_cond_with_same_area_if_none(conds, c):
|
||||
n = c[1].copy()
|
||||
conds += [[smallest[0], n]]
|
||||
|
||||
|
||||
def apply_control_net_to_equal_area(conds, uncond):
|
||||
cond_cnets = []
|
||||
cond_other = []
|
||||
uncond_cnets = []
|
||||
uncond_other = []
|
||||
for t in range(len(conds)):
|
||||
x = conds[t]
|
||||
if 'area' not in x[1]:
|
||||
if 'control' in x[1] and x[1]['control'] is not None:
|
||||
cond_cnets.append(x[1]['control'])
|
||||
else:
|
||||
cond_other.append((x, t))
|
||||
for t in range(len(uncond)):
|
||||
x = uncond[t]
|
||||
if 'area' not in x[1]:
|
||||
if 'control' in x[1] and x[1]['control'] is not None:
|
||||
uncond_cnets.append(x[1]['control'])
|
||||
else:
|
||||
uncond_other.append((x, t))
|
||||
|
||||
if len(uncond_cnets) > 0:
|
||||
return
|
||||
|
||||
for x in range(len(cond_cnets)):
|
||||
temp = uncond_other[x % len(uncond_other)]
|
||||
o = temp[0]
|
||||
if 'control' in o[1] and o[1]['control'] is not None:
|
||||
n = o[1].copy()
|
||||
n['control'] = cond_cnets[x]
|
||||
uncond += [[o[0], n]]
|
||||
else:
|
||||
n = o[1].copy()
|
||||
n['control'] = cond_cnets[x]
|
||||
uncond[temp[1]] = [o[0], n]
|
||||
|
||||
class KSampler:
|
||||
SCHEDULERS = ["karras", "normal", "simple"]
|
||||
SAMPLERS = ["sample_euler", "sample_euler_ancestral", "sample_heun", "sample_dpm_2", "sample_dpm_2_ancestral",
|
||||
@ -242,11 +317,13 @@ class KSampler:
|
||||
|
||||
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None):
|
||||
self.model = model
|
||||
self.model_denoise = CFGNoisePredictor(self.model)
|
||||
if self.model.parameterization == "v":
|
||||
self.model_wrap = k_diffusion_external.CompVisVDenoiser(self.model, quantize=True)
|
||||
self.model_wrap = CompVisVDenoiser(self.model_denoise, quantize=True)
|
||||
else:
|
||||
self.model_wrap = k_diffusion_external.CompVisDenoiser(self.model, quantize=True)
|
||||
self.model_k = CFGDenoiserComplex(self.model_wrap)
|
||||
self.model_wrap = k_diffusion_external.CompVisDenoiser(self.model_denoise, quantize=True)
|
||||
self.model_wrap.parameterization = self.model.parameterization
|
||||
self.model_k = KSamplerX0Inpaint(self.model_wrap)
|
||||
self.device = device
|
||||
if scheduler not in self.SCHEDULERS:
|
||||
scheduler = self.SCHEDULERS[0]
|
||||
@ -316,6 +393,8 @@ class KSampler:
|
||||
for c in negative:
|
||||
create_cond_with_same_area_if_none(positive, c)
|
||||
|
||||
apply_control_net_to_equal_area(positive, negative)
|
||||
|
||||
if self.model.model.diffusion_model.dtype == torch.float16:
|
||||
precision_scope = torch.autocast
|
||||
else:
|
||||
|
76
comfy/sd.py
76
comfy/sd.py
@ -6,6 +6,9 @@ import model_management
|
||||
from ldm.util import instantiate_from_config
|
||||
from ldm.models.autoencoder import AutoencoderKL
|
||||
from omegaconf import OmegaConf
|
||||
from .cldm import cldm
|
||||
|
||||
from . import utils
|
||||
|
||||
def load_torch_file(ckpt):
|
||||
if ckpt.lower().endswith(".safetensors"):
|
||||
@ -323,6 +326,79 @@ class VAE:
|
||||
samples = samples.cpu()
|
||||
return samples
|
||||
|
||||
class ControlNet:
|
||||
def __init__(self, control_model):
|
||||
self.control_model = control_model
|
||||
self.cond_hint_original = None
|
||||
self.cond_hint = None
|
||||
|
||||
def get_control(self, x_noisy, t, cond_txt):
|
||||
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
|
||||
if self.cond_hint is not None:
|
||||
del self.cond_hint
|
||||
self.cond_hint = None
|
||||
self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(x_noisy.device)
|
||||
print("set cond_hint", self.cond_hint.shape)
|
||||
control = self.control_model(x=x_noisy, hint=self.cond_hint, timesteps=t, context=cond_txt)
|
||||
return control
|
||||
|
||||
def set_cond_hint(self, cond_hint):
|
||||
self.cond_hint_original = cond_hint
|
||||
return self
|
||||
|
||||
def cleanup(self):
|
||||
if self.cond_hint is not None:
|
||||
del self.cond_hint
|
||||
self.cond_hint = None
|
||||
|
||||
def copy(self):
|
||||
c = ControlNet(self.control_model)
|
||||
c.cond_hint_original = self.cond_hint_original
|
||||
return c
|
||||
|
||||
def load_controlnet(ckpt_path):
|
||||
controlnet_data = load_torch_file(ckpt_path)
|
||||
pth_key = 'control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight'
|
||||
pth = False
|
||||
sd2 = False
|
||||
key = 'input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight'
|
||||
if pth_key in controlnet_data:
|
||||
pth = True
|
||||
key = pth_key
|
||||
elif key in controlnet_data:
|
||||
pass
|
||||
else:
|
||||
print("error checkpoint does not contain controlnet data", ckpt_path)
|
||||
return None
|
||||
|
||||
context_dim = controlnet_data[key].shape[1]
|
||||
control_model = cldm.ControlNet(image_size=32,
|
||||
in_channels=4,
|
||||
hint_channels=3,
|
||||
model_channels=320,
|
||||
attention_resolutions=[ 4, 2, 1 ],
|
||||
num_res_blocks=2,
|
||||
channel_mult=[ 1, 2, 4, 4 ],
|
||||
num_heads=8,
|
||||
use_spatial_transformer=True,
|
||||
transformer_depth=1,
|
||||
context_dim=context_dim,
|
||||
use_checkpoint=True,
|
||||
legacy=False)
|
||||
|
||||
if pth:
|
||||
class WeightsLoader(torch.nn.Module):
|
||||
pass
|
||||
w = WeightsLoader()
|
||||
w.control_model = control_model
|
||||
w.load_state_dict(controlnet_data, strict=False)
|
||||
else:
|
||||
control_model.load_state_dict(controlnet_data, strict=False)
|
||||
|
||||
control = ControlNet(control_model)
|
||||
return control
|
||||
|
||||
|
||||
def load_clip(ckpt_path, embedding_directory=None):
|
||||
clip_data = load_torch_file(ckpt_path)
|
||||
config = {}
|
||||
|
18
comfy/utils.py
Normal file
18
comfy/utils.py
Normal file
@ -0,0 +1,18 @@
|
||||
import torch
|
||||
|
||||
def common_upscale(samples, width, height, upscale_method, crop):
|
||||
if crop == "center":
|
||||
old_width = samples.shape[3]
|
||||
old_height = samples.shape[2]
|
||||
old_aspect = old_width / old_height
|
||||
new_aspect = width / height
|
||||
x = 0
|
||||
y = 0
|
||||
if old_aspect > new_aspect:
|
||||
x = round((old_width - old_width * (new_aspect / old_aspect)) / 2)
|
||||
elif old_aspect < new_aspect:
|
||||
y = round((old_height - old_height * (old_aspect / new_aspect)) / 2)
|
||||
s = samples[:,:,y:old_height-y,x:old_width-x]
|
||||
else:
|
||||
s = samples
|
||||
return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
|
93
nodes.py
93
nodes.py
@ -15,10 +15,12 @@ sys.path.insert(0, os.path.join(sys.path[0], "comfy"))
|
||||
|
||||
import comfy.samplers
|
||||
import comfy.sd
|
||||
import comfy.utils
|
||||
|
||||
import model_management
|
||||
|
||||
supported_ckpt_extensions = ['.ckpt']
|
||||
supported_pt_extensions = ['.ckpt', '.pt', '.bin']
|
||||
supported_ckpt_extensions = ['.ckpt', '.pth']
|
||||
supported_pt_extensions = ['.ckpt', '.pt', '.bin', '.pth']
|
||||
try:
|
||||
import safetensors.torch
|
||||
supported_ckpt_extensions += ['.safetensors']
|
||||
@ -77,12 +79,14 @@ class ConditioningSetArea:
|
||||
CATEGORY = "conditioning"
|
||||
|
||||
def append(self, conditioning, width, height, x, y, strength, min_sigma=0.0, max_sigma=99.0):
|
||||
c = copy.deepcopy(conditioning)
|
||||
for t in c:
|
||||
t[1]['area'] = (height // 8, width // 8, y // 8, x // 8)
|
||||
t[1]['strength'] = strength
|
||||
t[1]['min_sigma'] = min_sigma
|
||||
t[1]['max_sigma'] = max_sigma
|
||||
c = []
|
||||
for t in conditioning:
|
||||
n = [t[0], t[1].copy()]
|
||||
n[1]['area'] = (height // 8, width // 8, y // 8, x // 8)
|
||||
n[1]['strength'] = strength
|
||||
n[1]['min_sigma'] = min_sigma
|
||||
n[1]['max_sigma'] = max_sigma
|
||||
c.append(n)
|
||||
return (c, )
|
||||
|
||||
class VAEDecode:
|
||||
@ -134,7 +138,6 @@ class VAEEncodeForInpaint:
|
||||
CATEGORY = "latent/inpaint"
|
||||
|
||||
def encode(self, vae, pixels, mask):
|
||||
print(pixels.shape, mask.shape)
|
||||
x = (pixels.shape[1] // 64) * 64
|
||||
y = (pixels.shape[2] // 64) * 64
|
||||
if pixels.shape[1] != x or pixels.shape[2] != y:
|
||||
@ -144,7 +147,6 @@ class VAEEncodeForInpaint:
|
||||
#shave off a few pixels to keep things seamless
|
||||
kernel_tensor = torch.ones((1, 1, 6, 6))
|
||||
mask_erosion = torch.clamp(torch.nn.functional.conv2d((1.0 - mask.round())[None], kernel_tensor, padding=3), 0, 1)
|
||||
print(mask_erosion.shape, pixels.shape)
|
||||
for i in range(3):
|
||||
pixels[:,:,:,i] -= 0.5
|
||||
pixels[:,:,:,i] *= mask_erosion[0][:x,:y].round()
|
||||
@ -211,6 +213,44 @@ class VAELoader:
|
||||
vae = comfy.sd.VAE(ckpt_path=vae_path)
|
||||
return (vae,)
|
||||
|
||||
class ControlNetLoader:
|
||||
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
|
||||
controlnet_dir = os.path.join(models_dir, "controlnet")
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "control_net_name": (filter_files_extensions(recursive_search(s.controlnet_dir), supported_pt_extensions), )}}
|
||||
|
||||
RETURN_TYPES = ("CONTROL_NET",)
|
||||
FUNCTION = "load_controlnet"
|
||||
|
||||
CATEGORY = "loaders"
|
||||
|
||||
def load_controlnet(self, control_net_name):
|
||||
controlnet_path = os.path.join(self.controlnet_dir, control_net_name)
|
||||
controlnet = comfy.sd.load_controlnet(controlnet_path)
|
||||
return (controlnet,)
|
||||
|
||||
|
||||
class ControlNetApply:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"conditioning": ("CONDITIONING", ), "control_net": ("CONTROL_NET", ), "image": ("IMAGE", )}}
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "apply_controlnet"
|
||||
|
||||
CATEGORY = "conditioning"
|
||||
|
||||
def apply_controlnet(self, conditioning, control_net, image):
|
||||
c = []
|
||||
control_hint = image.movedim(-1,1)
|
||||
print(control_hint.shape)
|
||||
for t in conditioning:
|
||||
n = [t[0], t[1].copy()]
|
||||
n[1]['control'] = control_net.copy().set_cond_hint(control_hint)
|
||||
c.append(n)
|
||||
return (c, )
|
||||
|
||||
|
||||
class CLIPLoader:
|
||||
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
|
||||
clip_dir = os.path.join(models_dir, "clip")
|
||||
@ -248,22 +288,7 @@ class EmptyLatentImage:
|
||||
latent = torch.zeros([batch_size, 4, height // 8, width // 8])
|
||||
return ({"samples":latent}, )
|
||||
|
||||
def common_upscale(samples, width, height, upscale_method, crop):
|
||||
if crop == "center":
|
||||
old_width = samples.shape[3]
|
||||
old_height = samples.shape[2]
|
||||
old_aspect = old_width / old_height
|
||||
new_aspect = width / height
|
||||
x = 0
|
||||
y = 0
|
||||
if old_aspect > new_aspect:
|
||||
x = round((old_width - old_width * (new_aspect / old_aspect)) / 2)
|
||||
elif old_aspect < new_aspect:
|
||||
y = round((old_height - old_height * (old_aspect / new_aspect)) / 2)
|
||||
s = samples[:,:,y:old_height-y,x:old_width-x]
|
||||
else:
|
||||
s = samples
|
||||
return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
|
||||
|
||||
|
||||
class LatentUpscale:
|
||||
upscale_methods = ["nearest-exact", "bilinear", "area"]
|
||||
@ -282,7 +307,7 @@ class LatentUpscale:
|
||||
|
||||
def upscale(self, samples, upscale_method, width, height, crop):
|
||||
s = samples.copy()
|
||||
s["samples"] = common_upscale(samples["samples"], width // 8, height // 8, upscale_method, crop)
|
||||
s["samples"] = comfy.utils.common_upscale(samples["samples"], width // 8, height // 8, upscale_method, crop)
|
||||
return (s,)
|
||||
|
||||
class LatentRotate:
|
||||
@ -461,19 +486,26 @@ def common_ksampler(device, model, seed, steps, cfg, sampler_name, scheduler, po
|
||||
positive_copy = []
|
||||
negative_copy = []
|
||||
|
||||
control_nets = []
|
||||
for p in positive:
|
||||
t = p[0]
|
||||
if t.shape[0] < noise.shape[0]:
|
||||
t = torch.cat([t] * noise.shape[0])
|
||||
t = t.to(device)
|
||||
if 'control' in p[1]:
|
||||
control_nets += [p[1]['control']]
|
||||
positive_copy += [[t] + p[1:]]
|
||||
for n in negative:
|
||||
t = n[0]
|
||||
if t.shape[0] < noise.shape[0]:
|
||||
t = torch.cat([t] * noise.shape[0])
|
||||
t = t.to(device)
|
||||
if 'control' in p[1]:
|
||||
control_nets += [p[1]['control']]
|
||||
negative_copy += [[t] + n[1:]]
|
||||
|
||||
model_management.load_controlnet_gpu(list(map(lambda a: a.control_model, control_nets)))
|
||||
|
||||
if sampler_name in comfy.samplers.KSampler.SAMPLERS:
|
||||
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise)
|
||||
else:
|
||||
@ -482,6 +514,9 @@ def common_ksampler(device, model, seed, steps, cfg, sampler_name, scheduler, po
|
||||
|
||||
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask)
|
||||
samples = samples.cpu()
|
||||
for c in control_nets:
|
||||
c.cleanup()
|
||||
|
||||
out = latent.copy()
|
||||
out["samples"] = samples
|
||||
return (out, )
|
||||
@ -676,7 +711,7 @@ class ImageScale:
|
||||
|
||||
def upscale(self, image, upscale_method, width, height, crop):
|
||||
samples = image.movedim(-1,1)
|
||||
s = common_upscale(samples, width, height, upscale_method, crop)
|
||||
s = comfy.utils.common_upscale(samples, width, height, upscale_method, crop)
|
||||
s = s.movedim(1,-1)
|
||||
return (s,)
|
||||
|
||||
@ -704,6 +739,8 @@ NODE_CLASS_MAPPINGS = {
|
||||
"LatentCrop": LatentCrop,
|
||||
"LoraLoader": LoraLoader,
|
||||
"CLIPLoader": CLIPLoader,
|
||||
"ControlNetApply": ControlNetApply,
|
||||
"ControlNetLoader": ControlNetLoader,
|
||||
}
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user