Refactor to make it easier to add custom conds to models.

This commit is contained in:
comfyanonymous 2023-10-24 23:31:12 -04:00
parent 3fce8881ca
commit 036f88c621
4 changed files with 170 additions and 173 deletions

64
comfy/conds.py Normal file
View File

@ -0,0 +1,64 @@
import enum
import torch
import math
import comfy.utils
def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9)
return abs(a*b) // math.gcd(a, b)
class CONDRegular:
def __init__(self, cond):
self.cond = cond
def _copy_with(self, cond):
return self.__class__(cond)
def process_cond(self, batch_size, device, **kwargs):
return self._copy_with(comfy.utils.repeat_to_batch_size(self.cond, batch_size).to(device))
def can_concat(self, other):
if self.cond.shape != other.cond.shape:
return False
return True
def concat(self, others):
conds = [self.cond]
for x in others:
conds.append(x.cond)
return torch.cat(conds)
class CONDNoiseShape(CONDRegular):
def process_cond(self, batch_size, device, area, **kwargs):
data = self.cond[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
return self._copy_with(comfy.utils.repeat_to_batch_size(data, batch_size).to(device))
class CONDCrossAttn(CONDRegular):
def can_concat(self, other):
s1 = self.cond.shape
s2 = other.cond.shape
if s1 != s2:
if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen
return False
mult_min = lcm(s1[1], s2[1])
diff = mult_min // min(s1[1], s2[1])
if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
return False
return True
def concat(self, others):
conds = [self.cond]
crossattn_max_len = self.cond.shape[1]
for x in others:
c = x.cond
crossattn_max_len = lcm(crossattn_max_len, c.shape[1])
conds.append(c)
out = []
for c in conds:
if c.shape[1] < crossattn_max_len:
c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result
out.append(c)
return torch.cat(out)

View File

@ -4,6 +4,7 @@ from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugme
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep
import comfy.model_management
import comfy.conds
import numpy as np
from enum import Enum
from . import utils
@ -49,7 +50,7 @@ class BaseModel(torch.nn.Module):
self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32))
self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
def apply_model(self, x, t, c_concat=None, c_crossattn=None, c_adm=None, control=None, transformer_options={}):
def apply_model(self, x, t, c_concat=None, c_crossattn=None, c_adm=None, control=None, transformer_options={}, **kwargs):
if c_concat is not None:
xc = torch.cat([x] + [c_concat], dim=1)
else:
@ -72,7 +73,8 @@ class BaseModel(torch.nn.Module):
def encode_adm(self, **kwargs):
return None
def cond_concat(self, **kwargs):
def extra_conds(self, **kwargs):
out = {}
if self.inpaint_model:
concat_keys = ("mask", "masked_image")
cond_concat = []
@ -101,8 +103,12 @@ class BaseModel(torch.nn.Module):
cond_concat.append(torch.ones_like(noise)[:,:1])
elif ck == "masked_image":
cond_concat.append(blank_inpaint_image_like(noise))
return cond_concat
return None
data = torch.cat(cond_concat, dim=1)
out['c_concat'] = comfy.conds.CONDNoiseShape(data)
adm = self.encode_adm(**kwargs)
if adm is not None:
out['c_adm'] = comfy.conds.CONDRegular(adm)
return out
def load_model_weights(self, sd, unet_prefix=""):
to_load = {}

View File

