Merge branch 'comfyanonymous:master' into master

This commit is contained in:
Fannovel16 2023-02-17 18:02:52 +07:00 committed by GitHub
commit fa66ece26b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 981 additions and 94 deletions

286
comfy/cldm/cldm.py Normal file
View File

@ -0,0 +1,286 @@
#taken from: https://github.com/lllyasviel/ControlNet
#and modified
import einops
import torch
import torch as th
import torch.nn as nn
from ldm.modules.diffusionmodules.util import (
conv_nd,
linear,
zero_module,
timestep_embedding,
)
from einops import rearrange, repeat
from torchvision.utils import make_grid
from ldm.modules.attention import SpatialTransformer
from ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock
from ldm.models.diffusion.ddpm import LatentDiffusion
from ldm.util import log_txt_as_img, exists, instantiate_from_config
class ControlledUnetModel(UNetModel):
#implemented in the ldm unet
pass
class ControlNet(nn.Module):
def __init__(
self,
image_size,
in_channels,
model_channels,
hint_channels,
num_res_blocks,
attention_resolutions,
dropout=0,
channel_mult=(1, 2, 4, 8),
conv_resample=True,
dims=2,
use_checkpoint=False,
use_fp16=False,
num_heads=-1,
num_head_channels=-1,
num_heads_upsample=-1,
use_scale_shift_norm=False,
resblock_updown=False,
use_new_attention_order=False,
use_spatial_transformer=False, # custom transformer support
transformer_depth=1, # custom transformer support
context_dim=None, # custom transformer support
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
legacy=True,
disable_self_attentions=None,
num_attention_blocks=None,
disable_middle_self_attn=False,
use_linear_in_transformer=False,
):
super().__init__()
if use_spatial_transformer:
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
if context_dim is not None:
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
from omegaconf.listconfig import ListConfig
if type(context_dim) == ListConfig:
context_dim = list(context_dim)
if num_heads_upsample == -1:
num_heads_upsample = num_heads
if num_heads == -1:
assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
if num_head_channels == -1:
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
self.dims = dims
self.image_size = image_size
self.in_channels = in_channels
self.model_channels = model_channels
if isinstance(num_res_blocks, int):
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
else:
if len(num_res_blocks) != len(channel_mult):
raise ValueError("provide num_res_blocks either as an int (globally constant) or "
"as a list/tuple (per-level) with the same length as channel_mult")
self.num_res_blocks = num_res_blocks
if disable_self_attentions is not None:
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
assert len(disable_self_attentions) == len(channel_mult)
if num_attention_blocks is not None:
assert len(num_attention_blocks) == len(self.num_res_blocks)
assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
f"attention will still not be set.")
self.attention_resolutions = attention_resolutions
self.dropout = dropout
self.channel_mult = channel_mult
self.conv_resample = conv_resample
self.use_checkpoint = use_checkpoint
self.dtype = th.float16 if use_fp16 else th.float32
self.num_heads = num_heads
self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample
self.predict_codebook_ids = n_embed is not None
time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
self.input_blocks = nn.ModuleList(
[
TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels, 3, padding=1)
)
]
)
self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)])
self.input_hint_block = TimestepEmbedSequential(
conv_nd(dims, hint_channels, 16, 3, padding=1),
nn.SiLU(),
conv_nd(dims, 16, 16, 3, padding=1),
nn.SiLU(),
conv_nd(dims, 16, 32, 3, padding=1, stride=2),
nn.SiLU(),
conv_nd(dims, 32, 32, 3, padding=1),
nn.SiLU(),
conv_nd(dims, 32, 96, 3, padding=1, stride=2),
nn.SiLU(),
conv_nd(dims, 96, 96, 3, padding=1),
nn.SiLU(),
conv_nd(dims, 96, 256, 3, padding=1, stride=2),
nn.SiLU(),
zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
)
self._feature_size = model_channels
input_block_chans = [model_channels]
ch = model_channels
ds = 1
for level, mult in enumerate(channel_mult):
for nr in range(self.num_res_blocks[level]):
layers = [
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=mult * model_channels,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
)
]
ch = mult * model_channels
if ds in attention_resolutions:
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
#num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
if exists(disable_self_attentions):
disabled_sa = disable_self_attentions[level]
else:
disabled_sa = False
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
layers.append(
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint
)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
self.zero_convs.append(self.make_zero_conv(ch))
self._feature_size += ch
input_block_chans.append(ch)
if level != len(channel_mult) - 1:
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=out_ch,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
down=True,
)
if resblock_updown
else Downsample(
ch, conv_resample, dims=dims, out_channels=out_ch
)
)
)
ch = out_ch
input_block_chans.append(ch)
self.zero_convs.append(self.make_zero_conv(ch))
ds *= 2
self._feature_size += ch
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
#num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
self.middle_block = TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
),
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint
),
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
),
)
self.middle_block_out = self.make_zero_conv(ch)
self._feature_size += ch
def make_zero_conv(self, channels):
return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0)))
def forward(self, x, hint, timesteps, context, **kwargs):
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
emb = self.time_embed(t_emb)
guided_hint = self.input_hint_block(hint, emb, context)
outs = []
h = x.type(self.dtype)
for module, zero_conv in zip(self.input_blocks, self.zero_convs):
if guided_hint is not None:
h = module(h, emb, context)
h += guided_hint
guided_hint = None
else:
h = module(h, emb, context)
outs.append(zero_conv(h, emb, context))
h = self.middle_block(h, emb, context)
outs.append(self.middle_block_out(h, emb, context))
return outs

View File

