mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Support for inpaint models.
This commit is contained in:
parent
07db00355f
commit
cef2cc3cb0
@ -21,8 +21,8 @@ class CFGDenoiser(torch.nn.Module):
|
|||||||
uncond = self.inner_model(x, sigma, cond=uncond)
|
uncond = self.inner_model(x, sigma, cond=uncond)
|
||||||
return uncond + (cond - uncond) * cond_scale
|
return uncond + (cond - uncond) * cond_scale
|
||||||
|
|
||||||
def sampling_function(model_function, x, sigma, uncond, cond, cond_scale):
|
def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_concat=None):
|
||||||
def get_area_and_mult(cond, x_in):
|
def get_area_and_mult(cond, x_in, cond_concat_in):
|
||||||
area = (x_in.shape[2], x_in.shape[3], 0, 0)
|
area = (x_in.shape[2], x_in.shape[3], 0, 0)
|
||||||
strength = 1.0
|
strength = 1.0
|
||||||
min_sigma = 0.0
|
min_sigma = 0.0
|
||||||
@ -48,9 +48,43 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale):
|
|||||||
if (area[1] + area[3]) < x_in.shape[3]:
|
if (area[1] + area[3]) < x_in.shape[3]:
|
||||||
for t in range(rr):
|
for t in range(rr):
|
||||||
mult[:,:,:,area[1] + area[3] - 1 - t:area[1] + area[3] - t] *= ((1.0/rr) * (t + 1))
|
mult[:,:,:,area[1] + area[3] - 1 - t:area[1] + area[3] - t] *= ((1.0/rr) * (t + 1))
|
||||||
return (input_x, mult, cond[0], area)
|
conditionning = {}
|
||||||
|
conditionning['c_crossattn'] = cond[0]
|
||||||
|
if cond_concat_in is not None and len(cond_concat_in) > 0:
|
||||||
|
cropped = []
|
||||||
|
for x in cond_concat_in:
|
||||||
|
cr = x[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
|
||||||
|
cropped.append(cr)
|
||||||
|
conditionning['c_concat'] = torch.cat(cropped, dim=1)
|
||||||
|
return (input_x, mult, conditionning, area)
|
||||||
|
|
||||||
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, sigma, max_total_area):
|
def cond_equal_size(c1, c2):
|
||||||
|
if c1.keys() != c2.keys():
|
||||||
|
return False
|
||||||
|
if 'c_crossattn' in c1:
|
||||||
|
if c1['c_crossattn'].shape != c2['c_crossattn'].shape:
|
||||||
|
return False
|
||||||
|
if 'c_concat' in c1:
|
||||||
|
if c1['c_concat'].shape != c2['c_concat'].shape:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def cond_cat(c_list):
|
||||||
|
c_crossattn = []
|
||||||
|
c_concat = []
|
||||||
|
for x in c_list:
|
||||||
|
if 'c_crossattn' in x:
|
||||||
|
c_crossattn.append(x['c_crossattn'])
|
||||||
|
if 'c_concat' in x:
|
||||||
|
c_concat.append(x['c_concat'])
|
||||||
|
out = {}
|
||||||
|
if len(c_crossattn) > 0:
|
||||||
|
out['c_crossattn'] = [torch.cat(c_crossattn)]
|
||||||
|
if len(c_concat) > 0:
|
||||||
|
out['c_concat'] = [torch.cat(c_concat)]
|
||||||
|
return out
|
||||||
|
|
||||||
|
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, sigma, max_total_area, cond_concat_in):
|
||||||
out_cond = torch.zeros_like(x_in)
|
out_cond = torch.zeros_like(x_in)
|
||||||
out_count = torch.ones_like(x_in)/100000.0
|
out_count = torch.ones_like(x_in)/100000.0
|
||||||
|
|
||||||
@ -62,13 +96,13 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale):
|
|||||||
|
|
||||||
to_run = []
|
to_run = []
|
||||||
for x in cond:
|
for x in cond:
|
||||||
p = get_area_and_mult(x, x_in)
|
p = get_area_and_mult(x, x_in, cond_concat_in)
|
||||||
if p is None:
|
if p is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
to_run += [(p, COND)]
|
to_run += [(p, COND)]
|
||||||
for x in uncond:
|
for x in uncond:
|
||||||
p = get_area_and_mult(x, x_in)
|
p = get_area_and_mult(x, x_in, cond_concat_in)
|
||||||
if p is None:
|
if p is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -80,7 +114,7 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale):
|
|||||||
to_batch_temp = []
|
to_batch_temp = []
|
||||||
for x in range(len(to_run)):
|
for x in range(len(to_run)):
|
||||||
if to_run[x][0][0].shape == first_shape:
|
if to_run[x][0][0].shape == first_shape:
|
||||||
if to_run[x][0][2].shape == first[0][2].shape:
|
if cond_equal_size(to_run[x][0][2], first[0][2]):
|
||||||
to_batch_temp += [x]
|
to_batch_temp += [x]
|
||||||
|
|
||||||
to_batch_temp.reverse()
|
to_batch_temp.reverse()
|
||||||
@ -108,7 +142,7 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale):
|
|||||||
|
|
||||||
batch_chunks = len(cond_or_uncond)
|
batch_chunks = len(cond_or_uncond)
|
||||||
input_x = torch.cat(input_x)
|
input_x = torch.cat(input_x)
|
||||||
c = torch.cat(c)
|
c = cond_cat(c)
|
||||||
sigma_ = torch.cat([sigma] * batch_chunks)
|
sigma_ = torch.cat([sigma] * batch_chunks)
|
||||||
|
|
||||||
output = model_function(input_x, sigma_, cond=c).chunk(batch_chunks)
|
output = model_function(input_x, sigma_, cond=c).chunk(batch_chunks)
|
||||||
@ -132,18 +166,18 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale):
|
|||||||
|
|
||||||
|
|
||||||
max_total_area = model_management.maximum_batch_area()
|
max_total_area = model_management.maximum_batch_area()
|
||||||
cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, sigma, max_total_area)
|
cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, sigma, max_total_area, cond_concat)
|
||||||
return uncond + (cond - uncond) * cond_scale
|
return uncond + (cond - uncond) * cond_scale
|
||||||
|
|
||||||
class CFGDenoiserComplex(torch.nn.Module):
|
class CFGDenoiserComplex(torch.nn.Module):
|
||||||
def __init__(self, model):
|
def __init__(self, model):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.inner_model = model
|
self.inner_model = model
|
||||||
def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask):
|
def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, cond_concat=None):
|
||||||
if denoise_mask is not None:
|
if denoise_mask is not None:
|
||||||
latent_mask = 1. - denoise_mask
|
latent_mask = 1. - denoise_mask
|
||||||
x = x * denoise_mask + (self.latent_image + self.noise * sigma) * latent_mask
|
x = x * denoise_mask + (self.latent_image + self.noise * sigma) * latent_mask
|
||||||
out = sampling_function(self.inner_model, x, sigma, uncond, cond, cond_scale)
|
out = sampling_function(self.inner_model, x, sigma, uncond, cond, cond_scale, cond_concat)
|
||||||
if denoise_mask is not None:
|
if denoise_mask is not None:
|
||||||
out *= denoise_mask
|
out *= denoise_mask
|
||||||
|
|
||||||
@ -159,6 +193,17 @@ def simple_scheduler(model, steps):
|
|||||||
sigs += [0.0]
|
sigs += [0.0]
|
||||||
return torch.FloatTensor(sigs)
|
return torch.FloatTensor(sigs)
|
||||||
|
|
||||||
|
def blank_inpaint_image_like(latent_image):
|
||||||
|
blank_image = torch.ones_like(latent_image)
|
||||||
|
# these are the values for "zero" in pixel space translated to latent space
|
||||||
|
# the proper way to do this is to apply the mask to the image in pixel space and then send it through the VAE
|
||||||
|
# unfortunately that gives zero flexibility so I did things like this instead which hopefully works
|
||||||
|
blank_image[:,0] *= 0.8223
|
||||||
|
blank_image[:,1] *= -0.6876
|
||||||
|
blank_image[:,2] *= 0.6364
|
||||||
|
blank_image[:,3] *= 0.1380
|
||||||
|
return blank_image
|
||||||
|
|
||||||
def create_cond_with_same_area_if_none(conds, c):
|
def create_cond_with_same_area_if_none(conds, c):
|
||||||
if 'area' not in c[1]:
|
if 'area' not in c[1]:
|
||||||
return
|
return
|
||||||
@ -276,11 +321,24 @@ class KSampler:
|
|||||||
else:
|
else:
|
||||||
precision_scope = contextlib.nullcontext
|
precision_scope = contextlib.nullcontext
|
||||||
|
|
||||||
latent_mask = None
|
|
||||||
if denoise_mask is not None:
|
|
||||||
latent_mask = (torch.ones_like(denoise_mask) - denoise_mask)
|
|
||||||
|
|
||||||
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg}
|
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg}
|
||||||
|
|
||||||
|
if hasattr(self.model, 'concat_keys'):
|
||||||
|
cond_concat = []
|
||||||
|
for ck in self.model.concat_keys:
|
||||||
|
if denoise_mask is not None:
|
||||||
|
if ck == "mask":
|
||||||
|
cond_concat.append(denoise_mask[:,:1])
|
||||||
|
elif ck == "masked_image":
|
||||||
|
blank_image = blank_inpaint_image_like(latent_image)
|
||||||
|
cond_concat.append(latent_image * (1.0 - denoise_mask) + denoise_mask * blank_image)
|
||||||
|
else:
|
||||||
|
if ck == "mask":
|
||||||
|
cond_concat.append(torch.ones_like(noise)[:,:1])
|
||||||
|
elif ck == "masked_image":
|
||||||
|
cond_concat.append(blank_inpaint_image_like(noise))
|
||||||
|
extra_args["cond_concat"] = cond_concat
|
||||||
|
|
||||||
with precision_scope(self.device):
|
with precision_scope(self.device):
|
||||||
if self.sampler == "uni_pc":
|
if self.sampler == "uni_pc":
|
||||||
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, extra_args=extra_args, noise_mask=denoise_mask)
|
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, extra_args=extra_args, noise_mask=denoise_mask)
|
||||||
|
Loading…
Reference in New Issue
Block a user