mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Refactor cond_concat into conditioning.
This commit is contained in:
parent
430a8334c5
commit
45c972aba8
@ -14,8 +14,8 @@ def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9)
|
|||||||
|
|
||||||
#The main sampling function shared by all the samplers
|
#The main sampling function shared by all the samplers
|
||||||
#Returns predicted noise
|
#Returns predicted noise
|
||||||
def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, cond_concat=None, model_options={}, seed=None):
|
def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
|
||||||
def get_area_and_mult(cond, x_in, cond_concat_in, timestep_in):
|
def get_area_and_mult(cond, x_in, timestep_in):
|
||||||
area = (x_in.shape[2], x_in.shape[3], 0, 0)
|
area = (x_in.shape[2], x_in.shape[3], 0, 0)
|
||||||
strength = 1.0
|
strength = 1.0
|
||||||
if 'timestep_start' in cond[1]:
|
if 'timestep_start' in cond[1]:
|
||||||
@ -68,12 +68,15 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
|||||||
|
|
||||||
conditionning = {}
|
conditionning = {}
|
||||||
conditionning['c_crossattn'] = cond[0]
|
conditionning['c_crossattn'] = cond[0]
|
||||||
if cond_concat_in is not None and len(cond_concat_in) > 0:
|
|
||||||
cropped = []
|
if 'concat' in cond[1]:
|
||||||
for x in cond_concat_in:
|
cond_concat_in = cond[1]['concat']
|
||||||
cr = x[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
|
if cond_concat_in is not None and len(cond_concat_in) > 0:
|
||||||
cropped.append(cr)
|
cropped = []
|
||||||
conditionning['c_concat'] = torch.cat(cropped, dim=1)
|
for x in cond_concat_in:
|
||||||
|
cr = x[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
|
||||||
|
cropped.append(cr)
|
||||||
|
conditionning['c_concat'] = torch.cat(cropped, dim=1)
|
||||||
|
|
||||||
if adm_cond is not None:
|
if adm_cond is not None:
|
||||||
conditionning['c_adm'] = adm_cond
|
conditionning['c_adm'] = adm_cond
|
||||||
@ -173,7 +176,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
|||||||
out['c_adm'] = torch.cat(c_adm)
|
out['c_adm'] = torch.cat(c_adm)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, cond_concat_in, model_options):
|
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, model_options):
|
||||||
out_cond = torch.zeros_like(x_in)
|
out_cond = torch.zeros_like(x_in)
|
||||||
out_count = torch.ones_like(x_in)/100000.0
|
out_count = torch.ones_like(x_in)/100000.0
|
||||||
|
|
||||||
@ -185,14 +188,14 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
|||||||
|
|
||||||
to_run = []
|
to_run = []
|
||||||
for x in cond:
|
for x in cond:
|
||||||
p = get_area_and_mult(x, x_in, cond_concat_in, timestep)
|
p = get_area_and_mult(x, x_in, timestep)
|
||||||
if p is None:
|
if p is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
to_run += [(p, COND)]
|
to_run += [(p, COND)]
|
||||||
if uncond is not None:
|
if uncond is not None:
|
||||||
for x in uncond:
|
for x in uncond:
|
||||||
p = get_area_and_mult(x, x_in, cond_concat_in, timestep)
|
p = get_area_and_mult(x, x_in, timestep)
|
||||||
if p is None:
|
if p is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -286,7 +289,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
|||||||
if math.isclose(cond_scale, 1.0):
|
if math.isclose(cond_scale, 1.0):
|
||||||
uncond = None
|
uncond = None
|
||||||
|
|
||||||
cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, cond_concat, model_options)
|
cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, model_options)
|
||||||
if "sampler_cfg_function" in model_options:
|
if "sampler_cfg_function" in model_options:
|
||||||
args = {"cond": cond, "uncond": uncond, "cond_scale": cond_scale, "timestep": timestep}
|
args = {"cond": cond, "uncond": uncond, "cond_scale": cond_scale, "timestep": timestep}
|
||||||
return model_options["sampler_cfg_function"](args)
|
return model_options["sampler_cfg_function"](args)
|
||||||
@ -307,8 +310,8 @@ class CFGNoisePredictor(torch.nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.inner_model = model
|
self.inner_model = model
|
||||||
self.alphas_cumprod = model.alphas_cumprod
|
self.alphas_cumprod = model.alphas_cumprod
|
||||||
def apply_model(self, x, timestep, cond, uncond, cond_scale, cond_concat=None, model_options={}, seed=None):
|
def apply_model(self, x, timestep, cond, uncond, cond_scale, model_options={}, seed=None):
|
||||||
out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, cond_concat, model_options=model_options, seed=seed)
|
out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, model_options=model_options, seed=seed)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
@ -316,11 +319,11 @@ class KSamplerX0Inpaint(torch.nn.Module):
|
|||||||
def __init__(self, model):
|
def __init__(self, model):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.inner_model = model
|
self.inner_model = model
|
||||||
def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, cond_concat=None, model_options={}, seed=None):
|
def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, model_options={}, seed=None):
|
||||||
if denoise_mask is not None:
|
if denoise_mask is not None:
|
||||||
latent_mask = 1. - denoise_mask
|
latent_mask = 1. - denoise_mask
|
||||||
x = x * denoise_mask + (self.latent_image + self.noise * sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1))) * latent_mask
|
x = x * denoise_mask + (self.latent_image + self.noise * sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1))) * latent_mask
|
||||||
out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, cond_concat=cond_concat, model_options=model_options, seed=seed)
|
out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, model_options=model_options, seed=seed)
|
||||||
if denoise_mask is not None:
|
if denoise_mask is not None:
|
||||||
out *= denoise_mask
|
out *= denoise_mask
|
||||||
|
|
||||||
@ -534,6 +537,19 @@ def encode_adm(model, conds, batch_size, width, height, device, prompt_type):
|
|||||||
|
|
||||||
return conds
|
return conds
|
||||||
|
|
||||||
|
def encode_cond(model_function, key, conds, **kwargs):
|
||||||
|
for t in range(len(conds)):
|
||||||
|
x = conds[t]
|
||||||
|
params = x[1].copy()
|
||||||
|
for k in kwargs:
|
||||||
|
if k not in params:
|
||||||
|
params[k] = kwargs[k]
|
||||||
|
|
||||||
|
out = model_function(**params)
|
||||||
|
if out is not None:
|
||||||
|
x[1] = x[1].copy()
|
||||||
|
x[1][key] = out
|
||||||
|
return conds
|
||||||
|
|
||||||
class Sampler:
|
class Sampler:
|
||||||
def sample(self):
|
def sample(self):
|
||||||
@ -653,20 +669,19 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
|
|||||||
apply_empty_x_to_equal_area(list(filter(lambda c: c[1].get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x])
|
apply_empty_x_to_equal_area(list(filter(lambda c: c[1].get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x])
|
||||||
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
|
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
|
||||||
|
|
||||||
|
if latent_image is not None:
|
||||||
|
latent_image = model.process_latent_in(latent_image)
|
||||||
|
|
||||||
if model.is_adm():
|
if model.is_adm():
|
||||||
positive = encode_adm(model, positive, noise.shape[0], noise.shape[3], noise.shape[2], device, "positive")
|
positive = encode_adm(model, positive, noise.shape[0], noise.shape[3], noise.shape[2], device, "positive")
|
||||||
negative = encode_adm(model, negative, noise.shape[0], noise.shape[3], noise.shape[2], device, "negative")
|
negative = encode_adm(model, negative, noise.shape[0], noise.shape[3], noise.shape[2], device, "negative")
|
||||||
|
|
||||||
if latent_image is not None:
|
if hasattr(model, 'cond_concat'):
|
||||||
latent_image = model.process_latent_in(latent_image)
|
positive = encode_cond(model.cond_concat, "concat", positive, noise=noise, latent_image=latent_image, denoise_mask=denoise_mask)
|
||||||
|
negative = encode_cond(model.cond_concat, "concat", negative, noise=noise, latent_image=latent_image, denoise_mask=denoise_mask)
|
||||||
|
|
||||||
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed}
|
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed}
|
||||||
|
|
||||||
if hasattr(model, 'cond_concat'):
|
|
||||||
cond_concat = model.cond_concat(noise=noise, latent_image=latent_image, denoise_mask=denoise_mask)
|
|
||||||
if cond_concat is not None:
|
|
||||||
extra_args["cond_concat"] = cond_concat
|
|
||||||
|
|
||||||
samples = sampler.sample(model_wrap, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
|
samples = sampler.sample(model_wrap, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
|
||||||
return model.process_latent_out(samples.to(torch.float32))
|
return model.process_latent_out(samples.to(torch.float32))
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user