mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Support for inpaint models.
This commit is contained in:
parent
07db00355f
commit
cef2cc3cb0
@ -21,8 +21,8 @@ class CFGDenoiser(torch.nn.Module):
|
||||
uncond = self.inner_model(x, sigma, cond=uncond)
|
||||
return uncond + (cond - uncond) * cond_scale
|
||||
|
||||
def sampling_function(model_function, x, sigma, uncond, cond, cond_scale):
|
||||
def get_area_and_mult(cond, x_in):
|
||||
def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_concat=None):
|
||||
def get_area_and_mult(cond, x_in, cond_concat_in):
|
||||
area = (x_in.shape[2], x_in.shape[3], 0, 0)
|
||||
strength = 1.0
|
||||
min_sigma = 0.0
|
||||
@ -48,9 +48,43 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale):
|
||||
if (area[1] + area[3]) < x_in.shape[3]:
|
||||
for t in range(rr):
|
||||
mult[:,:,:,area[1] + area[3] - 1 - t:area[1] + area[3] - t] *= ((1.0/rr) * (t + 1))
|
||||
return (input_x, mult, cond[0], area)
|
||||
conditionning = {}
|
||||
conditionning['c_crossattn'] = cond[0]
|
||||
if cond_concat_in is not None and len(cond_concat_in) > 0:
|
||||
cropped = []
|
||||
for x in cond_concat_in:
|
||||
cr = x[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
|
||||
cropped.append(cr)
|
||||
conditionning['c_concat'] = torch.cat(cropped, dim=1)
|
||||
return (input_x, mult, conditionning, area)
|
||||
|
||||
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, sigma, max_total_area):
|
||||
def cond_equal_size(c1, c2):
|
||||
if c1.keys() != c2.keys():
|
||||
return False
|
||||
if 'c_crossattn' in c1:
|
||||
if c1['c_crossattn'].shape != c2['c_crossattn'].shape:
|
||||
return False
|
||||
if 'c_concat' in c1:
|
||||
if c1['c_concat'].shape != c2['c_concat'].shape:
|
||||
return False
|
||||
return True
|
||||
|
||||
def cond_cat(c_list):
|
||||
c_crossattn = []
|
||||
c_concat = []
|
||||
for x in c_list:
|
||||
if 'c_crossattn' in x:
|
||||
c_crossattn.append(x['c_crossattn'])
|
||||
if 'c_concat' in x:
|
||||
c_concat.append(x['c_concat'])
|
||||
out = {}
|
||||
if len(c_crossattn) > 0:
|
||||
out['c_crossattn'] = [torch.cat(c_crossattn)]
|
||||
if len(c_concat) > 0:
|
||||
out['c_concat'] = [torch.cat(c_concat)]
|
||||
return out
|
||||
|
||||
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, sigma, max_total_area, cond_concat_in):
|
||||
out_cond = torch.zeros_like(x_in)
|
||||
out_count = torch.ones_like(x_in)/100000.0
|
||||
|
||||
@ -62,13 +96,13 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale):
|
||||
|
||||
to_run = []
|
||||
for x in cond:
|
||||
p = get_area_and_mult(x, x_in)
|
||||
p = get_area_and_mult(x, x_in, cond_concat_in)
|
||||
if p is None:
|
||||
continue
|
||||
|
||||
to_run += [(p, COND)]
|
||||
for x in uncond:
|
||||
p = get_area_and_mult(x, x_in)
|
||||
p = get_area_and_mult(x, x_in, cond_concat_in)
|
||||
if p is None:
|
||||
continue
|
||||
|
||||
@ -80,7 +114,7 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale):
|
||||
to_batch_temp = []
|
||||
for x in range(len(to_run)):
|
||||
if to_run[x][0][0].shape == first_shape:
|
||||
if to_run[x][0][2].shape == first[0][2].shape:
|
||||
if cond_equal_size(to_run[x][0][2], first[0][2]):
|
||||
to_batch_temp += [x]
|
||||
|
||||
to_batch_temp.reverse()
|
||||
@ -108,7 +142,7 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale):
|
||||
|
||||
batch_chunks = len(cond_or_uncond)
|
||||
input_x = torch.cat(input_x)
|
||||
c = torch.cat(c)
|
||||
c = cond_cat(c)
|
||||
sigma_ = torch.cat([sigma] * batch_chunks)
|
||||
|
||||
output = model_function(input_x, sigma_, cond=c).chunk(batch_chunks)
|
||||
@ -132,18 +166,18 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale):
|
||||
|
||||
|
||||
max_total_area = model_management.maximum_batch_area()
|
||||
cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, sigma, max_total_area)
|
||||
cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, sigma, max_total_area, cond_concat)
|
||||
return uncond + (cond - uncond) * cond_scale
|
||||
|
||||
class CFGDenoiserComplex(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.inner_model = model
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask):
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, cond_concat=None):
|
||||
if denoise_mask is not None:
|
||||
latent_mask = 1. - denoise_mask
|
||||
x = x * denoise_mask + (self.latent_image + self.noise * sigma) * latent_mask
|
||||
out = sampling_function(self.inner_model, x, sigma, uncond, cond, cond_scale)
|
||||
out = sampling_function(self.inner_model, x, sigma, uncond, cond, cond_scale, cond_concat)
|
||||
if denoise_mask is not None:
|
||||
out *= denoise_mask
|
||||
|
||||
@ -159,6 +193,17 @@ def simple_scheduler(model, steps):
|
||||
sigs += [0.0]
|
||||
return torch.FloatTensor(sigs)
|
||||
|
||||
def blank_inpaint_image_like(latent_image):
|
||||
blank_image = torch.ones_like(latent_image)
|
||||
# these are the values for "zero" in pixel space translated to latent space
|
||||
# the proper way to do this is to apply the mask to the image in pixel space and then send it through the VAE
|
||||
# unfortunately that gives zero flexibility so I did things like this instead which hopefully works
|
||||
blank_image[:,0] *= 0.8223
|
||||
blank_image[:,1] *= -0.6876
|
||||
blank_image[:,2] *= 0.6364
|
||||
blank_image[:,3] *= 0.1380
|
||||
return blank_image
|
||||
|
||||
def create_cond_with_same_area_if_none(conds, c):
|
||||
if 'area' not in c[1]:
|
||||
return
|
||||
@ -276,11 +321,24 @@ class KSampler:
|
||||
else:
|
||||
precision_scope = contextlib.nullcontext
|
||||
|
||||
latent_mask = None
|
||||
if denoise_mask is not None:
|
||||
latent_mask = (torch.ones_like(denoise_mask) - denoise_mask)
|
||||
|
||||
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg}
|
||||
|
||||
if hasattr(self.model, 'concat_keys'):
|
||||
cond_concat = []
|
||||
for ck in self.model.concat_keys:
|
||||
if denoise_mask is not None:
|
||||
if ck == "mask":
|
||||
cond_concat.append(denoise_mask[:,:1])
|
||||
elif ck == "masked_image":
|
||||
blank_image = blank_inpaint_image_like(latent_image)
|
||||
cond_concat.append(latent_image * (1.0 - denoise_mask) + denoise_mask * blank_image)
|
||||
else:
|
||||
if ck == "mask":
|
||||
cond_concat.append(torch.ones_like(noise)[:,:1])
|
||||
elif ck == "masked_image":
|
||||
cond_concat.append(blank_inpaint_image_like(noise))
|
||||
extra_args["cond_concat"] = cond_concat
|
||||
|
||||
with precision_scope(self.device):
|
||||
if self.sampler == "uni_pc":
|
||||
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, extra_args=extra_args, noise_mask=denoise_mask)
|
||||
|
Loading…
Reference in New Issue
Block a user