From cef2cc3cb0c2454d89567c96929ad608df033cca Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 15 Feb 2023 16:38:20 -0500 Subject: [PATCH] Support for inpaint models. --- comfy/samplers.py | 88 +++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 73 insertions(+), 15 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index b806381e..c8ab106b 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -21,8 +21,8 @@ class CFGDenoiser(torch.nn.Module): uncond = self.inner_model(x, sigma, cond=uncond) return uncond + (cond - uncond) * cond_scale -def sampling_function(model_function, x, sigma, uncond, cond, cond_scale): - def get_area_and_mult(cond, x_in): +def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_concat=None): + def get_area_and_mult(cond, x_in, cond_concat_in): area = (x_in.shape[2], x_in.shape[3], 0, 0) strength = 1.0 min_sigma = 0.0 @@ -48,9 +48,43 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale): if (area[1] + area[3]) < x_in.shape[3]: for t in range(rr): mult[:,:,:,area[1] + area[3] - 1 - t:area[1] + area[3] - t] *= ((1.0/rr) * (t + 1)) - return (input_x, mult, cond[0], area) + conditionning = {} + conditionning['c_crossattn'] = cond[0] + if cond_concat_in is not None and len(cond_concat_in) > 0: + cropped = [] + for x in cond_concat_in: + cr = x[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] + cropped.append(cr) + conditionning['c_concat'] = torch.cat(cropped, dim=1) + return (input_x, mult, conditionning, area) - def calc_cond_uncond_batch(model_function, cond, uncond, x_in, sigma, max_total_area): + def cond_equal_size(c1, c2): + if c1.keys() != c2.keys(): + return False + if 'c_crossattn' in c1: + if c1['c_crossattn'].shape != c2['c_crossattn'].shape: + return False + if 'c_concat' in c1: + if c1['c_concat'].shape != c2['c_concat'].shape: + return False + return True + + def cond_cat(c_list): + c_crossattn = [] + c_concat = [] + for x in c_list: + if 'c_crossattn' in x: + c_crossattn.append(x['c_crossattn']) + if 'c_concat' in x: + c_concat.append(x['c_concat']) + out = {} + if len(c_crossattn) > 0: + out['c_crossattn'] = [torch.cat(c_crossattn)] + if len(c_concat) > 0: + out['c_concat'] = [torch.cat(c_concat)] + return out + + def calc_cond_uncond_batch(model_function, cond, uncond, x_in, sigma, max_total_area, cond_concat_in): out_cond = torch.zeros_like(x_in) out_count = torch.ones_like(x_in)/100000.0 @@ -62,13 +96,13 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale): to_run = [] for x in cond: - p = get_area_and_mult(x, x_in) + p = get_area_and_mult(x, x_in, cond_concat_in) if p is None: continue to_run += [(p, COND)] for x in uncond: - p = get_area_and_mult(x, x_in) + p = get_area_and_mult(x, x_in, cond_concat_in) if p is None: continue @@ -80,7 +114,7 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale): to_batch_temp = [] for x in range(len(to_run)): if to_run[x][0][0].shape == first_shape: - if to_run[x][0][2].shape == first[0][2].shape: + if cond_equal_size(to_run[x][0][2], first[0][2]): to_batch_temp += [x] to_batch_temp.reverse() @@ -108,7 +142,7 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale): batch_chunks = len(cond_or_uncond) input_x = torch.cat(input_x) - c = torch.cat(c) + c = cond_cat(c) sigma_ = torch.cat([sigma] * batch_chunks) output = model_function(input_x, sigma_, cond=c).chunk(batch_chunks) @@ -132,18 +166,18 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale): max_total_area = model_management.maximum_batch_area() - cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, sigma, max_total_area) + cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, sigma, max_total_area, cond_concat) return uncond + (cond - uncond) * cond_scale class CFGDenoiserComplex(torch.nn.Module): def __init__(self, model): super().__init__() self.inner_model = model - def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask): + def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, cond_concat=None): if denoise_mask is not None: latent_mask = 1. - denoise_mask x = x * denoise_mask + (self.latent_image + self.noise * sigma) * latent_mask - out = sampling_function(self.inner_model, x, sigma, uncond, cond, cond_scale) + out = sampling_function(self.inner_model, x, sigma, uncond, cond, cond_scale, cond_concat) if denoise_mask is not None: out *= denoise_mask @@ -159,6 +193,17 @@ def simple_scheduler(model, steps): sigs += [0.0] return torch.FloatTensor(sigs) +def blank_inpaint_image_like(latent_image): + blank_image = torch.ones_like(latent_image) + # these are the values for "zero" in pixel space translated to latent space + # the proper way to do this is to apply the mask to the image in pixel space and then send it through the VAE + # unfortunately that gives zero flexibility so I did things like this instead which hopefully works + blank_image[:,0] *= 0.8223 + blank_image[:,1] *= -0.6876 + blank_image[:,2] *= 0.6364 + blank_image[:,3] *= 0.1380 + return blank_image + def create_cond_with_same_area_if_none(conds, c): if 'area' not in c[1]: return @@ -276,11 +321,24 @@ class KSampler: else: precision_scope = contextlib.nullcontext - latent_mask = None - if denoise_mask is not None: - latent_mask = (torch.ones_like(denoise_mask) - denoise_mask) - extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg} + + if hasattr(self.model, 'concat_keys'): + cond_concat = [] + for ck in self.model.concat_keys: + if denoise_mask is not None: + if ck == "mask": + cond_concat.append(denoise_mask[:,:1]) + elif ck == "masked_image": + blank_image = blank_inpaint_image_like(latent_image) + cond_concat.append(latent_image * (1.0 - denoise_mask) + denoise_mask * blank_image) + else: + if ck == "mask": + cond_concat.append(torch.ones_like(noise)[:,:1]) + elif ck == "masked_image": + cond_concat.append(blank_inpaint_image_like(noise)) + extra_args["cond_concat"] = cond_concat + with precision_scope(self.device): if self.sampler == "uni_pc": samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, extra_args=extra_args, noise_mask=denoise_mask)