@ -1,6 +1,7 @@
import torch
import comfy.model_management
import comfy.samplers
import comfy.conds
import comfy.utils
import math
import numpy as np
@ -33,22 +34,24 @@ def prepare_mask(noise_mask, shape, device):
noise_mask = noise_mask.to(device)
return noise_mask
def broadcast_cond(cond, batch, device):
"""broadcasts conditioning to the batch size"""
copy = []
for p in cond:
t = comfy.utils.repeat_to_batch_size(p[0], batch)
t = t.to(device)
copy += [[t] + p[1:]]
return copy
def get_models_from_cond(cond, model_type):
models = []
for c in cond:
if model_type in c[1]:
models += [c[1][model_type]]
if model_type in c:
models += [c[model_type]]
return models
def convert_cond(cond):
out = []
for c in cond:
temp = c[1].copy()
model_conds = temp.get("model_conds", {})
if c[0] is not None:
model_conds["c_crossattn"] = comfy.conds.CONDCrossAttn(c[0])
temp["model_conds"] = model_conds
out.append(temp)
return out
def get_additional_models(positive, negative, dtype):
"""loads additional models in positive and negative conditioning"""
control_nets = set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control"))
@ -72,6 +75,8 @@ def cleanup_additional_models(models):
def prepare_sampling(model, noise_shape, positive, negative, noise_mask):
device = model.load_device
positive = convert_cond(positive)
negative = convert_cond(negative)
if noise_mask is not None:
noise_mask = prepare_mask(noise_mask, noise_shape, device)
@ -81,9 +86,7 @@ def prepare_sampling(model, noise_shape, positive, negative, noise_mask):
comfy.model_management.load_models_gpu([model] + models, comfy.model_management.batch_area_memory(noise_shape[0] * noise_shape[2] * noise_shape[3]) + inference_memory)
real_model = model.model
positive_copy = broadcast_cond(positive, noise_shape[0], device)
negative_copy = broadcast_cond(negative, noise_shape[0], device)
return real_model, positive_copy, negative_copy, noise_mask, models
return real_model, positive, negative, noise_mask, models
def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None):

View File

