mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Merge branch 'comfyanonymous:master' into master
This commit is contained in:
commit
fa66ece26b
286
comfy/cldm/cldm.py
Normal file
286
comfy/cldm/cldm.py
Normal file
@ -0,0 +1,286 @@
|
||||
#taken from: https://github.com/lllyasviel/ControlNet
|
||||
#and modified
|
||||
|
||||
import einops
|
||||
import torch
|
||||
import torch as th
|
||||
import torch.nn as nn
|
||||
|
||||
from ldm.modules.diffusionmodules.util import (
|
||||
conv_nd,
|
||||
linear,
|
||||
zero_module,
|
||||
timestep_embedding,
|
||||
)
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from torchvision.utils import make_grid
|
||||
from ldm.modules.attention import SpatialTransformer
|
||||
from ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock
|
||||
from ldm.models.diffusion.ddpm import LatentDiffusion
|
||||
from ldm.util import log_txt_as_img, exists, instantiate_from_config
|
||||
|
||||
|
||||
class ControlledUnetModel(UNetModel):
|
||||
#implemented in the ldm unet
|
||||
pass
|
||||
|
||||
class ControlNet(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
image_size,
|
||||
in_channels,
|
||||
model_channels,
|
||||
hint_channels,
|
||||
num_res_blocks,
|
||||
attention_resolutions,
|
||||
dropout=0,
|
||||
channel_mult=(1, 2, 4, 8),
|
||||
conv_resample=True,
|
||||
dims=2,
|
||||
use_checkpoint=False,
|
||||
use_fp16=False,
|
||||
num_heads=-1,
|
||||
num_head_channels=-1,
|
||||
num_heads_upsample=-1,
|
||||
use_scale_shift_norm=False,
|
||||
resblock_updown=False,
|
||||
use_new_attention_order=False,
|
||||
use_spatial_transformer=False, # custom transformer support
|
||||
transformer_depth=1, # custom transformer support
|
||||
context_dim=None, # custom transformer support
|
||||
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
|
||||
legacy=True,
|
||||
disable_self_attentions=None,
|
||||
num_attention_blocks=None,
|
||||
disable_middle_self_attn=False,
|
||||
use_linear_in_transformer=False,
|
||||
):
|
||||
super().__init__()
|
||||
if use_spatial_transformer:
|
||||
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
|
||||
|
||||
if context_dim is not None:
|
||||
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
|
||||
from omegaconf.listconfig import ListConfig
|
||||
if type(context_dim) == ListConfig:
|
||||
context_dim = list(context_dim)
|
||||
|
||||
if num_heads_upsample == -1:
|
||||
num_heads_upsample = num_heads
|
||||
|
||||
if num_heads == -1:
|
||||
assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
|
||||
|
||||
if num_head_channels == -1:
|
||||
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
|
||||
|
||||
self.dims = dims
|
||||
self.image_size = image_size
|
||||
self.in_channels = in_channels
|
||||
self.model_channels = model_channels
|
||||
if isinstance(num_res_blocks, int):
|
||||
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
|
||||
else:
|
||||
if len(num_res_blocks) != len(channel_mult):
|
||||
raise ValueError("provide num_res_blocks either as an int (globally constant) or "
|
||||
"as a list/tuple (per-level) with the same length as channel_mult")
|
||||
self.num_res_blocks = num_res_blocks
|
||||
if disable_self_attentions is not None:
|
||||
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
|
||||
assert len(disable_self_attentions) == len(channel_mult)
|
||||
if num_attention_blocks is not None:
|
||||
assert len(num_attention_blocks) == len(self.num_res_blocks)
|
||||
assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
|
||||
print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
|
||||
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
|
||||
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
|
||||
f"attention will still not be set.")
|
||||
|
||||
self.attention_resolutions = attention_resolutions
|
||||
self.dropout = dropout
|
||||
self.channel_mult = channel_mult
|
||||
self.conv_resample = conv_resample
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.dtype = th.float16 if use_fp16 else th.float32
|
||||
self.num_heads = num_heads
|
||||
self.num_head_channels = num_head_channels
|
||||
self.num_heads_upsample = num_heads_upsample
|
||||
self.predict_codebook_ids = n_embed is not None
|
||||
|
||||
time_embed_dim = model_channels * 4
|
||||
self.time_embed = nn.Sequential(
|
||||
linear(model_channels, time_embed_dim),
|
||||
nn.SiLU(),
|
||||
linear(time_embed_dim, time_embed_dim),
|
||||
)
|
||||
|
||||
self.input_blocks = nn.ModuleList(
|
||||
[
|
||||
TimestepEmbedSequential(
|
||||
conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
||||
)
|
||||
]
|
||||
)
|
||||
self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)])
|
||||
|
||||
self.input_hint_block = TimestepEmbedSequential(
|
||||
conv_nd(dims, hint_channels, 16, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, 16, 16, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, 16, 32, 3, padding=1, stride=2),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, 32, 32, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, 32, 96, 3, padding=1, stride=2),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, 96, 96, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, 96, 256, 3, padding=1, stride=2),
|
||||
nn.SiLU(),
|
||||
zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
|
||||
)
|
||||
|
||||
self._feature_size = model_channels
|
||||
input_block_chans = [model_channels]
|
||||
ch = model_channels
|
||||
ds = 1
|
||||
for level, mult in enumerate(channel_mult):
|
||||
for nr in range(self.num_res_blocks[level]):
|
||||
layers = [
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels=mult * model_channels,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
)
|
||||
]
|
||||
ch = mult * model_channels
|
||||
if ds in attention_resolutions:
|
||||
if num_head_channels == -1:
|
||||
dim_head = ch // num_heads
|
||||
else:
|
||||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
if legacy:
|
||||
#num_heads = 1
|
||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||
if exists(disable_self_attentions):
|
||||
disabled_sa = disable_self_attentions[level]
|
||||
else:
|
||||
disabled_sa = False
|
||||
|
||||
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
|
||||
layers.append(
|
||||
AttentionBlock(
|
||||
ch,
|
||||
use_checkpoint=use_checkpoint,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=dim_head,
|
||||
use_new_attention_order=use_new_attention_order,
|
||||
) if not use_spatial_transformer else SpatialTransformer(
|
||||
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
|
||||
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint
|
||||
)
|
||||
)
|
||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self.zero_convs.append(self.make_zero_conv(ch))
|
||||
self._feature_size += ch
|
||||
input_block_chans.append(ch)
|
||||
if level != len(channel_mult) - 1:
|
||||
out_ch = ch
|
||||
self.input_blocks.append(
|
||||
TimestepEmbedSequential(
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels=out_ch,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
down=True,
|
||||
)
|
||||
if resblock_updown
|
||||
else Downsample(
|
||||
ch, conv_resample, dims=dims, out_channels=out_ch
|
||||
)
|
||||
)
|
||||
)
|
||||
ch = out_ch
|
||||
input_block_chans.append(ch)
|
||||
self.zero_convs.append(self.make_zero_conv(ch))
|
||||
ds *= 2
|
||||
self._feature_size += ch
|
||||
|
||||
if num_head_channels == -1:
|
||||
dim_head = ch // num_heads
|
||||
else:
|
||||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
if legacy:
|
||||
#num_heads = 1
|
||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||
self.middle_block = TimestepEmbedSequential(
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
),
|
||||
AttentionBlock(
|
||||
ch,
|
||||
use_checkpoint=use_checkpoint,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=dim_head,
|
||||
use_new_attention_order=use_new_attention_order,
|
||||
) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
|
||||
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
|
||||
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint
|
||||
),
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
),
|
||||
)
|
||||
self.middle_block_out = self.make_zero_conv(ch)
|
||||
self._feature_size += ch
|
||||
|
||||
def make_zero_conv(self, channels):
|
||||
return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0)))
|
||||
|
||||
def forward(self, x, hint, timesteps, context, **kwargs):
|
||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
||||
emb = self.time_embed(t_emb)
|
||||
|
||||
guided_hint = self.input_hint_block(hint, emb, context)
|
||||
|
||||
outs = []
|
||||
|
||||
h = x.type(self.dtype)
|
||||
for module, zero_conv in zip(self.input_blocks, self.zero_convs):
|
||||
if guided_hint is not None:
|
||||
h = module(h, emb, context)
|
||||
h += guided_hint
|
||||
guided_hint = None
|
||||
else:
|
||||
h = module(h, emb, context)
|
||||
outs.append(zero_conv(h, emb, context))
|
||||
|
||||
h = self.middle_block(h, emb, context)
|
||||
outs.append(self.middle_block_out(h, emb, context))
|
||||
|
||||
return outs
|
||||
|
@ -358,7 +358,10 @@ class UniPC:
|
||||
predict_x0=True,
|
||||
thresholding=False,
|
||||
max_val=1.,
|
||||
variant='bh1'
|
||||
variant='bh1',
|
||||
noise_mask=None,
|
||||
masked_image=None,
|
||||
noise=None,
|
||||
):
|
||||
"""Construct a UniPC.
|
||||
|
||||
@ -370,7 +373,10 @@ class UniPC:
|
||||
self.predict_x0 = predict_x0
|
||||
self.thresholding = thresholding
|
||||
self.max_val = max_val
|
||||
|
||||
self.noise_mask = noise_mask
|
||||
self.masked_image = masked_image
|
||||
self.noise = noise
|
||||
|
||||
def dynamic_thresholding_fn(self, x0, t=None):
|
||||
"""
|
||||
The dynamic thresholding method.
|
||||
@ -386,7 +392,10 @@ class UniPC:
|
||||
"""
|
||||
Return the noise prediction model.
|
||||
"""
|
||||
return self.model(x, t)
|
||||
if self.noise_mask is not None:
|
||||
return self.model(x, t) * self.noise_mask
|
||||
else:
|
||||
return self.model(x, t)
|
||||
|
||||
def data_prediction_fn(self, x, t):
|
||||
"""
|
||||
@ -401,6 +410,8 @@ class UniPC:
|
||||
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
|
||||
s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
|
||||
x0 = torch.clamp(x0, -s, s) / s
|
||||
if self.noise_mask is not None:
|
||||
x0 = x0 * self.noise_mask + (1. - self.noise_mask) * self.masked_image
|
||||
return x0
|
||||
|
||||
def model_fn(self, x, t):
|
||||
@ -713,6 +724,8 @@ class UniPC:
|
||||
assert timesteps.shape[0] - 1 == steps
|
||||
# with torch.no_grad():
|
||||
for step_index in trange(steps):
|
||||
if self.noise_mask is not None:
|
||||
x = x * self.noise_mask + (1. - self.noise_mask) * (self.masked_image * self.noise_schedule.marginal_alpha(timesteps[step_index]) + self.noise * self.noise_schedule.marginal_std(timesteps[step_index]))
|
||||
if step_index == 0:
|
||||
vec_t = timesteps[0].expand((x.shape[0]))
|
||||
model_prev_list = [self.model_fn(x, vec_t)]
|
||||
@ -820,7 +833,7 @@ def expand_dims(v, dims):
|
||||
|
||||
|
||||
|
||||
def sample_unipc(model, noise, image, sigmas, sampling_function, extra_args=None, callback=None, disable=None):
|
||||
def sample_unipc(model, noise, image, sigmas, sampling_function, extra_args=None, callback=None, disable=None, noise_mask=None):
|
||||
to_zero = False
|
||||
if sigmas[-1] == 0:
|
||||
timesteps = torch.nn.functional.interpolate(sigmas[None,None,:-1], size=(len(sigmas),), mode='linear')[0][0]
|
||||
@ -843,13 +856,13 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, extra_args=None
|
||||
|
||||
device = noise.device
|
||||
|
||||
if model.inner_model.parameterization == "v":
|
||||
if model.parameterization == "v":
|
||||
model_type = "v"
|
||||
else:
|
||||
model_type = "noise"
|
||||
|
||||
model_fn = model_wrapper(
|
||||
model.inner_model.apply_model,
|
||||
model.inner_model.inner_model.apply_model,
|
||||
sampling_function,
|
||||
ns,
|
||||
model_type=model_type,
|
||||
@ -857,7 +870,7 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, extra_args=None
|
||||
model_kwargs=extra_args,
|
||||
)
|
||||
|
||||
uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False)
|
||||
uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, noise_mask=noise_mask, masked_image=image, noise=noise)
|
||||
x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=3, lower_order_final=True)
|
||||
if not to_zero:
|
||||
x /= ns.marginal_alpha(timesteps[-1])
|
||||
|
@ -1320,12 +1320,12 @@ class DiffusionWrapper(torch.nn.Module):
|
||||
self.conditioning_key = conditioning_key
|
||||
assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm']
|
||||
|
||||
def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None):
|
||||
def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None, control=None):
|
||||
if self.conditioning_key is None:
|
||||
out = self.diffusion_model(x, t)
|
||||
out = self.diffusion_model(x, t, control=control)
|
||||
elif self.conditioning_key == 'concat':
|
||||
xc = torch.cat([x] + c_concat, dim=1)
|
||||
out = self.diffusion_model(xc, t)
|
||||
out = self.diffusion_model(xc, t, control=control)
|
||||
elif self.conditioning_key == 'crossattn':
|
||||
if not self.sequential_cross_attn:
|
||||
cc = torch.cat(c_crossattn, 1)
|
||||
@ -1335,25 +1335,25 @@ class DiffusionWrapper(torch.nn.Module):
|
||||
# TorchScript changes names of the arguments
|
||||
# with argument cc defined as context=cc scripted model will produce
|
||||
# an error: RuntimeError: forward() is missing value for argument 'argument_3'.
|
||||
out = self.scripted_diffusion_model(x, t, cc)
|
||||
out = self.scripted_diffusion_model(x, t, cc, control=control)
|
||||
else:
|
||||
out = self.diffusion_model(x, t, context=cc)
|
||||
out = self.diffusion_model(x, t, context=cc, control=control)
|
||||
elif self.conditioning_key == 'hybrid':
|
||||
xc = torch.cat([x] + c_concat, dim=1)
|
||||
cc = torch.cat(c_crossattn, 1)
|
||||
out = self.diffusion_model(xc, t, context=cc)
|
||||
out = self.diffusion_model(xc, t, context=cc, control=control)
|
||||
elif self.conditioning_key == 'hybrid-adm':
|
||||
assert c_adm is not None
|
||||
xc = torch.cat([x] + c_concat, dim=1)
|
||||
cc = torch.cat(c_crossattn, 1)
|
||||
out = self.diffusion_model(xc, t, context=cc, y=c_adm)
|
||||
out = self.diffusion_model(xc, t, context=cc, y=c_adm, control=control)
|
||||
elif self.conditioning_key == 'crossattn-adm':
|
||||
assert c_adm is not None
|
||||
cc = torch.cat(c_crossattn, 1)
|
||||
out = self.diffusion_model(x, t, context=cc, y=c_adm)
|
||||
out = self.diffusion_model(x, t, context=cc, y=c_adm, control=control)
|
||||
elif self.conditioning_key == 'adm':
|
||||
cc = c_crossattn[0]
|
||||
out = self.diffusion_model(x, t, y=cc)
|
||||
out = self.diffusion_model(x, t, y=cc, control=control)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
@ -753,7 +753,7 @@ class UNetModel(nn.Module):
|
||||
self.middle_block.apply(convert_module_to_f32)
|
||||
self.output_blocks.apply(convert_module_to_f32)
|
||||
|
||||
def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
|
||||
def forward(self, x, timesteps=None, context=None, y=None, control=None, **kwargs):
|
||||
"""
|
||||
Apply the model to an input batch.
|
||||
:param x: an [N x C x ...] Tensor of inputs.
|
||||
@ -778,8 +778,14 @@ class UNetModel(nn.Module):
|
||||
h = module(h, emb, context)
|
||||
hs.append(h)
|
||||
h = self.middle_block(h, emb, context)
|
||||
if control is not None:
|
||||
h += control.pop()
|
||||
|
||||
for module in self.output_blocks:
|
||||
h = th.cat([h, hs.pop()], dim=1)
|
||||
hsp = hs.pop()
|
||||
if control is not None:
|
||||
hsp += control.pop()
|
||||
h = th.cat([h, hsp], dim=1)
|
||||
h = module(h, emb, context)
|
||||
h = h.type(x.dtype)
|
||||
if self.predict_codebook_ids:
|
||||
|
@ -48,7 +48,7 @@ print("Set vram state to:", ["CPU", "NO VRAM", "LOW VRAM", "NORMAL VRAM"][vram_s
|
||||
|
||||
|
||||
current_loaded_model = None
|
||||
|
||||
current_gpu_controlnets = []
|
||||
|
||||
model_accelerated = False
|
||||
|
||||
@ -56,6 +56,7 @@ model_accelerated = False
|
||||
def unload_model():
|
||||
global current_loaded_model
|
||||
global model_accelerated
|
||||
global current_gpu_controlnets
|
||||
if current_loaded_model is not None:
|
||||
if model_accelerated:
|
||||
accelerate.hooks.remove_hook_from_submodules(current_loaded_model.model)
|
||||
@ -64,6 +65,10 @@ def unload_model():
|
||||
current_loaded_model.model.cpu()
|
||||
current_loaded_model.unpatch_model()
|
||||
current_loaded_model = None
|
||||
if len(current_gpu_controlnets) > 0:
|
||||
for n in current_gpu_controlnets:
|
||||
n.cpu()
|
||||
current_gpu_controlnets = []
|
||||
|
||||
|
||||
def load_model_gpu(model):
|
||||
@ -95,6 +100,16 @@ def load_model_gpu(model):
|
||||
model_accelerated = True
|
||||
return current_loaded_model
|
||||
|
||||
def load_controlnet_gpu(models):
|
||||
global current_gpu_controlnets
|
||||
for m in current_gpu_controlnets:
|
||||
if m not in models:
|
||||
m.cpu()
|
||||
|
||||
current_gpu_controlnets = []
|
||||
for m in models:
|
||||
current_gpu_controlnets.append(m.cuda())
|
||||
|
||||
|
||||
def get_free_memory():
|
||||
dev = torch.cuda.current_device()
|
||||
|
@ -21,12 +21,13 @@ class CFGDenoiser(torch.nn.Module):
|
||||
uncond = self.inner_model(x, sigma, cond=uncond)
|
||||
return uncond + (cond - uncond) * cond_scale
|
||||
|
||||
def sampling_function(model_function, x, sigma, uncond, cond, cond_scale):
|
||||
def get_area_and_mult(cond, x_in):
|
||||
|
||||
#The main sampling function shared by all the samplers
|
||||
#Returns predicted noise
|
||||
def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, cond_concat=None):
|
||||
def get_area_and_mult(cond, x_in, cond_concat_in, timestep_in):
|
||||
area = (x_in.shape[2], x_in.shape[3], 0, 0)
|
||||
strength = 1.0
|
||||
min_sigma = 0.0
|
||||
max_sigma = 999.0
|
||||
if 'area' in cond[1]:
|
||||
area = cond[1]['area']
|
||||
if 'strength' in cond[1]:
|
||||
@ -48,9 +49,60 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale):
|
||||
if (area[1] + area[3]) < x_in.shape[3]:
|
||||
for t in range(rr):
|
||||
mult[:,:,:,area[1] + area[3] - 1 - t:area[1] + area[3] - t] *= ((1.0/rr) * (t + 1))
|
||||
return (input_x, mult, cond[0], area)
|
||||
conditionning = {}
|
||||
conditionning['c_crossattn'] = cond[0]
|
||||
if cond_concat_in is not None and len(cond_concat_in) > 0:
|
||||
cropped = []
|
||||
for x in cond_concat_in:
|
||||
cr = x[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
|
||||
cropped.append(cr)
|
||||
conditionning['c_concat'] = torch.cat(cropped, dim=1)
|
||||
|
||||
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, sigma, max_total_area):
|
||||
control = None
|
||||
if 'control' in cond[1]:
|
||||
control = cond[1]['control']
|
||||
return (input_x, mult, conditionning, area, control)
|
||||
|
||||
def cond_equal_size(c1, c2):
|
||||
if c1 is c2:
|
||||
return True
|
||||
if c1.keys() != c2.keys():
|
||||
return False
|
||||
if 'c_crossattn' in c1:
|
||||
if c1['c_crossattn'].shape != c2['c_crossattn'].shape:
|
||||
return False
|
||||
if 'c_concat' in c1:
|
||||
if c1['c_concat'].shape != c2['c_concat'].shape:
|
||||
return False
|
||||
return True
|
||||
|
||||
def can_concat_cond(c1, c2):
|
||||
if c1[0].shape != c2[0].shape:
|
||||
return False
|
||||
if (c1[4] is None) != (c2[4] is None):
|
||||
return False
|
||||
if c1[4] is not None:
|
||||
if c1[4] is not c2[4]:
|
||||
return False
|
||||
|
||||
return cond_equal_size(c1[2], c2[2])
|
||||
|
||||
def cond_cat(c_list):
|
||||
c_crossattn = []
|
||||
c_concat = []
|
||||
for x in c_list:
|
||||
if 'c_crossattn' in x:
|
||||
c_crossattn.append(x['c_crossattn'])
|
||||
if 'c_concat' in x:
|
||||
c_concat.append(x['c_concat'])
|
||||
out = {}
|
||||
if len(c_crossattn) > 0:
|
||||
out['c_crossattn'] = [torch.cat(c_crossattn)]
|
||||
if len(c_concat) > 0:
|
||||
out['c_concat'] = [torch.cat(c_concat)]
|
||||
return out
|
||||
|
||||
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, cond_concat_in):
|
||||
out_cond = torch.zeros_like(x_in)
|
||||
out_count = torch.ones_like(x_in)/100000.0
|
||||
|
||||
@ -62,13 +114,13 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale):
|
||||
|
||||
to_run = []
|
||||
for x in cond:
|
||||
p = get_area_and_mult(x, x_in)
|
||||
p = get_area_and_mult(x, x_in, cond_concat_in, timestep)
|
||||
if p is None:
|
||||
continue
|
||||
|
||||
to_run += [(p, COND)]
|
||||
for x in uncond:
|
||||
p = get_area_and_mult(x, x_in)
|
||||
p = get_area_and_mult(x, x_in, cond_concat_in, timestep)
|
||||
if p is None:
|
||||
continue
|
||||
|
||||
@ -79,9 +131,8 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale):
|
||||
first_shape = first[0][0].shape
|
||||
to_batch_temp = []
|
||||
for x in range(len(to_run)):
|
||||
if to_run[x][0][0].shape == first_shape:
|
||||
if to_run[x][0][2].shape == first[0][2].shape:
|
||||
to_batch_temp += [x]
|
||||
if can_concat_cond(to_run[x][0], first[0]):
|
||||
to_batch_temp += [x]
|
||||
|
||||
to_batch_temp.reverse()
|
||||
to_batch = to_batch_temp[:1]
|
||||
@ -97,6 +148,7 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale):
|
||||
c = []
|
||||
cond_or_uncond = []
|
||||
area = []
|
||||
control = None
|
||||
for x in to_batch:
|
||||
o = to_run.pop(x)
|
||||
p = o[0]
|
||||
@ -105,13 +157,17 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale):
|
||||
c += [p[2]]
|
||||
area += [p[3]]
|
||||
cond_or_uncond += [o[1]]
|
||||
control = p[4]
|
||||
|
||||
batch_chunks = len(cond_or_uncond)
|
||||
input_x = torch.cat(input_x)
|
||||
c = torch.cat(c)
|
||||
sigma_ = torch.cat([sigma] * batch_chunks)
|
||||
c = cond_cat(c)
|
||||
timestep_ = torch.cat([timestep] * batch_chunks)
|
||||
|
||||
output = model_function(input_x, sigma_, cond=c).chunk(batch_chunks)
|
||||
if control is not None:
|
||||
c['control'] = control.get_control(input_x, timestep_, c['c_crossattn'])
|
||||
|
||||
output = model_function(input_x, timestep_, cond=c).chunk(batch_chunks)
|
||||
del input_x
|
||||
|
||||
for o in range(batch_chunks):
|
||||
@ -132,15 +188,43 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale):
|
||||
|
||||
|
||||
max_total_area = model_management.maximum_batch_area()
|
||||
cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, sigma, max_total_area)
|
||||
cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, cond_concat)
|
||||
return uncond + (cond - uncond) * cond_scale
|
||||
|
||||
class CFGDenoiserComplex(torch.nn.Module):
|
||||
|
||||
class CompVisVDenoiser(k_diffusion_external.DiscreteVDDPMDenoiser):
|
||||
def __init__(self, model, quantize=False, device='cpu'):
|
||||
super().__init__(model, model.alphas_cumprod, quantize=quantize)
|
||||
|
||||
def get_v(self, x, t, cond, **kwargs):
|
||||
return self.inner_model.apply_model(x, t, cond, **kwargs)
|
||||
|
||||
|
||||
class CFGNoisePredictor(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.inner_model = model
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
||||
return sampling_function(self.inner_model, x, sigma, uncond, cond, cond_scale)
|
||||
self.alphas_cumprod = model.alphas_cumprod
|
||||
def apply_model(self, x, timestep, cond, uncond, cond_scale, cond_concat=None):
|
||||
out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, cond_concat)
|
||||
return out
|
||||
|
||||
|
||||
class KSamplerX0Inpaint(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.inner_model = model
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, cond_concat=None):
|
||||
if denoise_mask is not None:
|
||||
latent_mask = 1. - denoise_mask
|
||||
x = x * denoise_mask + (self.latent_image + self.noise * sigma) * latent_mask
|
||||
out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, cond_concat=cond_concat)
|
||||
if denoise_mask is not None:
|
||||
out *= denoise_mask
|
||||
|
||||
if denoise_mask is not None:
|
||||
out += self.latent_image * latent_mask
|
||||
return out
|
||||
|
||||
def simple_scheduler(model, steps):
|
||||
sigs = []
|
||||
@ -150,6 +234,15 @@ def simple_scheduler(model, steps):
|
||||
sigs += [0.0]
|
||||
return torch.FloatTensor(sigs)
|
||||
|
||||
def blank_inpaint_image_like(latent_image):
|
||||
blank_image = torch.ones_like(latent_image)
|
||||
# these are the values for "zero" in pixel space translated to latent space
|
||||
blank_image[:,0] *= 0.8223
|
||||
blank_image[:,1] *= -0.6876
|
||||
blank_image[:,2] *= 0.6364
|
||||
blank_image[:,3] *= 0.1380
|
||||
return blank_image
|
||||
|
||||
def create_cond_with_same_area_if_none(conds, c):
|
||||
if 'area' not in c[1]:
|
||||
return
|
||||
@ -180,6 +273,42 @@ def create_cond_with_same_area_if_none(conds, c):
|
||||
n = c[1].copy()
|
||||
conds += [[smallest[0], n]]
|
||||
|
||||
|
||||
def apply_control_net_to_equal_area(conds, uncond):
|
||||
cond_cnets = []
|
||||
cond_other = []
|
||||
uncond_cnets = []
|
||||
uncond_other = []
|
||||
for t in range(len(conds)):
|
||||
x = conds[t]
|
||||
if 'area' not in x[1]:
|
||||
if 'control' in x[1] and x[1]['control'] is not None:
|
||||
cond_cnets.append(x[1]['control'])
|
||||
else:
|
||||
cond_other.append((x, t))
|
||||
for t in range(len(uncond)):
|
||||
x = uncond[t]
|
||||
if 'area' not in x[1]:
|
||||
if 'control' in x[1] and x[1]['control'] is not None:
|
||||
uncond_cnets.append(x[1]['control'])
|
||||
else:
|
||||
uncond_other.append((x, t))
|
||||
|
||||
if len(uncond_cnets) > 0:
|
||||
return
|
||||
|
||||
for x in range(len(cond_cnets)):
|
||||
temp = uncond_other[x % len(uncond_other)]
|
||||
o = temp[0]
|
||||
if 'control' in o[1] and o[1]['control'] is not None:
|
||||
n = o[1].copy()
|
||||
n['control'] = cond_cnets[x]
|
||||
uncond += [[o[0], n]]
|
||||
else:
|
||||
n = o[1].copy()
|
||||
n['control'] = cond_cnets[x]
|
||||
uncond[temp[1]] = [o[0], n]
|
||||
|
||||
class KSampler:
|
||||
SCHEDULERS = ["karras", "normal", "simple"]
|
||||
SAMPLERS = ["sample_euler", "sample_euler_ancestral", "sample_heun", "sample_dpm_2", "sample_dpm_2_ancestral",
|
||||
@ -188,11 +317,13 @@ class KSampler:
|
||||
|
||||
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None):
|
||||
self.model = model
|
||||
self.model_denoise = CFGNoisePredictor(self.model)
|
||||
if self.model.parameterization == "v":
|
||||
self.model_wrap = k_diffusion_external.CompVisVDenoiser(self.model, quantize=True)
|
||||
self.model_wrap = CompVisVDenoiser(self.model_denoise, quantize=True)
|
||||
else:
|
||||
self.model_wrap = k_diffusion_external.CompVisDenoiser(self.model, quantize=True)
|
||||
self.model_k = CFGDenoiserComplex(self.model_wrap)
|
||||
self.model_wrap = k_diffusion_external.CompVisDenoiser(self.model_denoise, quantize=True)
|
||||
self.model_wrap.parameterization = self.model.parameterization
|
||||
self.model_k = KSamplerX0Inpaint(self.model_wrap)
|
||||
self.device = device
|
||||
if scheduler not in self.SCHEDULERS:
|
||||
scheduler = self.SCHEDULERS[0]
|
||||
@ -200,8 +331,8 @@ class KSampler:
|
||||
sampler = self.SAMPLERS[0]
|
||||
self.scheduler = scheduler
|
||||
self.sampler = sampler
|
||||
self.sigma_min=float(self.model_wrap.sigmas[0])
|
||||
self.sigma_max=float(self.model_wrap.sigmas[-1])
|
||||
self.sigma_min=float(self.model_wrap.sigma_min)
|
||||
self.sigma_max=float(self.model_wrap.sigma_max)
|
||||
self.set_steps(steps, denoise)
|
||||
|
||||
def _calculate_sigmas(self, steps):
|
||||
@ -235,7 +366,7 @@ class KSampler:
|
||||
self.sigmas = sigmas[-(steps + 1):]
|
||||
|
||||
|
||||
def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False):
|
||||
def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None):
|
||||
sigmas = self.sigmas
|
||||
sigma_min = self.sigma_min
|
||||
|
||||
@ -262,22 +393,47 @@ class KSampler:
|
||||
for c in negative:
|
||||
create_cond_with_same_area_if_none(positive, c)
|
||||
|
||||
apply_control_net_to_equal_area(positive, negative)
|
||||
|
||||
if self.model.model.diffusion_model.dtype == torch.float16:
|
||||
precision_scope = torch.autocast
|
||||
else:
|
||||
precision_scope = contextlib.nullcontext
|
||||
|
||||
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg}
|
||||
|
||||
if hasattr(self.model, 'concat_keys'):
|
||||
cond_concat = []
|
||||
for ck in self.model.concat_keys:
|
||||
if denoise_mask is not None:
|
||||
if ck == "mask":
|
||||
cond_concat.append(denoise_mask[:,:1])
|
||||
elif ck == "masked_image":
|
||||
cond_concat.append(latent_image) #NOTE: the latent_image should be masked by the mask in pixel space
|
||||
else:
|
||||
if ck == "mask":
|
||||
cond_concat.append(torch.ones_like(noise)[:,:1])
|
||||
elif ck == "masked_image":
|
||||
cond_concat.append(blank_inpaint_image_like(noise))
|
||||
extra_args["cond_concat"] = cond_concat
|
||||
|
||||
with precision_scope(self.device):
|
||||
if self.sampler == "uni_pc":
|
||||
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, extra_args={"cond":positive, "uncond":negative, "cond_scale": cfg})
|
||||
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, extra_args=extra_args, noise_mask=denoise_mask)
|
||||
else:
|
||||
noise *= sigmas[0]
|
||||
extra_args["denoise_mask"] = denoise_mask
|
||||
self.model_k.latent_image = latent_image
|
||||
self.model_k.noise = noise
|
||||
|
||||
noise = noise * sigmas[0]
|
||||
|
||||
if latent_image is not None:
|
||||
noise += latent_image
|
||||
if self.sampler == "sample_dpm_fast":
|
||||
samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], self.steps, extra_args={"cond":positive, "uncond":negative, "cond_scale": cfg})
|
||||
samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], self.steps, extra_args=extra_args)
|
||||
elif self.sampler == "sample_dpm_adaptive":
|
||||
samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args={"cond":positive, "uncond":negative, "cond_scale": cfg})
|
||||
samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args=extra_args)
|
||||
else:
|
||||
samples = getattr(k_diffusion_sampling, self.sampler)(self.model_k, noise, sigmas, extra_args={"cond":positive, "uncond":negative, "cond_scale": cfg})
|
||||
samples = getattr(k_diffusion_sampling, self.sampler)(self.model_k, noise, sigmas, extra_args=extra_args)
|
||||
|
||||
return samples.to(torch.float32)
|
||||
|
81
comfy/sd.py
81
comfy/sd.py
@ -6,6 +6,9 @@ import model_management
|
||||
from ldm.util import instantiate_from_config
|
||||
from ldm.models.autoencoder import AutoencoderKL
|
||||
from omegaconf import OmegaConf
|
||||
from .cldm import cldm
|
||||
|
||||
from . import utils
|
||||
|
||||
def load_torch_file(ckpt):
|
||||
if ckpt.lower().endswith(".safetensors"):
|
||||
@ -323,6 +326,84 @@ class VAE:
|
||||
samples = samples.cpu()
|
||||
return samples
|
||||
|
||||
class ControlNet:
|
||||
def __init__(self, control_model):
|
||||
self.control_model = control_model
|
||||
self.cond_hint_original = None
|
||||
self.cond_hint = None
|
||||
self.strength = 1.0
|
||||
|
||||
def get_control(self, x_noisy, t, cond_txt):
|
||||
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
|
||||
if self.cond_hint is not None:
|
||||
del self.cond_hint
|
||||
self.cond_hint = None
|
||||
self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(x_noisy.device)
|
||||
print("set cond_hint", self.cond_hint.shape)
|
||||
control = self.control_model(x=x_noisy, hint=self.cond_hint, timesteps=t, context=cond_txt)
|
||||
for x in control:
|
||||
x *= self.strength
|
||||
return control
|
||||
|
||||
def set_cond_hint(self, cond_hint, strength=1.0):
|
||||
self.cond_hint_original = cond_hint
|
||||
self.strength = strength
|
||||
return self
|
||||
|
||||
def cleanup(self):
|
||||
if self.cond_hint is not None:
|
||||
del self.cond_hint
|
||||
self.cond_hint = None
|
||||
|
||||
def copy(self):
|
||||
c = ControlNet(self.control_model)
|
||||
c.cond_hint_original = self.cond_hint_original
|
||||
c.strength = self.strength
|
||||
return c
|
||||
|
||||
def load_controlnet(ckpt_path):
|
||||
controlnet_data = load_torch_file(ckpt_path)
|
||||
pth_key = 'control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight'
|
||||
pth = False
|
||||
sd2 = False
|
||||
key = 'input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight'
|
||||
if pth_key in controlnet_data:
|
||||
pth = True
|
||||
key = pth_key
|
||||
elif key in controlnet_data:
|
||||
pass
|
||||
else:
|
||||
print("error checkpoint does not contain controlnet data", ckpt_path)
|
||||
return None
|
||||
|
||||
context_dim = controlnet_data[key].shape[1]
|
||||
control_model = cldm.ControlNet(image_size=32,
|
||||
in_channels=4,
|
||||
hint_channels=3,
|
||||
model_channels=320,
|
||||
attention_resolutions=[ 4, 2, 1 ],
|
||||
num_res_blocks=2,
|
||||
channel_mult=[ 1, 2, 4, 4 ],
|
||||
num_heads=8,
|
||||
use_spatial_transformer=True,
|
||||
transformer_depth=1,
|
||||
context_dim=context_dim,
|
||||
use_checkpoint=True,
|
||||
legacy=False)
|
||||
|
||||
if pth:
|
||||
class WeightsLoader(torch.nn.Module):
|
||||
pass
|
||||
w = WeightsLoader()
|
||||
w.control_model = control_model
|
||||
w.load_state_dict(controlnet_data, strict=False)
|
||||
else:
|
||||
control_model.load_state_dict(controlnet_data, strict=False)
|
||||
|
||||
control = ControlNet(control_model)
|
||||
return control
|
||||
|
||||
|
||||
def load_clip(ckpt_path, embedding_directory=None):
|
||||
clip_data = load_torch_file(ckpt_path)
|
||||
config = {}
|
||||
|
18
comfy/utils.py
Normal file
18
comfy/utils.py
Normal file
@ -0,0 +1,18 @@
|
||||
import torch
|
||||
|
||||
def common_upscale(samples, width, height, upscale_method, crop):
|
||||
if crop == "center":
|
||||
old_width = samples.shape[3]
|
||||
old_height = samples.shape[2]
|
||||
old_aspect = old_width / old_height
|
||||
new_aspect = width / height
|
||||
x = 0
|
||||
y = 0
|
||||
if old_aspect > new_aspect:
|
||||
x = round((old_width - old_width * (new_aspect / old_aspect)) / 2)
|
||||
elif old_aspect < new_aspect:
|
||||
y = round((old_height - old_height * (old_aspect / new_aspect)) / 2)
|
||||
s = samples[:,:,y:old_height-y,x:old_width-x]
|
||||
else:
|
||||
s = samples
|
||||
return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
|
4
main.py
4
main.py
@ -7,6 +7,10 @@ import heapq
|
||||
import traceback
|
||||
import asyncio
|
||||
|
||||
if os.name == "nt":
|
||||
import logging
|
||||
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
from aiohttp import web
|
||||
|
158
models/configs/v2-inpainting-inference.yaml
Normal file
158
models/configs/v2-inpainting-inference.yaml
Normal file
@ -0,0 +1,158 @@
|
||||
model:
|
||||
base_learning_rate: 5.0e-05
|
||||
target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion
|
||||
params:
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
num_timesteps_cond: 1
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: "jpg"
|
||||
cond_stage_key: "txt"
|
||||
image_size: 64
|
||||
channels: 4
|
||||
cond_stage_trainable: false
|
||||
conditioning_key: hybrid
|
||||
scale_factor: 0.18215
|
||||
monitor: val/loss_simple_ema
|
||||
finetune_keys: null
|
||||
use_ema: False
|
||||
|
||||
unet_config:
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
use_checkpoint: True
|
||||
image_size: 32 # unused
|
||||
in_channels: 9
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_head_channels: 64 # need to fix for flash-attn
|
||||
use_spatial_transformer: True
|
||||
use_linear_in_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 1024
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
target: ldm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
#attn_type: "vanilla-xformers"
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: [ ]
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
||||
params:
|
||||
freeze: True
|
||||
layer: "penultimate"
|
||||
|
||||
|
||||
data:
|
||||
target: ldm.data.laion.WebDataModuleFromConfig
|
||||
params:
|
||||
tar_base: null # for concat as in LAION-A
|
||||
p_unsafe_threshold: 0.1
|
||||
filter_word_list: "data/filters.yaml"
|
||||
max_pwatermark: 0.45
|
||||
batch_size: 8
|
||||
num_workers: 6
|
||||
multinode: True
|
||||
min_size: 512
|
||||
train:
|
||||
shards:
|
||||
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-0/{00000..18699}.tar -"
|
||||
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-1/{00000..18699}.tar -"
|
||||
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-2/{00000..18699}.tar -"
|
||||
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-3/{00000..18699}.tar -"
|
||||
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-4/{00000..18699}.tar -" #{00000-94333}.tar"
|
||||
shuffle: 10000
|
||||
image_key: jpg
|
||||
image_transforms:
|
||||
- target: torchvision.transforms.Resize
|
||||
params:
|
||||
size: 512
|
||||
interpolation: 3
|
||||
- target: torchvision.transforms.RandomCrop
|
||||
params:
|
||||
size: 512
|
||||
postprocess:
|
||||
target: ldm.data.laion.AddMask
|
||||
params:
|
||||
mode: "512train-large"
|
||||
p_drop: 0.25
|
||||
# NOTE use enough shards to avoid empty validation loops in workers
|
||||
validation:
|
||||
shards:
|
||||
- "pipe:aws s3 cp s3://deep-floyd-s3/datasets/laion_cleaned-part5/{93001..94333}.tar - "
|
||||
shuffle: 0
|
||||
image_key: jpg
|
||||
image_transforms:
|
||||
- target: torchvision.transforms.Resize
|
||||
params:
|
||||
size: 512
|
||||
interpolation: 3
|
||||
- target: torchvision.transforms.CenterCrop
|
||||
params:
|
||||
size: 512
|
||||
postprocess:
|
||||
target: ldm.data.laion.AddMask
|
||||
params:
|
||||
mode: "512train-large"
|
||||
p_drop: 0.25
|
||||
|
||||
lightning:
|
||||
find_unused_parameters: True
|
||||
modelcheckpoint:
|
||||
params:
|
||||
every_n_train_steps: 5000
|
||||
|
||||
callbacks:
|
||||
metrics_over_trainsteps_checkpoint:
|
||||
params:
|
||||
every_n_train_steps: 10000
|
||||
|
||||
image_logger:
|
||||
target: main.ImageLogger
|
||||
params:
|
||||
enable_autocast: False
|
||||
disabled: False
|
||||
batch_frequency: 1000
|
||||
max_images: 4
|
||||
increase_log_steps: False
|
||||
log_first_step: False
|
||||
log_images_kwargs:
|
||||
use_ema_scope: False
|
||||
inpaint: False
|
||||
plot_progressive_rows: False
|
||||
plot_diffusion_rows: False
|
||||
N: 4
|
||||
unconditional_guidance_scale: 5.0
|
||||
unconditional_guidance_label: [""]
|
||||
ddim_steps: 50 # todo check these out for depth2img,
|
||||
ddim_eta: 0.0 # todo check these out for depth2img,
|
||||
|
||||
trainer:
|
||||
benchmark: True
|
||||
val_check_interval: 5000000
|
||||
num_sanity_val_steps: 0
|
||||
accumulate_grad_batches: 1
|
0
models/controlnet/put_controlnets_here
Normal file
0
models/controlnet/put_controlnets_here
Normal file
242
nodes.py
242
nodes.py
@ -15,11 +15,13 @@ sys.path.insert(0, os.path.join(sys.path[0], "comfy"))
|
||||
|
||||
import comfy.samplers
|
||||
import comfy.sd
|
||||
import comfy.utils
|
||||
|
||||
import model_management
|
||||
import importlib
|
||||
|
||||
supported_ckpt_extensions = ['.ckpt']
|
||||
supported_pt_extensions = ['.ckpt', '.pt', '.bin']
|
||||
supported_ckpt_extensions = ['.ckpt', '.pth']
|
||||
supported_pt_extensions = ['.ckpt', '.pt', '.bin', '.pth']
|
||||
try:
|
||||
import safetensors.torch
|
||||
supported_ckpt_extensions += ['.safetensors']
|
||||
@ -78,12 +80,14 @@ class ConditioningSetArea:
|
||||
CATEGORY = "conditioning"
|
||||
|
||||
def append(self, conditioning, width, height, x, y, strength, min_sigma=0.0, max_sigma=99.0):
|
||||
c = copy.deepcopy(conditioning)
|
||||
for t in c:
|
||||
t[1]['area'] = (height // 8, width // 8, y // 8, x // 8)
|
||||
t[1]['strength'] = strength
|
||||
t[1]['min_sigma'] = min_sigma
|
||||
t[1]['max_sigma'] = max_sigma
|
||||
c = []
|
||||
for t in conditioning:
|
||||
n = [t[0], t[1].copy()]
|
||||
n[1]['area'] = (height // 8, width // 8, y // 8, x // 8)
|
||||
n[1]['strength'] = strength
|
||||
n[1]['min_sigma'] = min_sigma
|
||||
n[1]['max_sigma'] = max_sigma
|
||||
c.append(n)
|
||||
return (c, )
|
||||
|
||||
class VAEDecode:
|
||||
@ -99,7 +103,7 @@ class VAEDecode:
|
||||
CATEGORY = "latent"
|
||||
|
||||
def decode(self, vae, samples):
|
||||
return (vae.decode(samples), )
|
||||
return (vae.decode(samples["samples"]), )
|
||||
|
||||
class VAEEncode:
|
||||
def __init__(self, device="cpu"):
|
||||
@ -118,7 +122,39 @@ class VAEEncode:
|
||||
y = (pixels.shape[2] // 64) * 64
|
||||
if pixels.shape[1] != x or pixels.shape[2] != y:
|
||||
pixels = pixels[:,:x,:y,:]
|
||||
return (vae.encode(pixels), )
|
||||
t = vae.encode(pixels[:,:,:,:3])
|
||||
|
||||
return ({"samples":t}, )
|
||||
|
||||
class VAEEncodeForInpaint:
|
||||
def __init__(self, device="cpu"):
|
||||
self.device = device
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", ), "mask": ("MASK", )}}
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "latent/inpaint"
|
||||
|
||||
def encode(self, vae, pixels, mask):
|
||||
x = (pixels.shape[1] // 64) * 64
|
||||
y = (pixels.shape[2] // 64) * 64
|
||||
if pixels.shape[1] != x or pixels.shape[2] != y:
|
||||
pixels = pixels[:,:x,:y,:]
|
||||
mask = mask[:x,:y]
|
||||
|
||||
#shave off a few pixels to keep things seamless
|
||||
kernel_tensor = torch.ones((1, 1, 6, 6))
|
||||
mask_erosion = torch.clamp(torch.nn.functional.conv2d((1.0 - mask.round())[None], kernel_tensor, padding=3), 0, 1)
|
||||
for i in range(3):
|
||||
pixels[:,:,:,i] -= 0.5
|
||||
pixels[:,:,:,i] *= mask_erosion[0][:x,:y].round()
|
||||
pixels[:,:,:,i] += 0.5
|
||||
t = vae.encode(pixels)
|
||||
|
||||
return ({"samples":t, "noise_mask": mask}, )
|
||||
|
||||
class CheckpointLoader:
|
||||
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
|
||||
@ -178,6 +214,48 @@ class VAELoader:
|
||||
vae = comfy.sd.VAE(ckpt_path=vae_path)
|
||||
return (vae,)
|
||||
|
||||
class ControlNetLoader:
|
||||
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
|
||||
controlnet_dir = os.path.join(models_dir, "controlnet")
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "control_net_name": (filter_files_extensions(recursive_search(s.controlnet_dir), supported_pt_extensions), )}}
|
||||
|
||||
RETURN_TYPES = ("CONTROL_NET",)
|
||||
FUNCTION = "load_controlnet"
|
||||
|
||||
CATEGORY = "loaders"
|
||||
|
||||
def load_controlnet(self, control_net_name):
|
||||
controlnet_path = os.path.join(self.controlnet_dir, control_net_name)
|
||||
controlnet = comfy.sd.load_controlnet(controlnet_path)
|
||||
return (controlnet,)
|
||||
|
||||
|
||||
class ControlNetApply:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"conditioning": ("CONDITIONING", ),
|
||||
"control_net": ("CONTROL_NET", ),
|
||||
"image": ("IMAGE", ),
|
||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01})
|
||||
}}
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "apply_controlnet"
|
||||
|
||||
CATEGORY = "conditioning"
|
||||
|
||||
def apply_controlnet(self, conditioning, control_net, image, strength):
|
||||
c = []
|
||||
control_hint = image.movedim(-1,1)
|
||||
print(control_hint.shape)
|
||||
for t in conditioning:
|
||||
n = [t[0], t[1].copy()]
|
||||
n[1]['control'] = control_net.copy().set_cond_hint(control_hint, strength)
|
||||
c.append(n)
|
||||
return (c, )
|
||||
|
||||
|
||||
class CLIPLoader:
|
||||
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
|
||||
clip_dir = os.path.join(models_dir, "clip")
|
||||
@ -213,24 +291,9 @@ class EmptyLatentImage:
|
||||
|
||||
def generate(self, width, height, batch_size=1):
|
||||
latent = torch.zeros([batch_size, 4, height // 8, width // 8])
|
||||
return (latent, )
|
||||
return ({"samples":latent}, )
|
||||
|
||||
|
||||
def common_upscale(samples, width, height, upscale_method, crop):
|
||||
if crop == "center":
|
||||
old_width = samples.shape[3]
|
||||
old_height = samples.shape[2]
|
||||
old_aspect = old_width / old_height
|
||||
new_aspect = width / height
|
||||
x = 0
|
||||
y = 0
|
||||
if old_aspect > new_aspect:
|
||||
x = round((old_width - old_width * (new_aspect / old_aspect)) / 2)
|
||||
elif old_aspect < new_aspect:
|
||||
y = round((old_height - old_height * (old_aspect / new_aspect)) / 2)
|
||||
s = samples[:,:,y:old_height-y,x:old_width-x]
|
||||
else:
|
||||
s = samples
|
||||
return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
|
||||
|
||||
class LatentUpscale:
|
||||
upscale_methods = ["nearest-exact", "bilinear", "area"]
|
||||
@ -248,7 +311,8 @@ class LatentUpscale:
|
||||
CATEGORY = "latent"
|
||||
|
||||
def upscale(self, samples, upscale_method, width, height, crop):
|
||||
s = common_upscale(samples, width // 8, height // 8, upscale_method, crop)
|
||||
s = samples.copy()
|
||||
s["samples"] = comfy.utils.common_upscale(samples["samples"], width // 8, height // 8, upscale_method, crop)
|
||||
return (s,)
|
||||
|
||||
class LatentRotate:
|
||||
@ -263,6 +327,7 @@ class LatentRotate:
|
||||
CATEGORY = "latent"
|
||||
|
||||
def rotate(self, samples, rotation):
|
||||
s = samples.copy()
|
||||
rotate_by = 0
|
||||
if rotation.startswith("90"):
|
||||
rotate_by = 1
|
||||
@ -271,7 +336,7 @@ class LatentRotate:
|
||||
elif rotation.startswith("270"):
|
||||
rotate_by = 3
|
||||
|
||||
s = torch.rot90(samples, k=rotate_by, dims=[3, 2])
|
||||
s["samples"] = torch.rot90(samples["samples"], k=rotate_by, dims=[3, 2])
|
||||
return (s,)
|
||||
|
||||
class LatentFlip:
|
||||
@ -286,12 +351,11 @@ class LatentFlip:
|
||||
CATEGORY = "latent"
|
||||
|
||||
def flip(self, samples, flip_method):
|
||||
s = samples.copy()
|
||||
if flip_method.startswith("x"):
|
||||
s = torch.flip(samples, dims=[2])
|
||||
s["samples"] = torch.flip(samples["samples"], dims=[2])
|
||||
elif flip_method.startswith("y"):
|
||||
s = torch.flip(samples, dims=[3])
|
||||
else:
|
||||
s = samples
|
||||
s["samples"] = torch.flip(samples["samples"], dims=[3])
|
||||
|
||||
return (s,)
|
||||
|
||||
@ -313,12 +377,15 @@ class LatentComposite:
|
||||
x = x // 8
|
||||
y = y // 8
|
||||
feather = feather // 8
|
||||
s = samples_to.clone()
|
||||
samples_out = samples_to.copy()
|
||||
s = samples_to["samples"].clone()
|
||||
samples_to = samples_to["samples"]
|
||||
samples_from = samples_from["samples"]
|
||||
if feather == 0:
|
||||
s[:,:,y:y+samples_from.shape[2],x:x+samples_from.shape[3]] = samples_from[:,:,:samples_to.shape[2] - y, :samples_to.shape[3] - x]
|
||||
else:
|
||||
s_from = samples_from[:,:,:samples_to.shape[2] - y, :samples_to.shape[3] - x]
|
||||
mask = torch.ones_like(s_from)
|
||||
samples_from = samples_from[:,:,:samples_to.shape[2] - y, :samples_to.shape[3] - x]
|
||||
mask = torch.ones_like(samples_from)
|
||||
for t in range(feather):
|
||||
if y != 0:
|
||||
mask[:,:,t:1+t,:] *= ((1.0/feather) * (t + 1))
|
||||
@ -331,7 +398,8 @@ class LatentComposite:
|
||||
mask[:,:,:,mask.shape[3]- 1 - t: mask.shape[3]- t] *= ((1.0/feather) * (t + 1))
|
||||
rev_mask = torch.ones_like(mask) - mask
|
||||
s[:,:,y:y+samples_from.shape[2],x:x+samples_from.shape[3]] = samples_from[:,:,:samples_to.shape[2] - y, :samples_to.shape[3] - x] * mask + s[:,:,y:y+samples_from.shape[2],x:x+samples_from.shape[3]] * rev_mask
|
||||
return (s,)
|
||||
samples_out["samples"] = s
|
||||
return (samples_out,)
|
||||
|
||||
class LatentCrop:
|
||||
@classmethod
|
||||
@ -348,6 +416,8 @@ class LatentCrop:
|
||||
CATEGORY = "latent"
|
||||
|
||||
def crop(self, samples, width, height, x, y):
|
||||
s = samples.copy()
|
||||
samples = samples['samples']
|
||||
x = x // 8
|
||||
y = y // 8
|
||||
|
||||
@ -371,15 +441,43 @@ class LatentCrop:
|
||||
#make sure size is always multiple of 64
|
||||
x, to_x = enforce_image_dim(x, to_x, samples.shape[3])
|
||||
y, to_y = enforce_image_dim(y, to_y, samples.shape[2])
|
||||
s = samples[:,:,y:to_y, x:to_x]
|
||||
s['samples'] = samples[:,:,y:to_y, x:to_x]
|
||||
return (s,)
|
||||
|
||||
def common_ksampler(device, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False):
|
||||
class SetLatentNoiseMask:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "samples": ("LATENT",),
|
||||
"mask": ("MASK",),
|
||||
}}
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "set_mask"
|
||||
|
||||
CATEGORY = "latent/inpaint"
|
||||
|
||||
def set_mask(self, samples, mask):
|
||||
s = samples.copy()
|
||||
s["noise_mask"] = mask
|
||||
return (s,)
|
||||
|
||||
|
||||
def common_ksampler(device, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False):
|
||||
latent_image = latent["samples"]
|
||||
noise_mask = None
|
||||
|
||||
if disable_noise:
|
||||
noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
|
||||
else:
|
||||
noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=torch.manual_seed(seed), device="cpu")
|
||||
|
||||
if "noise_mask" in latent:
|
||||
noise_mask = latent['noise_mask']
|
||||
noise_mask = torch.nn.functional.interpolate(noise_mask[None,None,], size=(noise.shape[2], noise.shape[3]), mode="bilinear")
|
||||
noise_mask = noise_mask.round()
|
||||
noise_mask = torch.cat([noise_mask] * noise.shape[1], dim=1)
|
||||
noise_mask = torch.cat([noise_mask] * noise.shape[0])
|
||||
noise_mask = noise_mask.to(device)
|
||||
|
||||
real_model = None
|
||||
if device != "cpu":
|
||||
model_management.load_model_gpu(model)
|
||||
@ -393,29 +491,40 @@ def common_ksampler(device, model, seed, steps, cfg, sampler_name, scheduler, po
|
||||
positive_copy = []
|
||||
negative_copy = []
|
||||
|
||||
control_nets = []
|
||||
for p in positive:
|
||||
t = p[0]
|
||||
if t.shape[0] < noise.shape[0]:
|
||||
t = torch.cat([t] * noise.shape[0])
|
||||
t = t.to(device)
|
||||
if 'control' in p[1]:
|
||||
control_nets += [p[1]['control']]
|
||||
positive_copy += [[t] + p[1:]]
|
||||
for n in negative:
|
||||
t = n[0]
|
||||
if t.shape[0] < noise.shape[0]:
|
||||
t = torch.cat([t] * noise.shape[0])
|
||||
t = t.to(device)
|
||||
if 'control' in p[1]:
|
||||
control_nets += [p[1]['control']]
|
||||
negative_copy += [[t] + n[1:]]
|
||||
|
||||
model_management.load_controlnet_gpu(list(map(lambda a: a.control_model, control_nets)))
|
||||
|
||||
if sampler_name in comfy.samplers.KSampler.SAMPLERS:
|
||||
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise)
|
||||
else:
|
||||
#other samplers
|
||||
pass
|
||||
|
||||
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise)
|
||||
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask)
|
||||
samples = samples.cpu()
|
||||
for c in control_nets:
|
||||
c.cleanup()
|
||||
|
||||
return (samples, )
|
||||
out = latent.copy()
|
||||
out["samples"] = samples
|
||||
return (out, )
|
||||
|
||||
class KSampler:
|
||||
def __init__(self, device="cuda"):
|
||||
@ -532,7 +641,7 @@ class LoadImage:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
{"image": (os.listdir(s.input_dir), )},
|
||||
{"image": (sorted(os.listdir(s.input_dir)), )},
|
||||
}
|
||||
|
||||
CATEGORY = "image"
|
||||
@ -541,10 +650,11 @@ class LoadImage:
|
||||
FUNCTION = "load_image"
|
||||
def load_image(self, image):
|
||||
image_path = os.path.join(self.input_dir, image)
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
i = Image.open(image_path)
|
||||
image = i.convert("RGB")
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = torch.from_numpy(image[None])[None,]
|
||||
return image
|
||||
image = torch.from_numpy(image)[None,]
|
||||
return (image,)
|
||||
|
||||
@classmethod
|
||||
def IS_CHANGED(s, image):
|
||||
@ -554,6 +664,41 @@ class LoadImage:
|
||||
m.update(f.read())
|
||||
return m.digest().hex()
|
||||
|
||||
class LoadImageMask:
|
||||
input_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
{"image": (os.listdir(s.input_dir), ),
|
||||
"channel": (["alpha", "red", "green", "blue"], ),}
|
||||
}
|
||||
|
||||
CATEGORY = "image"
|
||||
|
||||
RETURN_TYPES = ("MASK",)
|
||||
FUNCTION = "load_image"
|
||||
def load_image(self, image, channel):
|
||||
image_path = os.path.join(self.input_dir, image)
|
||||
i = Image.open(image_path)
|
||||
mask = None
|
||||
c = channel[0].upper()
|
||||
if c in i.getbands():
|
||||
mask = np.array(i.getchannel(c)).astype(np.float32) / 255.0
|
||||
mask = torch.from_numpy(mask)
|
||||
if c == 'A':
|
||||
mask = 1. - mask
|
||||
else:
|
||||
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
|
||||
return (mask,)
|
||||
|
||||
@classmethod
|
||||
def IS_CHANGED(s, image, channel):
|
||||
image_path = os.path.join(s.input_dir, image)
|
||||
m = hashlib.sha256()
|
||||
with open(image_path, 'rb') as f:
|
||||
m.update(f.read())
|
||||
return m.digest().hex()
|
||||
|
||||
class ImageScale:
|
||||
upscale_methods = ["nearest-exact", "bilinear", "area"]
|
||||
crop_methods = ["disabled", "center"]
|
||||
@ -571,7 +716,7 @@ class ImageScale:
|
||||
|
||||
def upscale(self, image, upscale_method, width, height, crop):
|
||||
samples = image.movedim(-1,1)
|
||||
s = common_upscale(samples, width, height, upscale_method, crop)
|
||||
s = comfy.utils.common_upscale(samples, width, height, upscale_method, crop)
|
||||
s = s.movedim(1,-1)
|
||||
return (s,)
|
||||
|
||||
@ -581,21 +726,26 @@ NODE_CLASS_MAPPINGS = {
|
||||
"CLIPTextEncode": CLIPTextEncode,
|
||||
"VAEDecode": VAEDecode,
|
||||
"VAEEncode": VAEEncode,
|
||||
"VAEEncodeForInpaint": VAEEncodeForInpaint,
|
||||
"VAELoader": VAELoader,
|
||||
"EmptyLatentImage": EmptyLatentImage,
|
||||
"LatentUpscale": LatentUpscale,
|
||||
"SaveImage": SaveImage,
|
||||
"LoadImage": LoadImage,
|
||||
"LoadImageMask": LoadImageMask,
|
||||
"ImageScale": ImageScale,
|
||||
"ConditioningCombine": ConditioningCombine,
|
||||
"ConditioningSetArea": ConditioningSetArea,
|
||||
"KSamplerAdvanced": KSamplerAdvanced,
|
||||
"SetLatentNoiseMask": SetLatentNoiseMask,
|
||||
"LatentComposite": LatentComposite,
|
||||
"LatentRotate": LatentRotate,
|
||||
"LatentFlip": LatentFlip,
|
||||
"LatentCrop": LatentCrop,
|
||||
"LoraLoader": LoraLoader,
|
||||
"CLIPLoader": CLIPLoader,
|
||||
"ControlNetApply": ControlNetApply,
|
||||
"ControlNetLoader": ControlNetLoader,
|
||||
}
|
||||
|
||||
CUSTOM_NODE_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "custom_nodes")
|
||||
|
Loading…
Reference in New Issue
Block a user