mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Add ControlNet support.
This commit is contained in:
parent
bc69fb5245
commit
4efa67fa12
286
comfy/cldm/cldm.py
Normal file
286
comfy/cldm/cldm.py
Normal file
@ -0,0 +1,286 @@
|
|||||||
|
#taken from: https://github.com/lllyasviel/ControlNet
|
||||||
|
#and modified
|
||||||
|
|
||||||
|
import einops
|
||||||
|
import torch
|
||||||
|
import torch as th
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from ldm.modules.diffusionmodules.util import (
|
||||||
|
conv_nd,
|
||||||
|
linear,
|
||||||
|
zero_module,
|
||||||
|
timestep_embedding,
|
||||||
|
)
|
||||||
|
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
from torchvision.utils import make_grid
|
||||||
|
from ldm.modules.attention import SpatialTransformer
|
||||||
|
from ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock
|
||||||
|
from ldm.models.diffusion.ddpm import LatentDiffusion
|
||||||
|
from ldm.util import log_txt_as_img, exists, instantiate_from_config
|
||||||
|
|
||||||
|
|
||||||
|
class ControlledUnetModel(UNetModel):
|
||||||
|
#implemented in the ldm unet
|
||||||
|
pass
|
||||||
|
|
||||||
|
class ControlNet(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
image_size,
|
||||||
|
in_channels,
|
||||||
|
model_channels,
|
||||||
|
hint_channels,
|
||||||
|
num_res_blocks,
|
||||||
|
attention_resolutions,
|
||||||
|
dropout=0,
|
||||||
|
channel_mult=(1, 2, 4, 8),
|
||||||
|
conv_resample=True,
|
||||||
|
dims=2,
|
||||||
|
use_checkpoint=False,
|
||||||
|
use_fp16=False,
|
||||||
|
num_heads=-1,
|
||||||
|
num_head_channels=-1,
|
||||||
|
num_heads_upsample=-1,
|
||||||
|
use_scale_shift_norm=False,
|
||||||
|
resblock_updown=False,
|
||||||
|
use_new_attention_order=False,
|
||||||
|
use_spatial_transformer=False, # custom transformer support
|
||||||
|
transformer_depth=1, # custom transformer support
|
||||||
|
context_dim=None, # custom transformer support
|
||||||
|
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
|
||||||
|
legacy=True,
|
||||||
|
disable_self_attentions=None,
|
||||||
|
num_attention_blocks=None,
|
||||||
|
disable_middle_self_attn=False,
|
||||||
|
use_linear_in_transformer=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
if use_spatial_transformer:
|
||||||
|
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
|
||||||
|
|
||||||
|
if context_dim is not None:
|
||||||
|
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
|
||||||
|
from omegaconf.listconfig import ListConfig
|
||||||
|
if type(context_dim) == ListConfig:
|
||||||
|
context_dim = list(context_dim)
|
||||||
|
|
||||||
|
if num_heads_upsample == -1:
|
||||||
|
num_heads_upsample = num_heads
|
||||||
|
|
||||||
|
if num_heads == -1:
|
||||||
|
assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
|
||||||
|
|
||||||
|
if num_head_channels == -1:
|
||||||
|
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
|
||||||
|
|
||||||
|
self.dims = dims
|
||||||
|
self.image_size = image_size
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.model_channels = model_channels
|
||||||
|
if isinstance(num_res_blocks, int):
|
||||||
|
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
|
||||||
|
else:
|
||||||
|
if len(num_res_blocks) != len(channel_mult):
|
||||||
|
raise ValueError("provide num_res_blocks either as an int (globally constant) or "
|
||||||
|
"as a list/tuple (per-level) with the same length as channel_mult")
|
||||||
|
self.num_res_blocks = num_res_blocks
|
||||||
|
if disable_self_attentions is not None:
|
||||||
|
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
|
||||||
|
assert len(disable_self_attentions) == len(channel_mult)
|
||||||
|
if num_attention_blocks is not None:
|
||||||
|
assert len(num_attention_blocks) == len(self.num_res_blocks)
|
||||||
|
assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
|
||||||
|
print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
|
||||||
|
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
|
||||||
|
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
|
||||||
|
f"attention will still not be set.")
|
||||||
|
|
||||||
|
self.attention_resolutions = attention_resolutions
|
||||||
|
self.dropout = dropout
|
||||||
|
self.channel_mult = channel_mult
|
||||||
|
self.conv_resample = conv_resample
|
||||||
|
self.use_checkpoint = use_checkpoint
|
||||||
|
self.dtype = th.float16 if use_fp16 else th.float32
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.num_head_channels = num_head_channels
|
||||||
|
self.num_heads_upsample = num_heads_upsample
|
||||||
|
self.predict_codebook_ids = n_embed is not None
|
||||||
|
|
||||||
|
time_embed_dim = model_channels * 4
|
||||||
|
self.time_embed = nn.Sequential(
|
||||||
|
linear(model_channels, time_embed_dim),
|
||||||
|
nn.SiLU(),
|
||||||
|
linear(time_embed_dim, time_embed_dim),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.input_blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
TimestepEmbedSequential(
|
||||||
|
conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)])
|
||||||
|
|
||||||
|
self.input_hint_block = TimestepEmbedSequential(
|
||||||
|
conv_nd(dims, hint_channels, 16, 3, padding=1),
|
||||||
|
nn.SiLU(),
|
||||||
|
conv_nd(dims, 16, 16, 3, padding=1),
|
||||||
|
nn.SiLU(),
|
||||||
|
conv_nd(dims, 16, 32, 3, padding=1, stride=2),
|
||||||
|
nn.SiLU(),
|
||||||
|
conv_nd(dims, 32, 32, 3, padding=1),
|
||||||
|
nn.SiLU(),
|
||||||
|
conv_nd(dims, 32, 96, 3, padding=1, stride=2),
|
||||||
|
nn.SiLU(),
|
||||||
|
conv_nd(dims, 96, 96, 3, padding=1),
|
||||||
|
nn.SiLU(),
|
||||||
|
conv_nd(dims, 96, 256, 3, padding=1, stride=2),
|
||||||
|
nn.SiLU(),
|
||||||
|
zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
|
||||||
|
)
|
||||||
|
|
||||||
|
self._feature_size = model_channels
|
||||||
|
input_block_chans = [model_channels]
|
||||||
|
ch = model_channels
|
||||||
|
ds = 1
|
||||||
|
for level, mult in enumerate(channel_mult):
|
||||||
|
for nr in range(self.num_res_blocks[level]):
|
||||||
|
layers = [
|
||||||
|
ResBlock(
|
||||||
|
ch,
|
||||||
|
time_embed_dim,
|
||||||
|
dropout,
|
||||||
|
out_channels=mult * model_channels,
|
||||||
|
dims=dims,
|
||||||
|
use_checkpoint=use_checkpoint,
|
||||||
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
ch = mult * model_channels
|
||||||
|
if ds in attention_resolutions:
|
||||||
|
if num_head_channels == -1:
|
||||||
|
dim_head = ch // num_heads
|
||||||
|
else:
|
||||||
|
num_heads = ch // num_head_channels
|
||||||
|
dim_head = num_head_channels
|
||||||
|
if legacy:
|
||||||
|
#num_heads = 1
|
||||||
|
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||||
|
if exists(disable_self_attentions):
|
||||||
|
disabled_sa = disable_self_attentions[level]
|
||||||
|
else:
|
||||||
|
disabled_sa = False
|
||||||
|
|
||||||
|
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
|
||||||
|
layers.append(
|
||||||
|
AttentionBlock(
|
||||||
|
ch,
|
||||||
|
use_checkpoint=use_checkpoint,
|
||||||
|
num_heads=num_heads,
|
||||||
|
num_head_channels=dim_head,
|
||||||
|
use_new_attention_order=use_new_attention_order,
|
||||||
|
) if not use_spatial_transformer else SpatialTransformer(
|
||||||
|
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
|
||||||
|
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
||||||
|
use_checkpoint=use_checkpoint
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||||
|
self.zero_convs.append(self.make_zero_conv(ch))
|
||||||
|
self._feature_size += ch
|
||||||
|
input_block_chans.append(ch)
|
||||||
|
if level != len(channel_mult) - 1:
|
||||||
|
out_ch = ch
|
||||||
|
self.input_blocks.append(
|
||||||
|
TimestepEmbedSequential(
|
||||||
|
ResBlock(
|
||||||
|
ch,
|
||||||
|
time_embed_dim,
|
||||||
|
dropout,
|
||||||
|
out_channels=out_ch,
|
||||||
|
dims=dims,
|
||||||
|
use_checkpoint=use_checkpoint,
|
||||||
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
|
down=True,
|
||||||
|
)
|
||||||
|
if resblock_updown
|
||||||
|
else Downsample(
|
||||||
|
ch, conv_resample, dims=dims, out_channels=out_ch
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
ch = out_ch
|
||||||
|
input_block_chans.append(ch)
|
||||||
|
self.zero_convs.append(self.make_zero_conv(ch))
|
||||||
|
ds *= 2
|
||||||
|
self._feature_size += ch
|
||||||
|
|
||||||
|
if num_head_channels == -1:
|
||||||
|
dim_head = ch // num_heads
|
||||||
|
else:
|
||||||
|
num_heads = ch // num_head_channels
|
||||||
|
dim_head = num_head_channels
|
||||||
|
if legacy:
|
||||||
|
#num_heads = 1
|
||||||
|
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||||
|
self.middle_block = TimestepEmbedSequential(
|
||||||
|
ResBlock(
|
||||||
|
ch,
|
||||||
|
time_embed_dim,
|
||||||
|
dropout,
|
||||||
|
dims=dims,
|
||||||
|
use_checkpoint=use_checkpoint,
|
||||||
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
|
),
|
||||||
|
AttentionBlock(
|
||||||
|
ch,
|
||||||
|
use_checkpoint=use_checkpoint,
|
||||||
|
num_heads=num_heads,
|
||||||
|
num_head_channels=dim_head,
|
||||||
|
use_new_attention_order=use_new_attention_order,
|
||||||
|
) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
|
||||||
|
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
|
||||||
|
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
|
||||||
|
use_checkpoint=use_checkpoint
|
||||||
|
),
|
||||||
|
ResBlock(
|
||||||
|
ch,
|
||||||
|
time_embed_dim,
|
||||||
|
dropout,
|
||||||
|
dims=dims,
|
||||||
|
use_checkpoint=use_checkpoint,
|
||||||
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.middle_block_out = self.make_zero_conv(ch)
|
||||||
|
self._feature_size += ch
|
||||||
|
|
||||||
|
def make_zero_conv(self, channels):
|
||||||
|
return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0)))
|
||||||
|
|
||||||
|
def forward(self, x, hint, timesteps, context, **kwargs):
|
||||||
|
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
||||||
|
emb = self.time_embed(t_emb)
|
||||||
|
|
||||||
|
guided_hint = self.input_hint_block(hint, emb, context)
|
||||||
|
|
||||||
|
outs = []
|
||||||
|
|
||||||
|
h = x.type(self.dtype)
|
||||||
|
for module, zero_conv in zip(self.input_blocks, self.zero_convs):
|
||||||
|
if guided_hint is not None:
|
||||||
|
h = module(h, emb, context)
|
||||||
|
h += guided_hint
|
||||||
|
guided_hint = None
|
||||||
|
else:
|
||||||
|
h = module(h, emb, context)
|
||||||
|
outs.append(zero_conv(h, emb, context))
|
||||||
|
|
||||||
|
h = self.middle_block(h, emb, context)
|
||||||
|
outs.append(self.middle_block_out(h, emb, context))
|
||||||
|
|
||||||
|
return outs
|
||||||
|
|
@ -856,13 +856,13 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, extra_args=None
|
|||||||
|
|
||||||
device = noise.device
|
device = noise.device
|
||||||
|
|
||||||
if model.inner_model.parameterization == "v":
|
if model.parameterization == "v":
|
||||||
model_type = "v"
|
model_type = "v"
|
||||||
else:
|
else:
|
||||||
model_type = "noise"
|
model_type = "noise"
|
||||||
|
|
||||||
model_fn = model_wrapper(
|
model_fn = model_wrapper(
|
||||||
model.inner_model.apply_model,
|
model.inner_model.inner_model.apply_model,
|
||||||
sampling_function,
|
sampling_function,
|
||||||
ns,
|
ns,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
|
@ -1320,12 +1320,12 @@ class DiffusionWrapper(torch.nn.Module):
|
|||||||
self.conditioning_key = conditioning_key
|
self.conditioning_key = conditioning_key
|
||||||
assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm']
|
assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm']
|
||||||
|
|
||||||
def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None):
|
def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None, control=None):
|
||||||
if self.conditioning_key is None:
|
if self.conditioning_key is None:
|
||||||
out = self.diffusion_model(x, t)
|
out = self.diffusion_model(x, t, control=control)
|
||||||
elif self.conditioning_key == 'concat':
|
elif self.conditioning_key == 'concat':
|
||||||
xc = torch.cat([x] + c_concat, dim=1)
|
xc = torch.cat([x] + c_concat, dim=1)
|
||||||
out = self.diffusion_model(xc, t)
|
out = self.diffusion_model(xc, t, control=control)
|
||||||
elif self.conditioning_key == 'crossattn':
|
elif self.conditioning_key == 'crossattn':
|
||||||
if not self.sequential_cross_attn:
|
if not self.sequential_cross_attn:
|
||||||
cc = torch.cat(c_crossattn, 1)
|
cc = torch.cat(c_crossattn, 1)
|
||||||
@ -1335,25 +1335,25 @@ class DiffusionWrapper(torch.nn.Module):
|
|||||||
# TorchScript changes names of the arguments
|
# TorchScript changes names of the arguments
|
||||||
# with argument cc defined as context=cc scripted model will produce
|
# with argument cc defined as context=cc scripted model will produce
|
||||||
# an error: RuntimeError: forward() is missing value for argument 'argument_3'.
|
# an error: RuntimeError: forward() is missing value for argument 'argument_3'.
|
||||||
out = self.scripted_diffusion_model(x, t, cc)
|
out = self.scripted_diffusion_model(x, t, cc, control=control)
|
||||||
else:
|
else:
|
||||||
out = self.diffusion_model(x, t, context=cc)
|
out = self.diffusion_model(x, t, context=cc, control=control)
|
||||||
elif self.conditioning_key == 'hybrid':
|
elif self.conditioning_key == 'hybrid':
|
||||||
xc = torch.cat([x] + c_concat, dim=1)
|
xc = torch.cat([x] + c_concat, dim=1)
|
||||||
cc = torch.cat(c_crossattn, 1)
|
cc = torch.cat(c_crossattn, 1)
|
||||||
out = self.diffusion_model(xc, t, context=cc)
|
out = self.diffusion_model(xc, t, context=cc, control=control)
|
||||||
elif self.conditioning_key == 'hybrid-adm':
|
elif self.conditioning_key == 'hybrid-adm':
|
||||||
assert c_adm is not None
|
assert c_adm is not None
|
||||||
xc = torch.cat([x] + c_concat, dim=1)
|
xc = torch.cat([x] + c_concat, dim=1)
|
||||||
cc = torch.cat(c_crossattn, 1)
|
cc = torch.cat(c_crossattn, 1)
|
||||||
out = self.diffusion_model(xc, t, context=cc, y=c_adm)
|
out = self.diffusion_model(xc, t, context=cc, y=c_adm, control=control)
|
||||||
elif self.conditioning_key == 'crossattn-adm':
|
elif self.conditioning_key == 'crossattn-adm':
|
||||||
assert c_adm is not None
|
assert c_adm is not None
|
||||||
cc = torch.cat(c_crossattn, 1)
|
cc = torch.cat(c_crossattn, 1)
|
||||||
out = self.diffusion_model(x, t, context=cc, y=c_adm)
|
out = self.diffusion_model(x, t, context=cc, y=c_adm, control=control)
|
||||||
elif self.conditioning_key == 'adm':
|
elif self.conditioning_key == 'adm':
|
||||||
cc = c_crossattn[0]
|
cc = c_crossattn[0]
|
||||||
out = self.diffusion_model(x, t, y=cc)
|
out = self.diffusion_model(x, t, y=cc, control=control)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@ -753,7 +753,7 @@ class UNetModel(nn.Module):
|
|||||||
self.middle_block.apply(convert_module_to_f32)
|
self.middle_block.apply(convert_module_to_f32)
|
||||||
self.output_blocks.apply(convert_module_to_f32)
|
self.output_blocks.apply(convert_module_to_f32)
|
||||||
|
|
||||||
def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
|
def forward(self, x, timesteps=None, context=None, y=None, control=None, **kwargs):
|
||||||
"""
|
"""
|
||||||
Apply the model to an input batch.
|
Apply the model to an input batch.
|
||||||
:param x: an [N x C x ...] Tensor of inputs.
|
:param x: an [N x C x ...] Tensor of inputs.
|
||||||
@ -778,8 +778,14 @@ class UNetModel(nn.Module):
|
|||||||
h = module(h, emb, context)
|
h = module(h, emb, context)
|
||||||
hs.append(h)
|
hs.append(h)
|
||||||
h = self.middle_block(h, emb, context)
|
h = self.middle_block(h, emb, context)
|
||||||
|
if control is not None:
|
||||||
|
h += control.pop()
|
||||||
|
|
||||||
for module in self.output_blocks:
|
for module in self.output_blocks:
|
||||||
h = th.cat([h, hs.pop()], dim=1)
|
hsp = hs.pop()
|
||||||
|
if control is not None:
|
||||||
|
hsp += control.pop()
|
||||||
|
h = th.cat([h, hsp], dim=1)
|
||||||
h = module(h, emb, context)
|
h = module(h, emb, context)
|
||||||
h = h.type(x.dtype)
|
h = h.type(x.dtype)
|
||||||
if self.predict_codebook_ids:
|
if self.predict_codebook_ids:
|
||||||
|
@ -48,7 +48,7 @@ print("Set vram state to:", ["CPU", "NO VRAM", "LOW VRAM", "NORMAL VRAM"][vram_s
|
|||||||
|
|
||||||
|
|
||||||
current_loaded_model = None
|
current_loaded_model = None
|
||||||
|
current_gpu_controlnets = []
|
||||||
|
|
||||||
model_accelerated = False
|
model_accelerated = False
|
||||||
|
|
||||||
@ -56,6 +56,7 @@ model_accelerated = False
|
|||||||
def unload_model():
|
def unload_model():
|
||||||
global current_loaded_model
|
global current_loaded_model
|
||||||
global model_accelerated
|
global model_accelerated
|
||||||
|
global current_gpu_controlnets
|
||||||
if current_loaded_model is not None:
|
if current_loaded_model is not None:
|
||||||
if model_accelerated:
|
if model_accelerated:
|
||||||
accelerate.hooks.remove_hook_from_submodules(current_loaded_model.model)
|
accelerate.hooks.remove_hook_from_submodules(current_loaded_model.model)
|
||||||
@ -64,6 +65,10 @@ def unload_model():
|
|||||||
current_loaded_model.model.cpu()
|
current_loaded_model.model.cpu()
|
||||||
current_loaded_model.unpatch_model()
|
current_loaded_model.unpatch_model()
|
||||||
current_loaded_model = None
|
current_loaded_model = None
|
||||||
|
if len(current_gpu_controlnets) > 0:
|
||||||
|
for n in current_gpu_controlnets:
|
||||||
|
n.cpu()
|
||||||
|
current_gpu_controlnets = []
|
||||||
|
|
||||||
|
|
||||||
def load_model_gpu(model):
|
def load_model_gpu(model):
|
||||||
@ -95,6 +100,16 @@ def load_model_gpu(model):
|
|||||||
model_accelerated = True
|
model_accelerated = True
|
||||||
return current_loaded_model
|
return current_loaded_model
|
||||||
|
|
||||||
|
def load_controlnet_gpu(models):
|
||||||
|
global current_gpu_controlnets
|
||||||
|
for m in current_gpu_controlnets:
|
||||||
|
if m not in models:
|
||||||
|
m.cpu()
|
||||||
|
|
||||||
|
current_gpu_controlnets = []
|
||||||
|
for m in models:
|
||||||
|
current_gpu_controlnets.append(m.cuda())
|
||||||
|
|
||||||
|
|
||||||
def get_free_memory():
|
def get_free_memory():
|
||||||
dev = torch.cuda.current_device()
|
dev = torch.cuda.current_device()
|
||||||
|
@ -21,12 +21,13 @@ class CFGDenoiser(torch.nn.Module):
|
|||||||
uncond = self.inner_model(x, sigma, cond=uncond)
|
uncond = self.inner_model(x, sigma, cond=uncond)
|
||||||
return uncond + (cond - uncond) * cond_scale
|
return uncond + (cond - uncond) * cond_scale
|
||||||
|
|
||||||
def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_concat=None):
|
|
||||||
def get_area_and_mult(cond, x_in, cond_concat_in):
|
#The main sampling function shared by all the samplers
|
||||||
|
#Returns predicted noise
|
||||||
|
def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, cond_concat=None):
|
||||||
|
def get_area_and_mult(cond, x_in, cond_concat_in, timestep_in):
|
||||||
area = (x_in.shape[2], x_in.shape[3], 0, 0)
|
area = (x_in.shape[2], x_in.shape[3], 0, 0)
|
||||||
strength = 1.0
|
strength = 1.0
|
||||||
min_sigma = 0.0
|
|
||||||
max_sigma = 999.0
|
|
||||||
if 'area' in cond[1]:
|
if 'area' in cond[1]:
|
||||||
area = cond[1]['area']
|
area = cond[1]['area']
|
||||||
if 'strength' in cond[1]:
|
if 'strength' in cond[1]:
|
||||||
@ -56,9 +57,15 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_c
|
|||||||
cr = x[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
|
cr = x[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
|
||||||
cropped.append(cr)
|
cropped.append(cr)
|
||||||
conditionning['c_concat'] = torch.cat(cropped, dim=1)
|
conditionning['c_concat'] = torch.cat(cropped, dim=1)
|
||||||
return (input_x, mult, conditionning, area)
|
|
||||||
|
control = None
|
||||||
|
if 'control' in cond[1]:
|
||||||
|
control = cond[1]['control']
|
||||||
|
return (input_x, mult, conditionning, area, control)
|
||||||
|
|
||||||
def cond_equal_size(c1, c2):
|
def cond_equal_size(c1, c2):
|
||||||
|
if c1 is c2:
|
||||||
|
return True
|
||||||
if c1.keys() != c2.keys():
|
if c1.keys() != c2.keys():
|
||||||
return False
|
return False
|
||||||
if 'c_crossattn' in c1:
|
if 'c_crossattn' in c1:
|
||||||
@ -69,6 +76,17 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_c
|
|||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def can_concat_cond(c1, c2):
|
||||||
|
if c1[0].shape != c2[0].shape:
|
||||||
|
return False
|
||||||
|
if (c1[4] is None) != (c2[4] is None):
|
||||||
|
return False
|
||||||
|
if c1[4] is not None:
|
||||||
|
if c1[4] is not c2[4]:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return cond_equal_size(c1[2], c2[2])
|
||||||
|
|
||||||
def cond_cat(c_list):
|
def cond_cat(c_list):
|
||||||
c_crossattn = []
|
c_crossattn = []
|
||||||
c_concat = []
|
c_concat = []
|
||||||
@ -84,7 +102,7 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_c
|
|||||||
out['c_concat'] = [torch.cat(c_concat)]
|
out['c_concat'] = [torch.cat(c_concat)]
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, sigma, max_total_area, cond_concat_in):
|
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, cond_concat_in):
|
||||||
out_cond = torch.zeros_like(x_in)
|
out_cond = torch.zeros_like(x_in)
|
||||||
out_count = torch.ones_like(x_in)/100000.0
|
out_count = torch.ones_like(x_in)/100000.0
|
||||||
|
|
||||||
@ -96,13 +114,13 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_c
|
|||||||
|
|
||||||
to_run = []
|
to_run = []
|
||||||
for x in cond:
|
for x in cond:
|
||||||
p = get_area_and_mult(x, x_in, cond_concat_in)
|
p = get_area_and_mult(x, x_in, cond_concat_in, timestep)
|
||||||
if p is None:
|
if p is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
to_run += [(p, COND)]
|
to_run += [(p, COND)]
|
||||||
for x in uncond:
|
for x in uncond:
|
||||||
p = get_area_and_mult(x, x_in, cond_concat_in)
|
p = get_area_and_mult(x, x_in, cond_concat_in, timestep)
|
||||||
if p is None:
|
if p is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -113,8 +131,7 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_c
|
|||||||
first_shape = first[0][0].shape
|
first_shape = first[0][0].shape
|
||||||
to_batch_temp = []
|
to_batch_temp = []
|
||||||
for x in range(len(to_run)):
|
for x in range(len(to_run)):
|
||||||
if to_run[x][0][0].shape == first_shape:
|
if can_concat_cond(to_run[x][0], first[0]):
|
||||||
if cond_equal_size(to_run[x][0][2], first[0][2]):
|
|
||||||
to_batch_temp += [x]
|
to_batch_temp += [x]
|
||||||
|
|
||||||
to_batch_temp.reverse()
|
to_batch_temp.reverse()
|
||||||
@ -131,6 +148,7 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_c
|
|||||||
c = []
|
c = []
|
||||||
cond_or_uncond = []
|
cond_or_uncond = []
|
||||||
area = []
|
area = []
|
||||||
|
control = None
|
||||||
for x in to_batch:
|
for x in to_batch:
|
||||||
o = to_run.pop(x)
|
o = to_run.pop(x)
|
||||||
p = o[0]
|
p = o[0]
|
||||||
@ -139,13 +157,17 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_c
|
|||||||
c += [p[2]]
|
c += [p[2]]
|
||||||
area += [p[3]]
|
area += [p[3]]
|
||||||
cond_or_uncond += [o[1]]
|
cond_or_uncond += [o[1]]
|
||||||
|
control = p[4]
|
||||||
|
|
||||||
batch_chunks = len(cond_or_uncond)
|
batch_chunks = len(cond_or_uncond)
|
||||||
input_x = torch.cat(input_x)
|
input_x = torch.cat(input_x)
|
||||||
c = cond_cat(c)
|
c = cond_cat(c)
|
||||||
sigma_ = torch.cat([sigma] * batch_chunks)
|
timestep_ = torch.cat([timestep] * batch_chunks)
|
||||||
|
|
||||||
output = model_function(input_x, sigma_, cond=c).chunk(batch_chunks)
|
if control is not None:
|
||||||
|
c['control'] = control.get_control(input_x, timestep_, c['c_crossattn'])
|
||||||
|
|
||||||
|
output = model_function(input_x, timestep_, cond=c).chunk(batch_chunks)
|
||||||
del input_x
|
del input_x
|
||||||
|
|
||||||
for o in range(batch_chunks):
|
for o in range(batch_chunks):
|
||||||
@ -166,10 +188,29 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_c
|
|||||||
|
|
||||||
|
|
||||||
max_total_area = model_management.maximum_batch_area()
|
max_total_area = model_management.maximum_batch_area()
|
||||||
cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, sigma, max_total_area, cond_concat)
|
cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, cond_concat)
|
||||||
return uncond + (cond - uncond) * cond_scale
|
return uncond + (cond - uncond) * cond_scale
|
||||||
|
|
||||||
class CFGDenoiserComplex(torch.nn.Module):
|
|
||||||
|
class CompVisVDenoiser(k_diffusion_external.DiscreteVDDPMDenoiser):
|
||||||
|
def __init__(self, model, quantize=False, device='cpu'):
|
||||||
|
super().__init__(model, model.alphas_cumprod, quantize=quantize)
|
||||||
|
|
||||||
|
def get_v(self, x, t, cond, **kwargs):
|
||||||
|
return self.inner_model.apply_model(x, t, cond, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class CFGNoisePredictor(torch.nn.Module):
|
||||||
|
def __init__(self, model):
|
||||||
|
super().__init__()
|
||||||
|
self.inner_model = model
|
||||||
|
self.alphas_cumprod = model.alphas_cumprod
|
||||||
|
def apply_model(self, x, timestep, cond, uncond, cond_scale, cond_concat=None):
|
||||||
|
out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, cond_concat)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class KSamplerX0Inpaint(torch.nn.Module):
|
||||||
def __init__(self, model):
|
def __init__(self, model):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.inner_model = model
|
self.inner_model = model
|
||||||
@ -177,7 +218,7 @@ class CFGDenoiserComplex(torch.nn.Module):
|
|||||||
if denoise_mask is not None:
|
if denoise_mask is not None:
|
||||||
latent_mask = 1. - denoise_mask
|
latent_mask = 1. - denoise_mask
|
||||||
x = x * denoise_mask + (self.latent_image + self.noise * sigma) * latent_mask
|
x = x * denoise_mask + (self.latent_image + self.noise * sigma) * latent_mask
|
||||||
out = sampling_function(self.inner_model, x, sigma, uncond, cond, cond_scale, cond_concat)
|
out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, cond_concat=cond_concat)
|
||||||
if denoise_mask is not None:
|
if denoise_mask is not None:
|
||||||
out *= denoise_mask
|
out *= denoise_mask
|
||||||
|
|
||||||
@ -196,8 +237,6 @@ def simple_scheduler(model, steps):
|
|||||||
def blank_inpaint_image_like(latent_image):
|
def blank_inpaint_image_like(latent_image):
|
||||||
blank_image = torch.ones_like(latent_image)
|
blank_image = torch.ones_like(latent_image)
|
||||||
# these are the values for "zero" in pixel space translated to latent space
|
# these are the values for "zero" in pixel space translated to latent space
|
||||||
# the proper way to do this is to apply the mask to the image in pixel space and then send it through the VAE
|
|
||||||
# unfortunately that gives zero flexibility so I did things like this instead which hopefully works
|
|
||||||
blank_image[:,0] *= 0.8223
|
blank_image[:,0] *= 0.8223
|
||||||
blank_image[:,1] *= -0.6876
|
blank_image[:,1] *= -0.6876
|
||||||
blank_image[:,2] *= 0.6364
|
blank_image[:,2] *= 0.6364
|
||||||
@ -234,6 +273,42 @@ def create_cond_with_same_area_if_none(conds, c):
|
|||||||
n = c[1].copy()
|
n = c[1].copy()
|
||||||
conds += [[smallest[0], n]]
|
conds += [[smallest[0], n]]
|
||||||
|
|
||||||
|
|
||||||
|
def apply_control_net_to_equal_area(conds, uncond):
|
||||||
|
cond_cnets = []
|
||||||
|
cond_other = []
|
||||||
|
uncond_cnets = []
|
||||||
|
uncond_other = []
|
||||||
|
for t in range(len(conds)):
|
||||||
|
x = conds[t]
|
||||||
|
if 'area' not in x[1]:
|
||||||
|
if 'control' in x[1] and x[1]['control'] is not None:
|
||||||
|
cond_cnets.append(x[1]['control'])
|
||||||
|
else:
|
||||||
|
cond_other.append((x, t))
|
||||||
|
for t in range(len(uncond)):
|
||||||
|
x = uncond[t]
|
||||||
|
if 'area' not in x[1]:
|
||||||
|
if 'control' in x[1] and x[1]['control'] is not None:
|
||||||
|
uncond_cnets.append(x[1]['control'])
|
||||||
|
else:
|
||||||
|
uncond_other.append((x, t))
|
||||||
|
|
||||||
|
if len(uncond_cnets) > 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
for x in range(len(cond_cnets)):
|
||||||
|
temp = uncond_other[x % len(uncond_other)]
|
||||||
|
o = temp[0]
|
||||||
|
if 'control' in o[1] and o[1]['control'] is not None:
|
||||||
|
n = o[1].copy()
|
||||||
|
n['control'] = cond_cnets[x]
|
||||||
|
uncond += [[o[0], n]]
|
||||||
|
else:
|
||||||
|
n = o[1].copy()
|
||||||
|
n['control'] = cond_cnets[x]
|
||||||
|
uncond[temp[1]] = [o[0], n]
|
||||||
|
|
||||||
class KSampler:
|
class KSampler:
|
||||||
SCHEDULERS = ["karras", "normal", "simple"]
|
SCHEDULERS = ["karras", "normal", "simple"]
|
||||||
SAMPLERS = ["sample_euler", "sample_euler_ancestral", "sample_heun", "sample_dpm_2", "sample_dpm_2_ancestral",
|
SAMPLERS = ["sample_euler", "sample_euler_ancestral", "sample_heun", "sample_dpm_2", "sample_dpm_2_ancestral",
|
||||||
@ -242,11 +317,13 @@ class KSampler:
|
|||||||
|
|
||||||
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None):
|
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None):
|
||||||
self.model = model
|
self.model = model
|
||||||
|
self.model_denoise = CFGNoisePredictor(self.model)
|
||||||
if self.model.parameterization == "v":
|
if self.model.parameterization == "v":
|
||||||
self.model_wrap = k_diffusion_external.CompVisVDenoiser(self.model, quantize=True)
|
self.model_wrap = CompVisVDenoiser(self.model_denoise, quantize=True)
|
||||||
else:
|
else:
|
||||||
self.model_wrap = k_diffusion_external.CompVisDenoiser(self.model, quantize=True)
|
self.model_wrap = k_diffusion_external.CompVisDenoiser(self.model_denoise, quantize=True)
|
||||||
self.model_k = CFGDenoiserComplex(self.model_wrap)
|
self.model_wrap.parameterization = self.model.parameterization
|
||||||
|
self.model_k = KSamplerX0Inpaint(self.model_wrap)
|
||||||
self.device = device
|
self.device = device
|
||||||
if scheduler not in self.SCHEDULERS:
|
if scheduler not in self.SCHEDULERS:
|
||||||
scheduler = self.SCHEDULERS[0]
|
scheduler = self.SCHEDULERS[0]
|
||||||
@ -316,6 +393,8 @@ class KSampler:
|
|||||||
for c in negative:
|
for c in negative:
|
||||||
create_cond_with_same_area_if_none(positive, c)
|
create_cond_with_same_area_if_none(positive, c)
|
||||||
|
|
||||||
|
apply_control_net_to_equal_area(positive, negative)
|
||||||
|
|
||||||
if self.model.model.diffusion_model.dtype == torch.float16:
|
if self.model.model.diffusion_model.dtype == torch.float16:
|
||||||
precision_scope = torch.autocast
|
precision_scope = torch.autocast
|
||||||
else:
|
else:
|
||||||
|
76
comfy/sd.py
76
comfy/sd.py
@ -6,6 +6,9 @@ import model_management
|
|||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
from ldm.models.autoencoder import AutoencoderKL
|
from ldm.models.autoencoder import AutoencoderKL
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
from .cldm import cldm
|
||||||
|
|
||||||
|
from . import utils
|
||||||
|
|
||||||
def load_torch_file(ckpt):
|
def load_torch_file(ckpt):
|
||||||
if ckpt.lower().endswith(".safetensors"):
|
if ckpt.lower().endswith(".safetensors"):
|
||||||
@ -323,6 +326,79 @@ class VAE:
|
|||||||
samples = samples.cpu()
|
samples = samples.cpu()
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
class ControlNet:
|
||||||
|
def __init__(self, control_model):
|
||||||
|
self.control_model = control_model
|
||||||
|
self.cond_hint_original = None
|
||||||
|
self.cond_hint = None
|
||||||
|
|
||||||
|
def get_control(self, x_noisy, t, cond_txt):
|
||||||
|
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
|
||||||
|
if self.cond_hint is not None:
|
||||||
|
del self.cond_hint
|
||||||
|
self.cond_hint = None
|
||||||
|
self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(x_noisy.device)
|
||||||
|
print("set cond_hint", self.cond_hint.shape)
|
||||||
|
control = self.control_model(x=x_noisy, hint=self.cond_hint, timesteps=t, context=cond_txt)
|
||||||
|
return control
|
||||||
|
|
||||||
|
def set_cond_hint(self, cond_hint):
|
||||||
|
self.cond_hint_original = cond_hint
|
||||||
|
return self
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
if self.cond_hint is not None:
|
||||||
|
del self.cond_hint
|
||||||
|
self.cond_hint = None
|
||||||
|
|
||||||
|
def copy(self):
|
||||||
|
c = ControlNet(self.control_model)
|
||||||
|
c.cond_hint_original = self.cond_hint_original
|
||||||
|
return c
|
||||||
|
|
||||||
|
def load_controlnet(ckpt_path):
|
||||||
|
controlnet_data = load_torch_file(ckpt_path)
|
||||||
|
pth_key = 'control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight'
|
||||||
|
pth = False
|
||||||
|
sd2 = False
|
||||||
|
key = 'input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight'
|
||||||
|
if pth_key in controlnet_data:
|
||||||
|
pth = True
|
||||||
|
key = pth_key
|
||||||
|
elif key in controlnet_data:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
print("error checkpoint does not contain controlnet data", ckpt_path)
|
||||||
|
return None
|
||||||
|
|
||||||
|
context_dim = controlnet_data[key].shape[1]
|
||||||
|
control_model = cldm.ControlNet(image_size=32,
|
||||||
|
in_channels=4,
|
||||||
|
hint_channels=3,
|
||||||
|
model_channels=320,
|
||||||
|
attention_resolutions=[ 4, 2, 1 ],
|
||||||
|
num_res_blocks=2,
|
||||||
|
channel_mult=[ 1, 2, 4, 4 ],
|
||||||
|
num_heads=8,
|
||||||
|
use_spatial_transformer=True,
|
||||||
|
transformer_depth=1,
|
||||||
|
context_dim=context_dim,
|
||||||
|
use_checkpoint=True,
|
||||||
|
legacy=False)
|
||||||
|
|
||||||
|
if pth:
|
||||||
|
class WeightsLoader(torch.nn.Module):
|
||||||
|
pass
|
||||||
|
w = WeightsLoader()
|
||||||
|
w.control_model = control_model
|
||||||
|
w.load_state_dict(controlnet_data, strict=False)
|
||||||
|
else:
|
||||||
|
control_model.load_state_dict(controlnet_data, strict=False)
|
||||||
|
|
||||||
|
control = ControlNet(control_model)
|
||||||
|
return control
|
||||||
|
|
||||||
|
|
||||||
def load_clip(ckpt_path, embedding_directory=None):
|
def load_clip(ckpt_path, embedding_directory=None):
|
||||||
clip_data = load_torch_file(ckpt_path)
|
clip_data = load_torch_file(ckpt_path)
|
||||||
config = {}
|
config = {}
|
||||||
|
18
comfy/utils.py
Normal file
18
comfy/utils.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
def common_upscale(samples, width, height, upscale_method, crop):
|
||||||
|
if crop == "center":
|
||||||
|
old_width = samples.shape[3]
|
||||||
|
old_height = samples.shape[2]
|
||||||
|
old_aspect = old_width / old_height
|
||||||
|
new_aspect = width / height
|
||||||
|
x = 0
|
||||||
|
y = 0
|
||||||
|
if old_aspect > new_aspect:
|
||||||
|
x = round((old_width - old_width * (new_aspect / old_aspect)) / 2)
|
||||||
|
elif old_aspect < new_aspect:
|
||||||
|
y = round((old_height - old_height * (old_aspect / new_aspect)) / 2)
|
||||||
|
s = samples[:,:,y:old_height-y,x:old_width-x]
|
||||||
|
else:
|
||||||
|
s = samples
|
||||||
|
return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
|
93
nodes.py
93
nodes.py
@ -15,10 +15,12 @@ sys.path.insert(0, os.path.join(sys.path[0], "comfy"))
|
|||||||
|
|
||||||
import comfy.samplers
|
import comfy.samplers
|
||||||
import comfy.sd
|
import comfy.sd
|
||||||
|
import comfy.utils
|
||||||
|
|
||||||
import model_management
|
import model_management
|
||||||
|
|
||||||
supported_ckpt_extensions = ['.ckpt']
|
supported_ckpt_extensions = ['.ckpt', '.pth']
|
||||||
supported_pt_extensions = ['.ckpt', '.pt', '.bin']
|
supported_pt_extensions = ['.ckpt', '.pt', '.bin', '.pth']
|
||||||
try:
|
try:
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
supported_ckpt_extensions += ['.safetensors']
|
supported_ckpt_extensions += ['.safetensors']
|
||||||
@ -77,12 +79,14 @@ class ConditioningSetArea:
|
|||||||
CATEGORY = "conditioning"
|
CATEGORY = "conditioning"
|
||||||
|
|
||||||
def append(self, conditioning, width, height, x, y, strength, min_sigma=0.0, max_sigma=99.0):
|
def append(self, conditioning, width, height, x, y, strength, min_sigma=0.0, max_sigma=99.0):
|
||||||
c = copy.deepcopy(conditioning)
|
c = []
|
||||||
for t in c:
|
for t in conditioning:
|
||||||
t[1]['area'] = (height // 8, width // 8, y // 8, x // 8)
|
n = [t[0], t[1].copy()]
|
||||||
t[1]['strength'] = strength
|
n[1]['area'] = (height // 8, width // 8, y // 8, x // 8)
|
||||||
t[1]['min_sigma'] = min_sigma
|
n[1]['strength'] = strength
|
||||||
t[1]['max_sigma'] = max_sigma
|
n[1]['min_sigma'] = min_sigma
|
||||||
|
n[1]['max_sigma'] = max_sigma
|
||||||
|
c.append(n)
|
||||||
return (c, )
|
return (c, )
|
||||||
|
|
||||||
class VAEDecode:
|
class VAEDecode:
|
||||||
@ -134,7 +138,6 @@ class VAEEncodeForInpaint:
|
|||||||
CATEGORY = "latent/inpaint"
|
CATEGORY = "latent/inpaint"
|
||||||
|
|
||||||
def encode(self, vae, pixels, mask):
|
def encode(self, vae, pixels, mask):
|
||||||
print(pixels.shape, mask.shape)
|
|
||||||
x = (pixels.shape[1] // 64) * 64
|
x = (pixels.shape[1] // 64) * 64
|
||||||
y = (pixels.shape[2] // 64) * 64
|
y = (pixels.shape[2] // 64) * 64
|
||||||
if pixels.shape[1] != x or pixels.shape[2] != y:
|
if pixels.shape[1] != x or pixels.shape[2] != y:
|
||||||
@ -144,7 +147,6 @@ class VAEEncodeForInpaint:
|
|||||||
#shave off a few pixels to keep things seamless
|
#shave off a few pixels to keep things seamless
|
||||||
kernel_tensor = torch.ones((1, 1, 6, 6))
|
kernel_tensor = torch.ones((1, 1, 6, 6))
|
||||||
mask_erosion = torch.clamp(torch.nn.functional.conv2d((1.0 - mask.round())[None], kernel_tensor, padding=3), 0, 1)
|
mask_erosion = torch.clamp(torch.nn.functional.conv2d((1.0 - mask.round())[None], kernel_tensor, padding=3), 0, 1)
|
||||||
print(mask_erosion.shape, pixels.shape)
|
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
pixels[:,:,:,i] -= 0.5
|
pixels[:,:,:,i] -= 0.5
|
||||||
pixels[:,:,:,i] *= mask_erosion[0][:x,:y].round()
|
pixels[:,:,:,i] *= mask_erosion[0][:x,:y].round()
|
||||||
@ -211,6 +213,44 @@ class VAELoader:
|
|||||||
vae = comfy.sd.VAE(ckpt_path=vae_path)
|
vae = comfy.sd.VAE(ckpt_path=vae_path)
|
||||||
return (vae,)
|
return (vae,)
|
||||||
|
|
||||||
|
class ControlNetLoader:
|
||||||
|
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
|
||||||
|
controlnet_dir = os.path.join(models_dir, "controlnet")
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "control_net_name": (filter_files_extensions(recursive_search(s.controlnet_dir), supported_pt_extensions), )}}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("CONTROL_NET",)
|
||||||
|
FUNCTION = "load_controlnet"
|
||||||
|
|
||||||
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
|
def load_controlnet(self, control_net_name):
|
||||||
|
controlnet_path = os.path.join(self.controlnet_dir, control_net_name)
|
||||||
|
controlnet = comfy.sd.load_controlnet(controlnet_path)
|
||||||
|
return (controlnet,)
|
||||||
|
|
||||||
|
|
||||||
|
class ControlNetApply:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"conditioning": ("CONDITIONING", ), "control_net": ("CONTROL_NET", ), "image": ("IMAGE", )}}
|
||||||
|
RETURN_TYPES = ("CONDITIONING",)
|
||||||
|
FUNCTION = "apply_controlnet"
|
||||||
|
|
||||||
|
CATEGORY = "conditioning"
|
||||||
|
|
||||||
|
def apply_controlnet(self, conditioning, control_net, image):
|
||||||
|
c = []
|
||||||
|
control_hint = image.movedim(-1,1)
|
||||||
|
print(control_hint.shape)
|
||||||
|
for t in conditioning:
|
||||||
|
n = [t[0], t[1].copy()]
|
||||||
|
n[1]['control'] = control_net.copy().set_cond_hint(control_hint)
|
||||||
|
c.append(n)
|
||||||
|
return (c, )
|
||||||
|
|
||||||
|
|
||||||
class CLIPLoader:
|
class CLIPLoader:
|
||||||
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
|
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
|
||||||
clip_dir = os.path.join(models_dir, "clip")
|
clip_dir = os.path.join(models_dir, "clip")
|
||||||
@ -248,22 +288,7 @@ class EmptyLatentImage:
|
|||||||
latent = torch.zeros([batch_size, 4, height // 8, width // 8])
|
latent = torch.zeros([batch_size, 4, height // 8, width // 8])
|
||||||
return ({"samples":latent}, )
|
return ({"samples":latent}, )
|
||||||
|
|
||||||
def common_upscale(samples, width, height, upscale_method, crop):
|
|
||||||
if crop == "center":
|
|
||||||
old_width = samples.shape[3]
|
|
||||||
old_height = samples.shape[2]
|
|
||||||
old_aspect = old_width / old_height
|
|
||||||
new_aspect = width / height
|
|
||||||
x = 0
|
|
||||||
y = 0
|
|
||||||
if old_aspect > new_aspect:
|
|
||||||
x = round((old_width - old_width * (new_aspect / old_aspect)) / 2)
|
|
||||||
elif old_aspect < new_aspect:
|
|
||||||
y = round((old_height - old_height * (old_aspect / new_aspect)) / 2)
|
|
||||||
s = samples[:,:,y:old_height-y,x:old_width-x]
|
|
||||||
else:
|
|
||||||
s = samples
|
|
||||||
return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
|
|
||||||
|
|
||||||
class LatentUpscale:
|
class LatentUpscale:
|
||||||
upscale_methods = ["nearest-exact", "bilinear", "area"]
|
upscale_methods = ["nearest-exact", "bilinear", "area"]
|
||||||
@ -282,7 +307,7 @@ class LatentUpscale:
|
|||||||
|
|
||||||
def upscale(self, samples, upscale_method, width, height, crop):
|
def upscale(self, samples, upscale_method, width, height, crop):
|
||||||
s = samples.copy()
|
s = samples.copy()
|
||||||
s["samples"] = common_upscale(samples["samples"], width // 8, height // 8, upscale_method, crop)
|
s["samples"] = comfy.utils.common_upscale(samples["samples"], width // 8, height // 8, upscale_method, crop)
|
||||||
return (s,)
|
return (s,)
|
||||||
|
|
||||||
class LatentRotate:
|
class LatentRotate:
|
||||||
@ -461,19 +486,26 @@ def common_ksampler(device, model, seed, steps, cfg, sampler_name, scheduler, po
|
|||||||
positive_copy = []
|
positive_copy = []
|
||||||
negative_copy = []
|
negative_copy = []
|
||||||
|
|
||||||
|
control_nets = []
|
||||||
for p in positive:
|
for p in positive:
|
||||||
t = p[0]
|
t = p[0]
|
||||||
if t.shape[0] < noise.shape[0]:
|
if t.shape[0] < noise.shape[0]:
|
||||||
t = torch.cat([t] * noise.shape[0])
|
t = torch.cat([t] * noise.shape[0])
|
||||||
t = t.to(device)
|
t = t.to(device)
|
||||||
|
if 'control' in p[1]:
|
||||||
|
control_nets += [p[1]['control']]
|
||||||
positive_copy += [[t] + p[1:]]
|
positive_copy += [[t] + p[1:]]
|
||||||
for n in negative:
|
for n in negative:
|
||||||
t = n[0]
|
t = n[0]
|
||||||
if t.shape[0] < noise.shape[0]:
|
if t.shape[0] < noise.shape[0]:
|
||||||
t = torch.cat([t] * noise.shape[0])
|
t = torch.cat([t] * noise.shape[0])
|
||||||
t = t.to(device)
|
t = t.to(device)
|
||||||
|
if 'control' in p[1]:
|
||||||
|
control_nets += [p[1]['control']]
|
||||||
negative_copy += [[t] + n[1:]]
|
negative_copy += [[t] + n[1:]]
|
||||||
|
|
||||||
|
model_management.load_controlnet_gpu(list(map(lambda a: a.control_model, control_nets)))
|
||||||
|
|
||||||
if sampler_name in comfy.samplers.KSampler.SAMPLERS:
|
if sampler_name in comfy.samplers.KSampler.SAMPLERS:
|
||||||
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise)
|
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise)
|
||||||
else:
|
else:
|
||||||
@ -482,6 +514,9 @@ def common_ksampler(device, model, seed, steps, cfg, sampler_name, scheduler, po
|
|||||||
|
|
||||||
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask)
|
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask)
|
||||||
samples = samples.cpu()
|
samples = samples.cpu()
|
||||||
|
for c in control_nets:
|
||||||
|
c.cleanup()
|
||||||
|
|
||||||
out = latent.copy()
|
out = latent.copy()
|
||||||
out["samples"] = samples
|
out["samples"] = samples
|
||||||
return (out, )
|
return (out, )
|
||||||
@ -676,7 +711,7 @@ class ImageScale:
|
|||||||
|
|
||||||
def upscale(self, image, upscale_method, width, height, crop):
|
def upscale(self, image, upscale_method, width, height, crop):
|
||||||
samples = image.movedim(-1,1)
|
samples = image.movedim(-1,1)
|
||||||
s = common_upscale(samples, width, height, upscale_method, crop)
|
s = comfy.utils.common_upscale(samples, width, height, upscale_method, crop)
|
||||||
s = s.movedim(1,-1)
|
s = s.movedim(1,-1)
|
||||||
return (s,)
|
return (s,)
|
||||||
|
|
||||||
@ -704,6 +739,8 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"LatentCrop": LatentCrop,
|
"LatentCrop": LatentCrop,
|
||||||
"LoraLoader": LoraLoader,
|
"LoraLoader": LoraLoader,
|
||||||
"CLIPLoader": CLIPLoader,
|
"CLIPLoader": CLIPLoader,
|
||||||
|
"ControlNetApply": ControlNetApply,
|
||||||
|
"ControlNetLoader": ControlNetLoader,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user