@ -2,96 +2,44 @@ from .k_diffusion import sampling as k_diffusion_sampling
from .k_diffusion import external as k_diffusion_external
from .extra_samplers import uni_pc
import torch
import enum
from comfy import model_management
from .ldm.models.diffusion.ddim import DDIMSampler
from .ldm.modules.diffusionmodules.util import make_ddim_timesteps
import math
from comfy import model_base
import comfy.utils
def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9)
return abs(a*b) // math.gcd(a, b)
class CONDRegular:
def __init__(self, cond):
self.cond = cond
def can_concat(self, other):
if self.cond.shape != other.cond.shape:
return False
return True
def concat(self, others):
conds = [self.cond]
for x in others:
conds.append(x.cond)
return torch.cat(conds)
class CONDCrossAttn:
def __init__(self, cond):
self.cond = cond
def can_concat(self, other):
s1 = self.cond.shape
s2 = other.cond.shape
if s1 != s2:
if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen
return False
mult_min = lcm(s1[1], s2[1])
diff = mult_min // min(s1[1], s2[1])
if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
return False
return True
def concat(self, others):
conds = [self.cond]
crossattn_max_len = self.cond.shape[1]
for x in others:
c = x.cond
crossattn_max_len = lcm(crossattn_max_len, c.shape[1])
conds.append(c)
out = []
for c in conds:
if c.shape[1] < crossattn_max_len:
c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result
out.append(c)
return torch.cat(out)
import comfy.conds
#The main sampling function shared by all the samplers
#Returns predicted noise
def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
def get_area_and_mult(cond, x_in, timestep_in):
def get_area_and_mult(conds, x_in, timestep_in):
area = (x_in.shape[2], x_in.shape[3], 0, 0)
strength = 1.0
if 'timestep_start' in cond[1]:
timestep_start = cond[1]['timestep_start']
if 'timestep_start' in conds:
timestep_start = conds['timestep_start']
if timestep_in[0] > timestep_start:
return None
if 'timestep_end' in cond[1]:
timestep_end = cond[1]['timestep_end']
if 'timestep_end' in conds:
timestep_end = conds['timestep_end']
if timestep_in[0] < timestep_end:
return None
if 'area' in cond[1]:
area = cond[1]['area']
if 'strength' in cond[1]:
strength = cond[1]['strength']
adm_cond = None
if 'adm_encoded' in cond[1]:
adm_cond = cond[1]['adm_encoded']
if 'area' in conds:
area = conds['area']
if 'strength' in conds:
strength = conds['strength']
input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
if 'mask' in cond[1]:
if 'mask' in conds:
# Scale the mask to the size of the input
# The mask should have been resized as we began the sampling process
mask_strength = 1.0
if "mask_strength" in cond[1]:
mask_strength = cond[1]["mask_strength"]
mask = cond[1]['mask']
if "mask_strength" in conds:
mask_strength = conds["mask_strength"]
mask = conds['mask']
assert(mask.shape[1] == x_in.shape[2])
assert(mask.shape[2] == x_in.shape[3])
mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] * mask_strength
@ -100,7 +48,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
mask = torch.ones_like(input_x)
mult = mask * strength
if 'mask' not in cond[1]:
if 'mask' not in conds:
rr = 8
if area[2] != 0:
for t in range(rr):
@ -116,27 +64,17 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1))
conditionning = {}
conditionning['c_crossattn'] = CONDCrossAttn(cond[0])
if 'concat' in cond[1]:
cond_concat_in = cond[1]['concat']
if cond_concat_in is not None and len(cond_concat_in) > 0:
cropped = []
for x in cond_concat_in:
cr = x[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
cropped.append(cr)
conditionning['c_concat'] = CONDRegular(torch.cat(cropped, dim=1))
if adm_cond is not None:
conditionning['c_adm'] = CONDRegular(adm_cond)
model_conds = conds["model_conds"]
for c in model_conds:
conditionning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area)
control = None
if 'control' in cond[1]:
control = cond[1]['control']
if 'control' in conds:
control = conds['control']
patches = None
if 'gligen' in cond[1]:
gligen = cond[1]['gligen']
if 'gligen' in conds:
gligen = conds['gligen']
patches = {}
gligen_type = gligen[0]
gligen_model = gligen[1]
@ -412,19 +350,19 @@ def resolve_areas_and_cond_masks(conditions, h, w, device):
# While we're doing this, we can also resolve the mask device and scaling for performance reasons
for i in range(len(conditions)):
c = conditions[i]
if 'area' in c[1]:
area = c[1]['area']
if 'area' in c:
area = c['area']
if area[0] == "percentage":
modified = c[1].copy()
modified = c.copy()
area = (max(1, round(area[1] * h)), max(1, round(area[2] * w)), round(area[3] * h), round(area[4] * w))
modified['area'] = area
c = [c[0], modified]
c = modified
conditions[i] = c
if 'mask' in c[1]:
mask = c[1]['mask']
if 'mask' in c:
mask = c['mask']
mask = mask.to(device=device)
modified = c[1].copy()
modified = c.copy()
if len(mask.shape) == 2:
mask = mask.unsqueeze(0)
if mask.shape[1] != h or mask.shape[2] != w:
@ -445,37 +383,39 @@ def resolve_areas_and_cond_masks(conditions, h, w, device):
modified['area'] = area
modified['mask'] = mask
conditions[i] = [c[0], modified]
conditions[i] = modified
def create_cond_with_same_area_if_none(conds, c):
if 'area' not in c[1]:
if 'area' not in c:
return
c_area = c[1]['area']
c_area = c['area']
smallest = None
for x in conds:
if 'area' in x[1]:
a = x[1]['area']
if 'area' in x:
a = x['area']
if c_area[2] >= a[2] and c_area[3] >= a[3]:
if a[0] + a[2] >= c_area[0] + c_area[2]:
if a[1] + a[3] >= c_area[1] + c_area[3]:
if smallest is None:
smallest = x
elif 'area' not in smallest[1]:
elif 'area' not in smallest:
smallest = x
else:
if smallest[1]['area'][0] * smallest[1]['area'][1] > a[0] * a[1]:
if smallest['area'][0] * smallest['area'][1] > a[0] * a[1]:
smallest = x
else:
if smallest is None:
smallest = x
if smallest is None:
return
if 'area' in smallest[1]:
if smallest[1]['area'] == c_area:
if 'area' in smallest:
if smallest['area'] == c_area:
return
n = c[1].copy()
conds += [[smallest[0], n]]
out = c.copy()
out['model_conds'] = smallest['model_conds'].copy() #TODO: which fields should be copied?
conds += [out]
def calculate_start_end_timesteps(model, conds):
for t in range(len(conds)):
@ -483,18 +423,18 @@ def calculate_start_end_timesteps(model, conds):
timestep_start = None
timestep_end = None
if 'start_percent' in x[1]:
timestep_start = model.sigma_to_t(model.t_to_sigma(torch.tensor(x[1]['start_percent'] * 999.0)))
if 'end_percent' in x[1]:
timestep_end = model.sigma_to_t(model.t_to_sigma(torch.tensor(x[1]['end_percent'] * 999.0)))
if 'start_percent' in x:
timestep_start = model.sigma_to_t(model.t_to_sigma(torch.tensor(x['start_percent'] * 999.0)))
if 'end_percent' in x:
timestep_end = model.sigma_to_t(model.t_to_sigma(torch.tensor(x['end_percent'] * 999.0)))
if (timestep_start is not None) or (timestep_end is not None):
n = x[1].copy()
n = x.copy()
if (timestep_start is not None):
n['timestep_start'] = timestep_start
if (timestep_end is not None):
n['timestep_end'] = timestep_end
conds[t] = [x[0], n]
conds[t] = n
def pre_run_control(model, conds):
for t in range(len(conds)):
@ -503,8 +443,8 @@ def pre_run_control(model, conds):
timestep_start = None
timestep_end = None
percent_to_timestep_function = lambda a: model.sigma_to_t(model.t_to_sigma(torch.tensor(a) * 999.0))
if 'control' in x[1]:
x[1]['control'].pre_run(model.inner_model.inner_model, percent_to_timestep_function)
if 'control' in x:
x['control'].pre_run(model.inner_model.inner_model, percent_to_timestep_function)
def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
cond_cnets = []
@ -513,16 +453,16 @@ def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
uncond_other = []
for t in range(len(conds)):
x = conds[t]
if 'area' not in x[1]:
if name in x[1] and x[1][name] is not None:
cond_cnets.append(x[1][name])
if 'area' not in x:
if name in x and x[name] is not None:
cond_cnets.append(x[name])
else:
cond_other.append((x, t))
for t in range(len(uncond)):
x = uncond[t]
if 'area' not in x[1]:
if name in x[1] and x[1][name] is not None:
uncond_cnets.append(x[1][name])
if 'area' not in x:
if name in x and x[name] is not None:
uncond_cnets.append(x[name])
else:
uncond_other.append((x, t))
@ -532,47 +472,35 @@ def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
for x in range(len(cond_cnets)):
temp = uncond_other[x % len(uncond_other)]
o = temp[0]
if name in o[1] and o[1][name] is not None:
n = o[1].copy()
if name in o and o[name] is not None:
n = o.copy()
n[name] = uncond_fill_func(cond_cnets, x)
uncond += [[o[0], n]]
uncond += [n]
else:
n = o[1].copy()
n = o.copy()
n[name] = uncond_fill_func(cond_cnets, x)
uncond[temp[1]] = [o[0], n]
uncond[temp[1]] = n
def encode_adm(model, conds, batch_size, width, height, device, prompt_type):
def encode_model_conds(model_function, conds, noise, device, prompt_type, **kwargs):
for t in range(len(conds)):
x = conds[t]
adm_out = None
if 'adm' in x[1]:
adm_out = x[1]["adm"]
else:
params = x[1].copy()
params["width"] = params.get("width", width * 8)
params["height"] = params.get("height", height * 8)
params["prompt_type"] = params.get("prompt_type", prompt_type)
adm_out = model.encode_adm(device=device, **params)
if adm_out is not None:
x[1] = x[1].copy()
x[1]["adm_encoded"] = comfy.utils.repeat_to_batch_size(adm_out, batch_size).to(device)
return conds
def encode_cond(model_function, key, conds, device, **kwargs):
for t in range(len(conds)):
x = conds[t]
params = x[1].copy()
params = x.copy()
params["device"] = device
params["noise"] = noise
params["width"] = params.get("width", noise.shape[3] * 8)
params["height"] = params.get("height", noise.shape[2] * 8)
params["prompt_type"] = params.get("prompt_type", prompt_type)
for k in kwargs:
if k not in params:
params[k] = kwargs[k]
out = model_function(**params)
if out is not None:
x[1] = x[1].copy()
x[1][key] = out
x = x.copy()
model_conds = x['model_conds'].copy()
for k in out:
model_conds[k] = out[k]
x['model_conds'] = model_conds
conds[t] = x
return conds
class Sampler:
@ -690,19 +618,15 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
pre_run_control(model_wrap, negative + positive)
apply_empty_x_to_equal_area(list(filter(lambda c: c[1].get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x])
apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x])
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
if latent_image is not None:
latent_image = model.process_latent_in(latent_image)
if model.is_adm():
positive = encode_adm(model, positive, noise.shape[0], noise.shape[3], noise.shape[2], device, "positive")
negative = encode_adm(model, negative, noise.shape[0], noise.shape[3], noise.shape[2], device, "negative")
if hasattr(model, 'cond_concat'):
positive = encode_cond(model.cond_concat, "concat", positive, device, noise=noise, latent_image=latent_image, denoise_mask=denoise_mask)
negative = encode_cond(model.cond_concat, "concat", negative, device, noise=noise, latent_image=latent_image, denoise_mask=denoise_mask)
if hasattr(model, 'extra_conds'):
positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask)
negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask)
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed}