Support for inpaint models.

This commit is contained in:
comfyanonymous 2023-02-15 16:38:20 -05:00
parent 07db00355f
commit cef2cc3cb0

View File

@ -21,8 +21,8 @@ class CFGDenoiser(torch.nn.Module):
uncond = self.inner_model(x, sigma, cond=uncond) uncond = self.inner_model(x, sigma, cond=uncond)
return uncond + (cond - uncond) * cond_scale return uncond + (cond - uncond) * cond_scale
def sampling_function(model_function, x, sigma, uncond, cond, cond_scale): def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_concat=None):
def get_area_and_mult(cond, x_in): def get_area_and_mult(cond, x_in, cond_concat_in):
area = (x_in.shape[2], x_in.shape[3], 0, 0) area = (x_in.shape[2], x_in.shape[3], 0, 0)
strength = 1.0 strength = 1.0
min_sigma = 0.0 min_sigma = 0.0
@ -48,9 +48,43 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale):
if (area[1] + area[3]) < x_in.shape[3]: if (area[1] + area[3]) < x_in.shape[3]:
for t in range(rr): for t in range(rr):
mult[:,:,:,area[1] + area[3] - 1 - t:area[1] + area[3] - t] *= ((1.0/rr) * (t + 1)) mult[:,:,:,area[1] + area[3] - 1 - t:area[1] + area[3] - t] *= ((1.0/rr) * (t + 1))
return (input_x, mult, cond[0], area) conditionning = {}
conditionning['c_crossattn'] = cond[0]
if cond_concat_in is not None and len(cond_concat_in) > 0:
cropped = []
for x in cond_concat_in:
cr = x[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
cropped.append(cr)
conditionning['c_concat'] = torch.cat(cropped, dim=1)
return (input_x, mult, conditionning, area)
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, sigma, max_total_area): def cond_equal_size(c1, c2):
if c1.keys() != c2.keys():
return False
if 'c_crossattn' in c1:
if c1['c_crossattn'].shape != c2['c_crossattn'].shape:
return False
if 'c_concat' in c1:
if c1['c_concat'].shape != c2['c_concat'].shape:
return False
return True
def cond_cat(c_list):
c_crossattn = []
c_concat = []
for x in c_list:
if 'c_crossattn' in x:
c_crossattn.append(x['c_crossattn'])
if 'c_concat' in x:
c_concat.append(x['c_concat'])
out = {}
if len(c_crossattn) > 0:
out['c_crossattn'] = [torch.cat(c_crossattn)]
if len(c_concat) > 0:
out['c_concat'] = [torch.cat(c_concat)]
return out
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, sigma, max_total_area, cond_concat_in):
out_cond = torch.zeros_like(x_in) out_cond = torch.zeros_like(x_in)
out_count = torch.ones_like(x_in)/100000.0 out_count = torch.ones_like(x_in)/100000.0
@ -62,13 +96,13 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale):
to_run = [] to_run = []
for x in cond: for x in cond:
p = get_area_and_mult(x, x_in) p = get_area_and_mult(x, x_in, cond_concat_in)
if p is None: if p is None:
continue continue
to_run += [(p, COND)] to_run += [(p, COND)]
for x in uncond: for x in uncond:
p = get_area_and_mult(x, x_in) p = get_area_and_mult(x, x_in, cond_concat_in)
if p is None: if p is None:
continue continue
@ -80,7 +114,7 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale):
to_batch_temp = [] to_batch_temp = []
for x in range(len(to_run)): for x in range(len(to_run)):
if to_run[x][0][0].shape == first_shape: if to_run[x][0][0].shape == first_shape:
if to_run[x][0][2].shape == first[0][2].shape: if cond_equal_size(to_run[x][0][2], first[0][2]):
to_batch_temp += [x] to_batch_temp += [x]
to_batch_temp.reverse() to_batch_temp.reverse()
@ -108,7 +142,7 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale):
batch_chunks = len(cond_or_uncond) batch_chunks = len(cond_or_uncond)
input_x = torch.cat(input_x) input_x = torch.cat(input_x)
c = torch.cat(c) c = cond_cat(c)
sigma_ = torch.cat([sigma] * batch_chunks) sigma_ = torch.cat([sigma] * batch_chunks)
output = model_function(input_x, sigma_, cond=c).chunk(batch_chunks) output = model_function(input_x, sigma_, cond=c).chunk(batch_chunks)
@ -132,18 +166,18 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale):
max_total_area = model_management.maximum_batch_area() max_total_area = model_management.maximum_batch_area()
cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, sigma, max_total_area) cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, sigma, max_total_area, cond_concat)
return uncond + (cond - uncond) * cond_scale return uncond + (cond - uncond) * cond_scale
class CFGDenoiserComplex(torch.nn.Module): class CFGDenoiserComplex(torch.nn.Module):
def __init__(self, model): def __init__(self, model):
super().__init__() super().__init__()
self.inner_model = model self.inner_model = model
def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask): def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, cond_concat=None):
if denoise_mask is not None: if denoise_mask is not None:
latent_mask = 1. - denoise_mask latent_mask = 1. - denoise_mask
x = x * denoise_mask + (self.latent_image + self.noise * sigma) * latent_mask x = x * denoise_mask + (self.latent_image + self.noise * sigma) * latent_mask
out = sampling_function(self.inner_model, x, sigma, uncond, cond, cond_scale) out = sampling_function(self.inner_model, x, sigma, uncond, cond, cond_scale, cond_concat)
if denoise_mask is not None: if denoise_mask is not None:
out *= denoise_mask out *= denoise_mask
@ -159,6 +193,17 @@ def simple_scheduler(model, steps):
sigs += [0.0] sigs += [0.0]
return torch.FloatTensor(sigs) return torch.FloatTensor(sigs)
def blank_inpaint_image_like(latent_image):
blank_image = torch.ones_like(latent_image)
# these are the values for "zero" in pixel space translated to latent space
# the proper way to do this is to apply the mask to the image in pixel space and then send it through the VAE
# unfortunately that gives zero flexibility so I did things like this instead which hopefully works
blank_image[:,0] *= 0.8223
blank_image[:,1] *= -0.6876
blank_image[:,2] *= 0.6364
blank_image[:,3] *= 0.1380
return blank_image
def create_cond_with_same_area_if_none(conds, c): def create_cond_with_same_area_if_none(conds, c):
if 'area' not in c[1]: if 'area' not in c[1]:
return return
@ -276,11 +321,24 @@ class KSampler:
else: else:
precision_scope = contextlib.nullcontext precision_scope = contextlib.nullcontext
latent_mask = None
if denoise_mask is not None:
latent_mask = (torch.ones_like(denoise_mask) - denoise_mask)
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg} extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg}
if hasattr(self.model, 'concat_keys'):
cond_concat = []
for ck in self.model.concat_keys:
if denoise_mask is not None:
if ck == "mask":
cond_concat.append(denoise_mask[:,:1])
elif ck == "masked_image":
blank_image = blank_inpaint_image_like(latent_image)
cond_concat.append(latent_image * (1.0 - denoise_mask) + denoise_mask * blank_image)
else:
if ck == "mask":
cond_concat.append(torch.ones_like(noise)[:,:1])
elif ck == "masked_image":
cond_concat.append(blank_inpaint_image_like(noise))
extra_args["cond_concat"] = cond_concat
with precision_scope(self.device): with precision_scope(self.device):
if self.sampler == "uni_pc": if self.sampler == "uni_pc":
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, extra_args=extra_args, noise_mask=denoise_mask) samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, extra_args=extra_args, noise_mask=denoise_mask)