@ -358,7 +358,10 @@ class UniPC:
predict_x0=True,
thresholding=False,
max_val=1.,
variant='bh1'
variant='bh1',
noise_mask=None,
masked_image=None,
noise=None,
):
"""Construct a UniPC.
@ -370,7 +373,10 @@ class UniPC:
self.predict_x0 = predict_x0
self.thresholding = thresholding
self.max_val = max_val
self.noise_mask = noise_mask
self.masked_image = masked_image
self.noise = noise
def dynamic_thresholding_fn(self, x0, t=None):
"""
The dynamic thresholding method.
@ -386,7 +392,10 @@ class UniPC:
"""
Return the noise prediction model.
"""
return self.model(x, t)
if self.noise_mask is not None:
return self.model(x, t) * self.noise_mask
else:
return self.model(x, t)
def data_prediction_fn(self, x, t):
"""
@ -401,6 +410,8 @@ class UniPC:
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
x0 = torch.clamp(x0, -s, s) / s
if self.noise_mask is not None:
x0 = x0 * self.noise_mask + (1. - self.noise_mask) * self.masked_image
return x0
def model_fn(self, x, t):
@ -713,6 +724,8 @@ class UniPC:
assert timesteps.shape[0] - 1 == steps
# with torch.no_grad():
for step_index in trange(steps):
if self.noise_mask is not None:
x = x * self.noise_mask + (1. - self.noise_mask) * (self.masked_image * self.noise_schedule.marginal_alpha(timesteps[step_index]) + self.noise * self.noise_schedule.marginal_std(timesteps[step_index]))
if step_index == 0:
vec_t = timesteps[0].expand((x.shape[0]))
model_prev_list = [self.model_fn(x, vec_t)]
@ -820,7 +833,7 @@ def expand_dims(v, dims):
def sample_unipc(model, noise, image, sigmas, sampling_function, extra_args=None, callback=None, disable=None):
def sample_unipc(model, noise, image, sigmas, sampling_function, extra_args=None, callback=None, disable=None, noise_mask=None):
to_zero = False
if sigmas[-1] == 0:
timesteps = torch.nn.functional.interpolate(sigmas[None,None,:-1], size=(len(sigmas),), mode='linear')[0][0]
@ -843,13 +856,13 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, extra_args=None
device = noise.device
if model.inner_model.parameterization == "v":
if model.parameterization == "v":
model_type = "v"
else:
model_type = "noise"
model_fn = model_wrapper(
model.inner_model.apply_model,
model.inner_model.inner_model.apply_model,
sampling_function,
ns,
model_type=model_type,
@ -857,7 +870,7 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, extra_args=None
model_kwargs=extra_args,
)
uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False)
uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, noise_mask=noise_mask, masked_image=image, noise=noise)
x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=3, lower_order_final=True)
if not to_zero:
x /= ns.marginal_alpha(timesteps[-1])

View File

@ -1320,12 +1320,12 @@ class DiffusionWrapper(torch.nn.Module):
self.conditioning_key = conditioning_key
assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm']
def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None):
def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None, control=None):
if self.conditioning_key is None:
out = self.diffusion_model(x, t)
out = self.diffusion_model(x, t, control=control)
elif self.conditioning_key == 'concat':
xc = torch.cat([x] + c_concat, dim=1)
out = self.diffusion_model(xc, t)
out = self.diffusion_model(xc, t, control=control)
elif self.conditioning_key == 'crossattn':
if not self.sequential_cross_attn:
cc = torch.cat(c_crossattn, 1)
@ -1335,25 +1335,25 @@ class DiffusionWrapper(torch.nn.Module):
# TorchScript changes names of the arguments
# with argument cc defined as context=cc scripted model will produce
# an error: RuntimeError: forward() is missing value for argument 'argument_3'.
out = self.scripted_diffusion_model(x, t, cc)
out = self.scripted_diffusion_model(x, t, cc, control=control)
else:
out = self.diffusion_model(x, t, context=cc)
out = self.diffusion_model(x, t, context=cc, control=control)
elif self.conditioning_key == 'hybrid':
xc = torch.cat([x] + c_concat, dim=1)
cc = torch.cat(c_crossattn, 1)
out = self.diffusion_model(xc, t, context=cc)
out = self.diffusion_model(xc, t, context=cc, control=control)
elif self.conditioning_key == 'hybrid-adm':
assert c_adm is not None
xc = torch.cat([x] + c_concat, dim=1)
cc = torch.cat(c_crossattn, 1)
out = self.diffusion_model(xc, t, context=cc, y=c_adm)
out = self.diffusion_model(xc, t, context=cc, y=c_adm, control=control)
elif self.conditioning_key == 'crossattn-adm':
assert c_adm is not None
cc = torch.cat(c_crossattn, 1)
out = self.diffusion_model(x, t, context=cc, y=c_adm)
out = self.diffusion_model(x, t, context=cc, y=c_adm, control=control)
elif self.conditioning_key == 'adm':
cc = c_crossattn[0]
out = self.diffusion_model(x, t, y=cc)
out = self.diffusion_model(x, t, y=cc, control=control)
else:
raise NotImplementedError()

View File

@ -753,7 +753,7 @@ class UNetModel(nn.Module):
self.middle_block.apply(convert_module_to_f32)
self.output_blocks.apply(convert_module_to_f32)
def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
def forward(self, x, timesteps=None, context=None, y=None, control=None, **kwargs):
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
@ -778,8 +778,14 @@ class UNetModel(nn.Module):
h = module(h, emb, context)
hs.append(h)
h = self.middle_block(h, emb, context)
if control is not None:
h += control.pop()
for module in self.output_blocks:
h = th.cat([h, hs.pop()], dim=1)
hsp = hs.pop()
if control is not None:
hsp += control.pop()
h = th.cat([h, hsp], dim=1)
h = module(h, emb, context)
h = h.type(x.dtype)
if self.predict_codebook_ids:

View File

@ -48,7 +48,7 @@ print("Set vram state to:", ["CPU", "NO VRAM", "LOW VRAM", "NORMAL VRAM"][vram_s
current_loaded_model = None
current_gpu_controlnets = []
model_accelerated = False
@ -56,6 +56,7 @@ model_accelerated = False
def unload_model():
global current_loaded_model
global model_accelerated
global current_gpu_controlnets
if current_loaded_model is not None:
if model_accelerated:
accelerate.hooks.remove_hook_from_submodules(current_loaded_model.model)
@ -64,6 +65,10 @@ def unload_model():
current_loaded_model.model.cpu()
current_loaded_model.unpatch_model()
current_loaded_model = None
if len(current_gpu_controlnets) > 0:
for n in current_gpu_controlnets:
n.cpu()
current_gpu_controlnets = []
def load_model_gpu(model):
@ -95,6 +100,16 @@ def load_model_gpu(model):
model_accelerated = True
return current_loaded_model
def load_controlnet_gpu(models):
global current_gpu_controlnets
for m in current_gpu_controlnets:
if m not in models:
m.cpu()
current_gpu_controlnets = []
for m in models:
current_gpu_controlnets.append(m.cuda())
def get_free_memory():
dev = torch.cuda.current_device()

View File

@ -21,12 +21,13 @@ class CFGDenoiser(torch.nn.Module):
uncond = self.inner_model(x, sigma, cond=uncond)
return uncond + (cond - uncond) * cond_scale
def sampling_function(model_function, x, sigma, uncond, cond, cond_scale):
def get_area_and_mult(cond, x_in):
#The main sampling function shared by all the samplers
#Returns predicted noise
def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, cond_concat=None):
def get_area_and_mult(cond, x_in, cond_concat_in, timestep_in):
area = (x_in.shape[2], x_in.shape[3], 0, 0)
strength = 1.0
min_sigma = 0.0
max_sigma = 999.0
if 'area' in cond[1]:
area = cond[1]['area']
if 'strength' in cond[1]:
@ -48,9 +49,60 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale):
if (area[1] + area[3]) < x_in.shape[3]:
for t in range(rr):
mult[:,:,:,area[1] + area[3] - 1 - t:area[1] + area[3] - t] *= ((1.0/rr) * (t + 1))
return (input_x, mult, cond[0], area)
conditionning = {}
conditionning['c_crossattn'] = cond[0]
if cond_concat_in is not None and len(cond_concat_in) > 0:
cropped = []
for x in cond_concat_in:
cr = x[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
cropped.append(cr)
conditionning['c_concat'] = torch.cat(cropped, dim=1)
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, sigma, max_total_area):
control = None
if 'control' in cond[1]:
control = cond[1]['control']
return (input_x, mult, conditionning, area, control)
def cond_equal_size(c1, c2):
if c1 is c2:
return True
if c1.keys() != c2.keys():
return False
if 'c_crossattn' in c1:
if c1['c_crossattn'].shape != c2['c_crossattn'].shape:
return False
if 'c_concat' in c1:
if c1['c_concat'].shape != c2['c_concat'].shape:
return False
return True
def can_concat_cond(c1, c2):
if c1[0].shape != c2[0].shape:
return False
if (c1[4] is None) != (c2[4] is None):
return False
if c1[4] is not None:
if c1[4] is not c2[4]:
return False
return cond_equal_size(c1[2], c2[2])
def cond_cat(c_list):
c_crossattn = []
c_concat = []
for x in c_list:
if 'c_crossattn' in x:
c_crossattn.append(x['c_crossattn'])
if 'c_concat' in x:
c_concat.append(x['c_concat'])
out = {}
if len(c_crossattn) > 0:
out['c_crossattn'] = [torch.cat(c_crossattn)]
if len(c_concat) > 0:
out['c_concat'] = [torch.cat(c_concat)]
return out
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, cond_concat_in):
out_cond = torch.zeros_like(x_in)
out_count = torch.ones_like(x_in)/100000.0
@ -62,13 +114,13 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale):
to_run = []
for x in cond:
p = get_area_and_mult(x, x_in)
p = get_area_and_mult(x, x_in, cond_concat_in, timestep)
if p is None:
continue
to_run += [(p, COND)]
for x in uncond:
p = get_area_and_mult(x, x_in)
p = get_area_and_mult(x, x_in, cond_concat_in, timestep)
if p is None:
continue
@ -79,9 +131,8 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale):
first_shape = first[0][0].shape
to_batch_temp = []
for x in range(len(to_run)):
if to_run[x][0][0].shape == first_shape:
if to_run[x][0][2].shape == first[0][2].shape:
to_batch_temp += [x]
if can_concat_cond(to_run[x][0], first[0]):
to_batch_temp += [x]
to_batch_temp.reverse()
to_batch = to_batch_temp[:1]
@ -97,6 +148,7 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale):
c = []
cond_or_uncond = []
area = []
control = None
for x in to_batch:
o = to_run.pop(x)
p = o[0]
@ -105,13 +157,17 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale):
c += [p[2]]
area += [p[3]]
cond_or_uncond += [o[1]]
control = p[4]
batch_chunks = len(cond_or_uncond)
input_x = torch.cat(input_x)
c = torch.cat(c)
sigma_ = torch.cat([sigma] * batch_chunks)
c = cond_cat(c)
timestep_ = torch.cat([timestep] * batch_chunks)
output = model_function(input_x, sigma_, cond=c).chunk(batch_chunks)
if control is not None:
c['control'] = control.get_control(input_x, timestep_, c['c_crossattn'])
output = model_function(input_x, timestep_, cond=c).chunk(batch_chunks)
del input_x
for o in range(batch_chunks):
@ -132,15 +188,43 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale):
max_total_area = model_management.maximum_batch_area()
cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, sigma, max_total_area)
cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, cond_concat)
return uncond + (cond - uncond) * cond_scale
class CFGDenoiserComplex(torch.nn.Module):
class CompVisVDenoiser(k_diffusion_external.DiscreteVDDPMDenoiser):
def __init__(self, model, quantize=False, device='cpu'):
super().__init__(model, model.alphas_cumprod, quantize=quantize)
def get_v(self, x, t, cond, **kwargs):
return self.inner_model.apply_model(x, t, cond, **kwargs)
class CFGNoisePredictor(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.inner_model = model
def forward(self, x, sigma, uncond, cond, cond_scale):
return sampling_function(self.inner_model, x, sigma, uncond, cond, cond_scale)
self.alphas_cumprod = model.alphas_cumprod
def apply_model(self, x, timestep, cond, uncond, cond_scale, cond_concat=None):
out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, cond_concat)
return out
class KSamplerX0Inpaint(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.inner_model = model
def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, cond_concat=None):
if denoise_mask is not None:
latent_mask = 1. - denoise_mask
x = x * denoise_mask + (self.latent_image + self.noise * sigma) * latent_mask
out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, cond_concat=cond_concat)
if denoise_mask is not None:
out *= denoise_mask
if denoise_mask is not None:
out += self.latent_image * latent_mask
return out
def simple_scheduler(model, steps):
sigs = []
@ -150,6 +234,15 @@ def simple_scheduler(model, steps):
sigs += [0.0]
return torch.FloatTensor(sigs)
def blank_inpaint_image_like(latent_image):
blank_image = torch.ones_like(latent_image)
# these are the values for "zero" in pixel space translated to latent space
blank_image[:,0] *= 0.8223
blank_image[:,1] *= -0.6876
blank_image[:,2] *= 0.6364
blank_image[:,3] *= 0.1380
return blank_image
def create_cond_with_same_area_if_none(conds, c):
if 'area' not in c[1]:
return
@ -180,6 +273,42 @@ def create_cond_with_same_area_if_none(conds, c):
n = c[1].copy()
conds += [[smallest[0], n]]
def apply_control_net_to_equal_area(conds, uncond):
cond_cnets = []
cond_other = []
uncond_cnets = []
uncond_other = []
for t in range(len(conds)):
x = conds[t]
if 'area' not in x[1]:
if 'control' in x[1] and x[1]['control'] is not None:
cond_cnets.append(x[1]['control'])
else:
cond_other.append((x, t))
for t in range(len(uncond)):
x = uncond[t]
if 'area' not in x[1]:
if 'control' in x[1] and x[1]['control'] is not None:
uncond_cnets.append(x[1]['control'])
else:
uncond_other.append((x, t))
if len(uncond_cnets) > 0:
return
for x in range(len(cond_cnets)):
temp = uncond_other[x % len(uncond_other)]
o = temp[0]
if 'control' in o[1] and o[1]['control'] is not None:
n = o[1].copy()
n['control'] = cond_cnets[x]
uncond += [[o[0], n]]
else:
n = o[1].copy()
n['control'] = cond_cnets[x]
uncond[temp[1]] = [o[0], n]
class KSampler:
SCHEDULERS = ["karras", "normal", "simple"]
SAMPLERS = ["sample_euler", "sample_euler_ancestral", "sample_heun", "sample_dpm_2", "sample_dpm_2_ancestral",
@ -188,11 +317,13 @@ class KSampler:
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None):
self.model = model
self.model_denoise = CFGNoisePredictor(self.model)
if self.model.parameterization == "v":
self.model_wrap = k_diffusion_external.CompVisVDenoiser(self.model, quantize=True)
self.model_wrap = CompVisVDenoiser(self.model_denoise, quantize=True)
else:
self.model_wrap = k_diffusion_external.CompVisDenoiser(self.model, quantize=True)
self.model_k = CFGDenoiserComplex(self.model_wrap)
self.model_wrap = k_diffusion_external.CompVisDenoiser(self.model_denoise, quantize=True)
self.model_wrap.parameterization = self.model.parameterization
self.model_k = KSamplerX0Inpaint(self.model_wrap)
self.device = device
if scheduler not in self.SCHEDULERS:
scheduler = self.SCHEDULERS[0]
@ -200,8 +331,8 @@ class KSampler:
sampler = self.SAMPLERS[0]
self.scheduler = scheduler
self.sampler = sampler
self.sigma_min=float(self.model_wrap.sigmas[0])
self.sigma_max=float(self.model_wrap.sigmas[-1])
self.sigma_min=float(self.model_wrap.sigma_min)
self.sigma_max=float(self.model_wrap.sigma_max)
self.set_steps(steps, denoise)
def _calculate_sigmas(self, steps):
@ -235,7 +366,7 @@ class KSampler:
self.sigmas = sigmas[-(steps + 1):]
def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False):
def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None):
sigmas = self.sigmas
sigma_min = self.sigma_min
@ -262,22 +393,47 @@ class KSampler:
for c in negative:
create_cond_with_same_area_if_none(positive, c)
apply_control_net_to_equal_area(positive, negative)
if self.model.model.diffusion_model.dtype == torch.float16:
precision_scope = torch.autocast
else:
precision_scope = contextlib.nullcontext
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg}
if hasattr(self.model, 'concat_keys'):
cond_concat = []
for ck in self.model.concat_keys:
if denoise_mask is not None:
if ck == "mask":
cond_concat.append(denoise_mask[:,:1])
elif ck == "masked_image":
cond_concat.append(latent_image) #NOTE: the latent_image should be masked by the mask in pixel space
else:
if ck == "mask":
cond_concat.append(torch.ones_like(noise)[:,:1])
elif ck == "masked_image":
cond_concat.append(blank_inpaint_image_like(noise))
extra_args["cond_concat"] = cond_concat
with precision_scope(self.device):
if self.sampler == "uni_pc":
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, extra_args={"cond":positive, "uncond":negative, "cond_scale": cfg})
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, extra_args=extra_args, noise_mask=denoise_mask)
else:
noise *= sigmas[0]
extra_args["denoise_mask"] = denoise_mask
self.model_k.latent_image = latent_image
self.model_k.noise = noise
noise = noise * sigmas[0]
if latent_image is not None:
noise += latent_image
if self.sampler == "sample_dpm_fast":
samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], self.steps, extra_args={"cond":positive, "uncond":negative, "cond_scale": cfg})
samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], self.steps, extra_args=extra_args)
elif self.sampler == "sample_dpm_adaptive":
samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args={"cond":positive, "uncond":negative, "cond_scale": cfg})
samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args=extra_args)
else:
samples = getattr(k_diffusion_sampling, self.sampler)(self.model_k, noise, sigmas, extra_args={"cond":positive, "uncond":negative, "cond_scale": cfg})
samples = getattr(k_diffusion_sampling, self.sampler)(self.model_k, noise, sigmas, extra_args=extra_args)
return samples.to(torch.float32)

View File

@ -6,6 +6,9 @@ import model_management
from ldm.util import instantiate_from_config
from ldm.models.autoencoder import AutoencoderKL
from omegaconf import OmegaConf
from .cldm import cldm
from . import utils
def load_torch_file(ckpt):
if ckpt.lower().endswith(".safetensors"):
@ -323,6 +326,84 @@ class VAE:
samples = samples.cpu()
return samples
class ControlNet:
def __init__(self, control_model):
self.control_model = control_model
self.cond_hint_original = None
self.cond_hint = None
self.strength = 1.0
def get_control(self, x_noisy, t, cond_txt):
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
if self.cond_hint is not None:
del self.cond_hint
self.cond_hint = None
self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(x_noisy.device)
print("set cond_hint", self.cond_hint.shape)
control = self.control_model(x=x_noisy, hint=self.cond_hint, timesteps=t, context=cond_txt)
for x in control:
x *= self.strength
return control
def set_cond_hint(self, cond_hint, strength=1.0):
self.cond_hint_original = cond_hint
self.strength = strength
return self
def cleanup(self):
if self.cond_hint is not None:
del self.cond_hint
self.cond_hint = None
def copy(self):
c = ControlNet(self.control_model)
c.cond_hint_original = self.cond_hint_original
c.strength = self.strength
return c
def load_controlnet(ckpt_path):
controlnet_data = load_torch_file(ckpt_path)
pth_key = 'control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight'
pth = False
sd2 = False
key = 'input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight'
if pth_key in controlnet_data:
pth = True
key = pth_key
elif key in controlnet_data:
pass
else:
print("error checkpoint does not contain controlnet data", ckpt_path)
return None
context_dim = controlnet_data[key].shape[1]
control_model = cldm.ControlNet(image_size=32,
in_channels=4,
hint_channels=3,
model_channels=320,
attention_resolutions=[ 4, 2, 1 ],
num_res_blocks=2,
channel_mult=[ 1, 2, 4, 4 ],
num_heads=8,
use_spatial_transformer=True,
transformer_depth=1,
context_dim=context_dim,
use_checkpoint=True,
legacy=False)
if pth:
class WeightsLoader(torch.nn.Module):
pass
w = WeightsLoader()
w.control_model = control_model
w.load_state_dict(controlnet_data, strict=False)
else:
control_model.load_state_dict(controlnet_data, strict=False)
control = ControlNet(control_model)
return control
def load_clip(ckpt_path, embedding_directory=None):
clip_data = load_torch_file(ckpt_path)
config = {}

18
comfy/utils.py Normal file
View File

@ -0,0 +1,18 @@
import torch
def common_upscale(samples, width, height, upscale_method, crop):
if crop == "center":
old_width = samples.shape[3]
old_height = samples.shape[2]
old_aspect = old_width / old_height
new_aspect = width / height
x = 0
y = 0
if old_aspect > new_aspect:
x = round((old_width - old_width * (new_aspect / old_aspect)) / 2)
elif old_aspect < new_aspect:
y = round((old_height - old_height * (old_aspect / new_aspect)) / 2)
s = samples[:,:,y:old_height-y,x:old_width-x]
else:
s = samples
return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)

View File

@ -7,6 +7,10 @@ import heapq
import traceback
import asyncio
if os.name == "nt":
import logging
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
try:
import aiohttp
from aiohttp import web

View File

@ -0,0 +1,158 @@
model:
base_learning_rate: 5.0e-05
target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false
conditioning_key: hybrid
scale_factor: 0.18215
monitor: val/loss_simple_ema
finetune_keys: null
use_ema: False
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
use_checkpoint: True
image_size: 32 # unused
in_channels: 9
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_head_channels: 64 # need to fix for flash-attn
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: 1
context_dim: 1024
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
#attn_type: "vanilla-xformers"
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: [ ]
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
params:
freeze: True
layer: "penultimate"
data:
target: ldm.data.laion.WebDataModuleFromConfig
params:
tar_base: null # for concat as in LAION-A
p_unsafe_threshold: 0.1
filter_word_list: "data/filters.yaml"
max_pwatermark: 0.45
batch_size: 8
num_workers: 6
multinode: True
min_size: 512
train:
shards:
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-0/{00000..18699}.tar -"
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-1/{00000..18699}.tar -"
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-2/{00000..18699}.tar -"
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-3/{00000..18699}.tar -"
- "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-4/{00000..18699}.tar -" #{00000-94333}.tar"
shuffle: 10000
image_key: jpg
image_transforms:
- target: torchvision.transforms.Resize
params:
size: 512
interpolation: 3
- target: torchvision.transforms.RandomCrop
params:
size: 512
postprocess:
target: ldm.data.laion.AddMask
params:
mode: "512train-large"
p_drop: 0.25
# NOTE use enough shards to avoid empty validation loops in workers
validation:
shards:
- "pipe:aws s3 cp s3://deep-floyd-s3/datasets/laion_cleaned-part5/{93001..94333}.tar - "
shuffle: 0
image_key: jpg
image_transforms:
- target: torchvision.transforms.Resize
params:
size: 512
interpolation: 3
- target: torchvision.transforms.CenterCrop
params:
size: 512
postprocess:
target: ldm.data.laion.AddMask
params:
mode: "512train-large"
p_drop: 0.25
lightning:
find_unused_parameters: True
modelcheckpoint:
params:
every_n_train_steps: 5000
callbacks:
metrics_over_trainsteps_checkpoint:
params:
every_n_train_steps: 10000
image_logger:
target: main.ImageLogger
params:
enable_autocast: False
disabled: False
batch_frequency: 1000
max_images: 4
increase_log_steps: False
log_first_step: False
log_images_kwargs:
use_ema_scope: False
inpaint: False
plot_progressive_rows: False
plot_diffusion_rows: False
N: 4
unconditional_guidance_scale: 5.0
unconditional_guidance_label: [""]
ddim_steps: 50 # todo check these out for depth2img,
ddim_eta: 0.0 # todo check these out for depth2img,
trainer:
benchmark: True
val_check_interval: 5000000
num_sanity_val_steps: 0
accumulate_grad_batches: 1

View File

242
nodes.py
View File

@ -15,11 +15,13 @@ sys.path.insert(0, os.path.join(sys.path[0], "comfy"))
import comfy.samplers
import comfy.sd
import comfy.utils
import model_management
import importlib
supported_ckpt_extensions = ['.ckpt']
supported_pt_extensions = ['.ckpt', '.pt', '.bin']
supported_ckpt_extensions = ['.ckpt', '.pth']
supported_pt_extensions = ['.ckpt', '.pt', '.bin', '.pth']
try:
import safetensors.torch
supported_ckpt_extensions += ['.safetensors']
@ -78,12 +80,14 @@ class ConditioningSetArea:
CATEGORY = "conditioning"
def append(self, conditioning, width, height, x, y, strength, min_sigma=0.0, max_sigma=99.0):
c = copy.deepcopy(conditioning)
for t in c:
t[1]['area'] = (height // 8, width // 8, y // 8, x // 8)
t[1]['strength'] = strength
t[1]['min_sigma'] = min_sigma
t[1]['max_sigma'] = max_sigma
c = []
for t in conditioning:
n = [t[0], t[1].copy()]
n[1]['area'] = (height // 8, width // 8, y // 8, x // 8)
n[1]['strength'] = strength
n[1]['min_sigma'] = min_sigma
n[1]['max_sigma'] = max_sigma
c.append(n)
return (c, )
class VAEDecode:
@ -99,7 +103,7 @@ class VAEDecode:
CATEGORY = "latent"
def decode(self, vae, samples):
return (vae.decode(samples), )
return (vae.decode(samples["samples"]), )
class VAEEncode:
def __init__(self, device="cpu"):
@ -118,7 +122,39 @@ class VAEEncode:
y = (pixels.shape[2] // 64) * 64
if pixels.shape[1] != x or pixels.shape[2] != y:
pixels = pixels[:,:x,:y,:]
return (vae.encode(pixels), )
t = vae.encode(pixels[:,:,:,:3])
return ({"samples":t}, )
class VAEEncodeForInpaint:
def __init__(self, device="cpu"):
self.device = device
@classmethod
def INPUT_TYPES(s):
return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", ), "mask": ("MASK", )}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "encode"
CATEGORY = "latent/inpaint"
def encode(self, vae, pixels, mask):
x = (pixels.shape[1] // 64) * 64
y = (pixels.shape[2] // 64) * 64
if pixels.shape[1] != x or pixels.shape[2] != y:
pixels = pixels[:,:x,:y,:]
mask = mask[:x,:y]
#shave off a few pixels to keep things seamless
kernel_tensor = torch.ones((1, 1, 6, 6))
mask_erosion = torch.clamp(torch.nn.functional.conv2d((1.0 - mask.round())[None], kernel_tensor, padding=3), 0, 1)
for i in range(3):
pixels[:,:,:,i] -= 0.5
pixels[:,:,:,i] *= mask_erosion[0][:x,:y].round()
pixels[:,:,:,i] += 0.5
t = vae.encode(pixels)
return ({"samples":t, "noise_mask": mask}, )
class CheckpointLoader:
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
@ -178,6 +214,48 @@ class VAELoader:
vae = comfy.sd.VAE(ckpt_path=vae_path)
return (vae,)
class ControlNetLoader:
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
controlnet_dir = os.path.join(models_dir, "controlnet")
@classmethod
def INPUT_TYPES(s):
return {"required": { "control_net_name": (filter_files_extensions(recursive_search(s.controlnet_dir), supported_pt_extensions), )}}
RETURN_TYPES = ("CONTROL_NET",)
FUNCTION = "load_controlnet"
CATEGORY = "loaders"
def load_controlnet(self, control_net_name):
controlnet_path = os.path.join(self.controlnet_dir, control_net_name)
controlnet = comfy.sd.load_controlnet(controlnet_path)
return (controlnet,)
class ControlNetApply:
@classmethod
def INPUT_TYPES(s):
return {"required": {"conditioning": ("CONDITIONING", ),
"control_net": ("CONTROL_NET", ),
"image": ("IMAGE", ),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01})
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "apply_controlnet"
CATEGORY = "conditioning"
def apply_controlnet(self, conditioning, control_net, image, strength):
c = []
control_hint = image.movedim(-1,1)
print(control_hint.shape)
for t in conditioning:
n = [t[0], t[1].copy()]
n[1]['control'] = control_net.copy().set_cond_hint(control_hint, strength)
c.append(n)
return (c, )
class CLIPLoader:
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
clip_dir = os.path.join(models_dir, "clip")
@ -213,24 +291,9 @@ class EmptyLatentImage:
def generate(self, width, height, batch_size=1):
latent = torch.zeros([batch_size, 4, height // 8, width // 8])
return (latent, )
return ({"samples":latent}, )
def common_upscale(samples, width, height, upscale_method, crop):
if crop == "center":
old_width = samples.shape[3]
old_height = samples.shape[2]
old_aspect = old_width / old_height
new_aspect = width / height
x = 0
y = 0
if old_aspect > new_aspect:
x = round((old_width - old_width * (new_aspect / old_aspect)) / 2)
elif old_aspect < new_aspect:
y = round((old_height - old_height * (old_aspect / new_aspect)) / 2)
s = samples[:,:,y:old_height-y,x:old_width-x]
else:
s = samples
return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
class LatentUpscale:
upscale_methods = ["nearest-exact", "bilinear", "area"]
@ -248,7 +311,8 @@ class LatentUpscale:
CATEGORY = "latent"
def upscale(self, samples, upscale_method, width, height, crop):
s = common_upscale(samples, width // 8, height // 8, upscale_method, crop)
s = samples.copy()
s["samples"] = comfy.utils.common_upscale(samples["samples"], width // 8, height // 8, upscale_method, crop)
return (s,)
class LatentRotate:
@ -263,6 +327,7 @@ class LatentRotate:
CATEGORY = "latent"
def rotate(self, samples, rotation):
s = samples.copy()
rotate_by = 0
if rotation.startswith("90"):
rotate_by = 1
@ -271,7 +336,7 @@ class LatentRotate:
elif rotation.startswith("270"):
rotate_by = 3
s = torch.rot90(samples, k=rotate_by, dims=[3, 2])
s["samples"] = torch.rot90(samples["samples"], k=rotate_by, dims=[3, 2])
return (s,)
class LatentFlip:
@ -286,12 +351,11 @@ class LatentFlip:
CATEGORY = "latent"
def flip(self, samples, flip_method):
s = samples.copy()
if flip_method.startswith("x"):
s = torch.flip(samples, dims=[2])
s["samples"] = torch.flip(samples["samples"], dims=[2])
elif flip_method.startswith("y"):
s = torch.flip(samples, dims=[3])
else:
s = samples
s["samples"] = torch.flip(samples["samples"], dims=[3])
return (s,)
@ -313,12 +377,15 @@ class LatentComposite:
x = x // 8
y = y // 8
feather = feather // 8
s = samples_to.clone()
samples_out = samples_to.copy()
s = samples_to["samples"].clone()
samples_to = samples_to["samples"]
samples_from = samples_from["samples"]
if feather == 0:
s[:,:,y:y+samples_from.shape[2],x:x+samples_from.shape[3]] = samples_from[:,:,:samples_to.shape[2] - y, :samples_to.shape[3] - x]
else:
s_from = samples_from[:,:,:samples_to.shape[2] - y, :samples_to.shape[3] - x]
mask = torch.ones_like(s_from)
samples_from = samples_from[:,:,:samples_to.shape[2] - y, :samples_to.shape[3] - x]
mask = torch.ones_like(samples_from)
for t in range(feather):
if y != 0:
mask[:,:,t:1+t,:] *= ((1.0/feather) * (t + 1))
@ -331,7 +398,8 @@ class LatentComposite:
mask[:,:,:,mask.shape[3]- 1 - t: mask.shape[3]- t] *= ((1.0/feather) * (t + 1))
rev_mask = torch.ones_like(mask) - mask
s[:,:,y:y+samples_from.shape[2],x:x+samples_from.shape[3]] = samples_from[:,:,:samples_to.shape[2] - y, :samples_to.shape[3] - x] * mask + s[:,:,y:y+samples_from.shape[2],x:x+samples_from.shape[3]] * rev_mask
return (s,)
samples_out["samples"] = s
return (samples_out,)
class LatentCrop:
@classmethod
@ -348,6 +416,8 @@ class LatentCrop:
CATEGORY = "latent"
def crop(self, samples, width, height, x, y):
s = samples.copy()
samples = samples['samples']
x = x // 8
y = y // 8
@ -371,15 +441,43 @@ class LatentCrop:
#make sure size is always multiple of 64
x, to_x = enforce_image_dim(x, to_x, samples.shape[3])
y, to_y = enforce_image_dim(y, to_y, samples.shape[2])
s = samples[:,:,y:to_y, x:to_x]
s['samples'] = samples[:,:,y:to_y, x:to_x]
return (s,)
def common_ksampler(device, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False):
class SetLatentNoiseMask:
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",),
"mask": ("MASK",),
}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "set_mask"
CATEGORY = "latent/inpaint"
def set_mask(self, samples, mask):
s = samples.copy()
s["noise_mask"] = mask
return (s,)
def common_ksampler(device, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False):
latent_image = latent["samples"]
noise_mask = None
if disable_noise:
noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
else:
noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=torch.manual_seed(seed), device="cpu")
if "noise_mask" in latent:
noise_mask = latent['noise_mask']
noise_mask = torch.nn.functional.interpolate(noise_mask[None,None,], size=(noise.shape[2], noise.shape[3]), mode="bilinear")
noise_mask = noise_mask.round()
noise_mask = torch.cat([noise_mask] * noise.shape[1], dim=1)
noise_mask = torch.cat([noise_mask] * noise.shape[0])
noise_mask = noise_mask.to(device)
real_model = None
if device != "cpu":
model_management.load_model_gpu(model)
@ -393,29 +491,40 @@ def common_ksampler(device, model, seed, steps, cfg, sampler_name, scheduler, po
positive_copy = []
negative_copy = []
control_nets = []
for p in positive:
t = p[0]
if t.shape[0] < noise.shape[0]:
t = torch.cat([t] * noise.shape[0])
t = t.to(device)
if 'control' in p[1]:
control_nets += [p[1]['control']]
positive_copy += [[t] + p[1:]]
for n in negative:
t = n[0]
if t.shape[0] < noise.shape[0]:
t = torch.cat([t] * noise.shape[0])
t = t.to(device)
if 'control' in p[1]:
control_nets += [p[1]['control']]
negative_copy += [[t] + n[1:]]
model_management.load_controlnet_gpu(list(map(lambda a: a.control_model, control_nets)))
if sampler_name in comfy.samplers.KSampler.SAMPLERS:
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise)
else:
#other samplers
pass
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise)
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask)
samples = samples.cpu()
for c in control_nets:
c.cleanup()
return (samples, )
out = latent.copy()
out["samples"] = samples
return (out, )
class KSampler:
def __init__(self, device="cuda"):
@ -532,7 +641,7 @@ class LoadImage:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"image": (os.listdir(s.input_dir), )},
{"image": (sorted(os.listdir(s.input_dir)), )},
}
CATEGORY = "image"
@ -541,10 +650,11 @@ class LoadImage:
FUNCTION = "load_image"
def load_image(self, image):
image_path = os.path.join(self.input_dir, image)
image = Image.open(image_path).convert("RGB")
i = Image.open(image_path)
image = i.convert("RGB")
image = np.array(image).astype(np.float32) / 255.0
image = torch.from_numpy(image[None])[None,]
return image
image = torch.from_numpy(image)[None,]
return (image,)
@classmethod
def IS_CHANGED(s, image):
@ -554,6 +664,41 @@ class LoadImage:
m.update(f.read())
return m.digest().hex()
class LoadImageMask:
input_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
@classmethod
def INPUT_TYPES(s):
return {"required":
{"image": (os.listdir(s.input_dir), ),
"channel": (["alpha", "red", "green", "blue"], ),}
}
CATEGORY = "image"
RETURN_TYPES = ("MASK",)
FUNCTION = "load_image"
def load_image(self, image, channel):
image_path = os.path.join(self.input_dir, image)
i = Image.open(image_path)
mask = None
c = channel[0].upper()
if c in i.getbands():
mask = np.array(i.getchannel(c)).astype(np.float32) / 255.0
mask = torch.from_numpy(mask)
if c == 'A':
mask = 1. - mask
else:
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
return (mask,)
@classmethod
def IS_CHANGED(s, image, channel):
image_path = os.path.join(s.input_dir, image)
m = hashlib.sha256()
with open(image_path, 'rb') as f:
m.update(f.read())
return m.digest().hex()
class ImageScale:
upscale_methods = ["nearest-exact", "bilinear", "area"]
crop_methods = ["disabled", "center"]
@ -571,7 +716,7 @@ class ImageScale:
def upscale(self, image, upscale_method, width, height, crop):
samples = image.movedim(-1,1)
s = common_upscale(samples, width, height, upscale_method, crop)
s = comfy.utils.common_upscale(samples, width, height, upscale_method, crop)
s = s.movedim(1,-1)
return (s,)
@ -581,21 +726,26 @@ NODE_CLASS_MAPPINGS = {
"CLIPTextEncode": CLIPTextEncode,
"VAEDecode": VAEDecode,
"VAEEncode": VAEEncode,
"VAEEncodeForInpaint": VAEEncodeForInpaint,
"VAELoader": VAELoader,
"EmptyLatentImage": EmptyLatentImage,
"LatentUpscale": LatentUpscale,
"SaveImage": SaveImage,
"LoadImage": LoadImage,
"LoadImageMask": LoadImageMask,
"ImageScale": ImageScale,
"ConditioningCombine": ConditioningCombine,
"ConditioningSetArea": ConditioningSetArea,
"KSamplerAdvanced": KSamplerAdvanced,
"SetLatentNoiseMask": SetLatentNoiseMask,
"LatentComposite": LatentComposite,
"LatentRotate": LatentRotate,
"LatentFlip": LatentFlip,
"LatentCrop": LatentCrop,
"LoraLoader": LoraLoader,
"CLIPLoader": CLIPLoader,
"ControlNetApply": ControlNetApply,
"ControlNetLoader": ControlNetLoader,
}
CUSTOM_NODE_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "custom_nodes")