From 07db00355f890c095a1137feed103e360914e7bf Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 15 Feb 2023 01:49:17 -0500 Subject: [PATCH 01/11] Add masks to samplers code for inpainting. --- comfy/extra_samplers/uni_pc.py | 23 ++++++++++++++----- comfy/samplers.py | 40 +++++++++++++++++++++++++--------- 2 files changed, 48 insertions(+), 15 deletions(-) diff --git a/comfy/extra_samplers/uni_pc.py b/comfy/extra_samplers/uni_pc.py index ae3a544e..cfd7225b 100644 --- a/comfy/extra_samplers/uni_pc.py +++ b/comfy/extra_samplers/uni_pc.py @@ -358,7 +358,10 @@ class UniPC: predict_x0=True, thresholding=False, max_val=1., - variant='bh1' + variant='bh1', + noise_mask=None, + masked_image=None, + noise=None, ): """Construct a UniPC. @@ -370,7 +373,10 @@ class UniPC: self.predict_x0 = predict_x0 self.thresholding = thresholding self.max_val = max_val - + self.noise_mask = noise_mask + self.masked_image = masked_image + self.noise = noise + def dynamic_thresholding_fn(self, x0, t=None): """ The dynamic thresholding method. @@ -386,7 +392,10 @@ class UniPC: """ Return the noise prediction model. """ - return self.model(x, t) + if self.noise_mask is not None: + return self.model(x, t) * self.noise_mask + else: + return self.model(x, t) def data_prediction_fn(self, x, t): """ @@ -401,6 +410,8 @@ class UniPC: s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims) x0 = torch.clamp(x0, -s, s) / s + if self.noise_mask is not None: + x0 = x0 * self.noise_mask + (1. - self.noise_mask) * self.masked_image return x0 def model_fn(self, x, t): @@ -713,6 +724,8 @@ class UniPC: assert timesteps.shape[0] - 1 == steps # with torch.no_grad(): for step_index in trange(steps): + if self.noise_mask is not None: + x = x * self.noise_mask + (1. - self.noise_mask) * (self.masked_image * self.noise_schedule.marginal_alpha(timesteps[step_index]) + self.noise * self.noise_schedule.marginal_std(timesteps[step_index])) if step_index == 0: vec_t = timesteps[0].expand((x.shape[0])) model_prev_list = [self.model_fn(x, vec_t)] @@ -820,7 +833,7 @@ def expand_dims(v, dims): -def sample_unipc(model, noise, image, sigmas, sampling_function, extra_args=None, callback=None, disable=None): +def sample_unipc(model, noise, image, sigmas, sampling_function, extra_args=None, callback=None, disable=None, noise_mask=None): to_zero = False if sigmas[-1] == 0: timesteps = torch.nn.functional.interpolate(sigmas[None,None,:-1], size=(len(sigmas),), mode='linear')[0][0] @@ -857,7 +870,7 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, extra_args=None model_kwargs=extra_args, ) - uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False) + uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, noise_mask=noise_mask, masked_image=image, noise=noise) x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=3, lower_order_final=True) if not to_zero: x /= ns.marginal_alpha(timesteps[-1]) diff --git a/comfy/samplers.py b/comfy/samplers.py index 7f6dc972..b806381e 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -139,8 +139,17 @@ class CFGDenoiserComplex(torch.nn.Module): def __init__(self, model): super().__init__() self.inner_model = model - def forward(self, x, sigma, uncond, cond, cond_scale): - return sampling_function(self.inner_model, x, sigma, uncond, cond, cond_scale) + def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask): + if denoise_mask is not None: + latent_mask = 1. - denoise_mask + x = x * denoise_mask + (self.latent_image + self.noise * sigma) * latent_mask + out = sampling_function(self.inner_model, x, sigma, uncond, cond, cond_scale) + if denoise_mask is not None: + out *= denoise_mask + + if denoise_mask is not None: + out += self.latent_image * latent_mask + return out def simple_scheduler(model, steps): sigs = [] @@ -200,8 +209,8 @@ class KSampler: sampler = self.SAMPLERS[0] self.scheduler = scheduler self.sampler = sampler - self.sigma_min=float(self.model_wrap.sigmas[0]) - self.sigma_max=float(self.model_wrap.sigmas[-1]) + self.sigma_min=float(self.model_wrap.sigma_min) + self.sigma_max=float(self.model_wrap.sigma_max) self.set_steps(steps, denoise) def _calculate_sigmas(self, steps): @@ -235,7 +244,7 @@ class KSampler: self.sigmas = sigmas[-(steps + 1):] - def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False): + def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None): sigmas = self.sigmas sigma_min = self.sigma_min @@ -267,17 +276,28 @@ class KSampler: else: precision_scope = contextlib.nullcontext + latent_mask = None + if denoise_mask is not None: + latent_mask = (torch.ones_like(denoise_mask) - denoise_mask) + + extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg} with precision_scope(self.device): if self.sampler == "uni_pc": - samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, extra_args={"cond":positive, "uncond":negative, "cond_scale": cfg}) + samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, extra_args=extra_args, noise_mask=denoise_mask) else: - noise *= sigmas[0] + extra_args["denoise_mask"] = denoise_mask + self.model_k.latent_image = latent_image + self.model_k.noise = noise + + noise = noise * sigmas[0] + if latent_image is not None: noise += latent_image if self.sampler == "sample_dpm_fast": - samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], self.steps, extra_args={"cond":positive, "uncond":negative, "cond_scale": cfg}) + samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], self.steps, extra_args=extra_args) elif self.sampler == "sample_dpm_adaptive": - samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args={"cond":positive, "uncond":negative, "cond_scale": cfg}) + samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args=extra_args) else: - samples = getattr(k_diffusion_sampling, self.sampler)(self.model_k, noise, sigmas, extra_args={"cond":positive, "uncond":negative, "cond_scale": cfg}) + samples = getattr(k_diffusion_sampling, self.sampler)(self.model_k, noise, sigmas, extra_args=extra_args) + return samples.to(torch.float32) From cef2cc3cb0c2454d89567c96929ad608df033cca Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 15 Feb 2023 16:38:20 -0500 Subject: [PATCH 02/11] Support for inpaint models. --- comfy/samplers.py | 88 +++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 73 insertions(+), 15 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index b806381e..c8ab106b 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -21,8 +21,8 @@ class CFGDenoiser(torch.nn.Module): uncond = self.inner_model(x, sigma, cond=uncond) return uncond + (cond - uncond) * cond_scale -def sampling_function(model_function, x, sigma, uncond, cond, cond_scale): - def get_area_and_mult(cond, x_in): +def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_concat=None): + def get_area_and_mult(cond, x_in, cond_concat_in): area = (x_in.shape[2], x_in.shape[3], 0, 0) strength = 1.0 min_sigma = 0.0 @@ -48,9 +48,43 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale): if (area[1] + area[3]) < x_in.shape[3]: for t in range(rr): mult[:,:,:,area[1] + area[3] - 1 - t:area[1] + area[3] - t] *= ((1.0/rr) * (t + 1)) - return (input_x, mult, cond[0], area) + conditionning = {} + conditionning['c_crossattn'] = cond[0] + if cond_concat_in is not None and len(cond_concat_in) > 0: + cropped = [] + for x in cond_concat_in: + cr = x[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] + cropped.append(cr) + conditionning['c_concat'] = torch.cat(cropped, dim=1) + return (input_x, mult, conditionning, area) - def calc_cond_uncond_batch(model_function, cond, uncond, x_in, sigma, max_total_area): + def cond_equal_size(c1, c2): + if c1.keys() != c2.keys(): + return False + if 'c_crossattn' in c1: + if c1['c_crossattn'].shape != c2['c_crossattn'].shape: + return False + if 'c_concat' in c1: + if c1['c_concat'].shape != c2['c_concat'].shape: + return False + return True + + def cond_cat(c_list): + c_crossattn = [] + c_concat = [] + for x in c_list: + if 'c_crossattn' in x: + c_crossattn.append(x['c_crossattn']) + if 'c_concat' in x: + c_concat.append(x['c_concat']) + out = {} + if len(c_crossattn) > 0: + out['c_crossattn'] = [torch.cat(c_crossattn)] + if len(c_concat) > 0: + out['c_concat'] = [torch.cat(c_concat)] + return out + + def calc_cond_uncond_batch(model_function, cond, uncond, x_in, sigma, max_total_area, cond_concat_in): out_cond = torch.zeros_like(x_in) out_count = torch.ones_like(x_in)/100000.0 @@ -62,13 +96,13 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale): to_run = [] for x in cond: - p = get_area_and_mult(x, x_in) + p = get_area_and_mult(x, x_in, cond_concat_in) if p is None: continue to_run += [(p, COND)] for x in uncond: - p = get_area_and_mult(x, x_in) + p = get_area_and_mult(x, x_in, cond_concat_in) if p is None: continue @@ -80,7 +114,7 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale): to_batch_temp = [] for x in range(len(to_run)): if to_run[x][0][0].shape == first_shape: - if to_run[x][0][2].shape == first[0][2].shape: + if cond_equal_size(to_run[x][0][2], first[0][2]): to_batch_temp += [x] to_batch_temp.reverse() @@ -108,7 +142,7 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale): batch_chunks = len(cond_or_uncond) input_x = torch.cat(input_x) - c = torch.cat(c) + c = cond_cat(c) sigma_ = torch.cat([sigma] * batch_chunks) output = model_function(input_x, sigma_, cond=c).chunk(batch_chunks) @@ -132,18 +166,18 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale): max_total_area = model_management.maximum_batch_area() - cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, sigma, max_total_area) + cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, sigma, max_total_area, cond_concat) return uncond + (cond - uncond) * cond_scale class CFGDenoiserComplex(torch.nn.Module): def __init__(self, model): super().__init__() self.inner_model = model - def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask): + def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, cond_concat=None): if denoise_mask is not None: latent_mask = 1. - denoise_mask x = x * denoise_mask + (self.latent_image + self.noise * sigma) * latent_mask - out = sampling_function(self.inner_model, x, sigma, uncond, cond, cond_scale) + out = sampling_function(self.inner_model, x, sigma, uncond, cond, cond_scale, cond_concat) if denoise_mask is not None: out *= denoise_mask @@ -159,6 +193,17 @@ def simple_scheduler(model, steps): sigs += [0.0] return torch.FloatTensor(sigs) +def blank_inpaint_image_like(latent_image): + blank_image = torch.ones_like(latent_image) + # these are the values for "zero" in pixel space translated to latent space + # the proper way to do this is to apply the mask to the image in pixel space and then send it through the VAE + # unfortunately that gives zero flexibility so I did things like this instead which hopefully works + blank_image[:,0] *= 0.8223 + blank_image[:,1] *= -0.6876 + blank_image[:,2] *= 0.6364 + blank_image[:,3] *= 0.1380 + return blank_image + def create_cond_with_same_area_if_none(conds, c): if 'area' not in c[1]: return @@ -276,11 +321,24 @@ class KSampler: else: precision_scope = contextlib.nullcontext - latent_mask = None - if denoise_mask is not None: - latent_mask = (torch.ones_like(denoise_mask) - denoise_mask) - extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg} + + if hasattr(self.model, 'concat_keys'): + cond_concat = [] + for ck in self.model.concat_keys: + if denoise_mask is not None: + if ck == "mask": + cond_concat.append(denoise_mask[:,:1]) + elif ck == "masked_image": + blank_image = blank_inpaint_image_like(latent_image) + cond_concat.append(latent_image * (1.0 - denoise_mask) + denoise_mask * blank_image) + else: + if ck == "mask": + cond_concat.append(torch.ones_like(noise)[:,:1]) + elif ck == "masked_image": + cond_concat.append(blank_inpaint_image_like(noise)) + extra_args["cond_concat"] = cond_concat + with precision_scope(self.device): if self.sampler == "uni_pc": samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, extra_args=extra_args, noise_mask=denoise_mask) From d75003001a8e5c8d5149c87a43410730cee1dace Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 15 Feb 2023 16:58:55 -0500 Subject: [PATCH 03/11] Add a SetLatentNoiseMask node. LATENT is now a dict that can contain properties. --- nodes.py | 78 +++++++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 60 insertions(+), 18 deletions(-) diff --git a/nodes.py b/nodes.py index aaa4f87a..c190d633 100644 --- a/nodes.py +++ b/nodes.py @@ -98,7 +98,7 @@ class VAEDecode: CATEGORY = "latent" def decode(self, vae, samples): - return (vae.decode(samples), ) + return (vae.decode(samples["samples"]), ) class VAEEncode: def __init__(self, device="cpu"): @@ -117,7 +117,9 @@ class VAEEncode: y = (pixels.shape[2] // 64) * 64 if pixels.shape[1] != x or pixels.shape[2] != y: pixels = pixels[:,:x,:y,:] - return (vae.encode(pixels), ) + t = vae.encode(pixels[:,:,:,:3]) + + return ({"samples":t}, ) class CheckpointLoader: models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") @@ -212,7 +214,7 @@ class EmptyLatentImage: def generate(self, width, height, batch_size=1): latent = torch.zeros([batch_size, 4, height // 8, width // 8]) - return (latent, ) + return ({"samples":latent}, ) def common_upscale(samples, width, height, upscale_method, crop): if crop == "center": @@ -247,7 +249,8 @@ class LatentUpscale: CATEGORY = "latent" def upscale(self, samples, upscale_method, width, height, crop): - s = common_upscale(samples, width // 8, height // 8, upscale_method, crop) + s = samples.copy() + s["samples"] = common_upscale(samples["samples"], width // 8, height // 8, upscale_method, crop) return (s,) class LatentRotate: @@ -262,6 +265,7 @@ class LatentRotate: CATEGORY = "latent" def rotate(self, samples, rotation): + s = samples.copy() rotate_by = 0 if rotation.startswith("90"): rotate_by = 1 @@ -270,7 +274,7 @@ class LatentRotate: elif rotation.startswith("270"): rotate_by = 3 - s = torch.rot90(samples, k=rotate_by, dims=[3, 2]) + s["samples"] = torch.rot90(samples["samples"], k=rotate_by, dims=[3, 2]) return (s,) class LatentFlip: @@ -285,12 +289,11 @@ class LatentFlip: CATEGORY = "latent" def flip(self, samples, flip_method): + s = samples.copy() if flip_method.startswith("x"): - s = torch.flip(samples, dims=[2]) + s["samples"] = torch.flip(samples["samples"], dims=[2]) elif flip_method.startswith("y"): - s = torch.flip(samples, dims=[3]) - else: - s = samples + s["samples"] = torch.flip(samples["samples"], dims=[3]) return (s,) @@ -312,12 +315,15 @@ class LatentComposite: x = x // 8 y = y // 8 feather = feather // 8 - s = samples_to.clone() + samples_out = samples_to.copy() + s = samples_to["samples"].clone() + samples_to = samples_to["samples"] + samples_from = samples_from["samples"] if feather == 0: s[:,:,y:y+samples_from.shape[2],x:x+samples_from.shape[3]] = samples_from[:,:,:samples_to.shape[2] - y, :samples_to.shape[3] - x] else: - s_from = samples_from[:,:,:samples_to.shape[2] - y, :samples_to.shape[3] - x] - mask = torch.ones_like(s_from) + samples_from = samples_from[:,:,:samples_to.shape[2] - y, :samples_to.shape[3] - x] + mask = torch.ones_like(samples_from) for t in range(feather): if y != 0: mask[:,:,t:1+t,:] *= ((1.0/feather) * (t + 1)) @@ -330,7 +336,8 @@ class LatentComposite: mask[:,:,:,mask.shape[3]- 1 - t: mask.shape[3]- t] *= ((1.0/feather) * (t + 1)) rev_mask = torch.ones_like(mask) - mask s[:,:,y:y+samples_from.shape[2],x:x+samples_from.shape[3]] = samples_from[:,:,:samples_to.shape[2] - y, :samples_to.shape[3] - x] * mask + s[:,:,y:y+samples_from.shape[2],x:x+samples_from.shape[3]] * rev_mask - return (s,) + samples_out["samples"] = s + return (samples_out,) class LatentCrop: @classmethod @@ -347,6 +354,8 @@ class LatentCrop: CATEGORY = "latent" def crop(self, samples, width, height, x, y): + s = samples.copy() + samples = samples['samples'] x = x // 8 y = y // 8 @@ -370,15 +379,46 @@ class LatentCrop: #make sure size is always multiple of 64 x, to_x = enforce_image_dim(x, to_x, samples.shape[3]) y, to_y = enforce_image_dim(y, to_y, samples.shape[2]) - s = samples[:,:,y:to_y, x:to_x] + s['samples'] = samples[:,:,y:to_y, x:to_x] return (s,) -def common_ksampler(device, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False): +class SetLatentNoiseMask: + @classmethod + def INPUT_TYPES(s): + return {"required": { "samples": ("LATENT",), + "mask": ("MASK",), + }} + RETURN_TYPES = ("LATENT",) + FUNCTION = "set_mask" + + CATEGORY = "latent" + + def set_mask(self, samples, mask): + s = samples.copy() + s["noise_mask"] = mask + return (s,) + + +def common_ksampler(device, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False): + latent_image = latent["samples"] + noise_mask = None + if disable_noise: noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") else: noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=torch.manual_seed(seed), device="cpu") + if "noise_mask" in latent: + noise_mask = latent['noise_mask'] + print(noise_mask.shape, noise.shape) + + noise_mask = torch.nn.functional.interpolate(noise_mask[None,None,], size=(noise.shape[2], noise.shape[3]), mode="bilinear") + noise_mask = noise_mask.floor() + noise_mask = torch.ones_like(noise_mask) - noise_mask + noise_mask = torch.cat([noise_mask] * noise.shape[1], dim=1) + noise_mask = torch.cat([noise_mask] * noise.shape[0]) + noise_mask = noise_mask.to(device) + real_model = None if device != "cpu": model_management.load_model_gpu(model) @@ -411,10 +451,11 @@ def common_ksampler(device, model, seed, steps, cfg, sampler_name, scheduler, po #other samplers pass - samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise) + samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask) samples = samples.cpu() - - return (samples, ) + out = latent.copy() + out["samples"] = samples + return (out, ) class KSampler: def __init__(self, device="cuda"): @@ -589,6 +630,7 @@ NODE_CLASS_MAPPINGS = { "ConditioningCombine": ConditioningCombine, "ConditioningSetArea": ConditioningSetArea, "KSamplerAdvanced": KSamplerAdvanced, + "SetLatentNoiseMask": SetLatentNoiseMask, "LatentComposite": LatentComposite, "LatentRotate": LatentRotate, "LatentFlip": LatentFlip, From e87a8669b66af5dc08d83c7ef29c386618db9927 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 15 Feb 2023 17:39:42 -0500 Subject: [PATCH 04/11] Add a LoadImageMask node to load one colour channel in an image as a mask. --- nodes.py | 48 +++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 41 insertions(+), 7 deletions(-) diff --git a/nodes.py b/nodes.py index c190d633..0a41511c 100644 --- a/nodes.py +++ b/nodes.py @@ -410,11 +410,8 @@ def common_ksampler(device, model, seed, steps, cfg, sampler_name, scheduler, po if "noise_mask" in latent: noise_mask = latent['noise_mask'] - print(noise_mask.shape, noise.shape) - noise_mask = torch.nn.functional.interpolate(noise_mask[None,None,], size=(noise.shape[2], noise.shape[3]), mode="bilinear") - noise_mask = noise_mask.floor() - noise_mask = torch.ones_like(noise_mask) - noise_mask + noise_mask = noise_mask.round() noise_mask = torch.cat([noise_mask] * noise.shape[1], dim=1) noise_mask = torch.cat([noise_mask] * noise.shape[0]) noise_mask = noise_mask.to(device) @@ -581,10 +578,11 @@ class LoadImage: FUNCTION = "load_image" def load_image(self, image): image_path = os.path.join(self.input_dir, image) - image = Image.open(image_path).convert("RGB") + i = Image.open(image_path) + image = i.convert("RGB") image = np.array(image).astype(np.float32) / 255.0 - image = torch.from_numpy(image[None])[None,] - return image + image = torch.from_numpy(image)[None,] + return (image,) @classmethod def IS_CHANGED(s, image): @@ -594,6 +592,41 @@ class LoadImage: m.update(f.read()) return m.digest().hex() +class LoadImageMask: + input_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input") + @classmethod + def INPUT_TYPES(s): + return {"required": + {"image": (os.listdir(s.input_dir), ), + "channel": (["alpha", "red", "green", "blue"], ),} + } + + CATEGORY = "image" + + RETURN_TYPES = ("MASK",) + FUNCTION = "load_image" + def load_image(self, image, channel): + image_path = os.path.join(self.input_dir, image) + i = Image.open(image_path) + mask = None + c = channel[0].upper() + if c in i.getbands(): + mask = np.array(i.getchannel(c)).astype(np.float32) / 255.0 + mask = torch.from_numpy(mask) + if c == 'A': + mask = 1. - mask + else: + mask = torch.zeros((64,64), dtype=torch.float32, device="cpu") + return (mask,) + + @classmethod + def IS_CHANGED(s, image, channel): + image_path = os.path.join(s.input_dir, image) + m = hashlib.sha256() + with open(image_path, 'rb') as f: + m.update(f.read()) + return m.digest().hex() + class ImageScale: upscale_methods = ["nearest-exact", "bilinear", "area"] crop_methods = ["disabled", "center"] @@ -626,6 +659,7 @@ NODE_CLASS_MAPPINGS = { "LatentUpscale": LatentUpscale, "SaveImage": SaveImage, "LoadImage": LoadImage, + "LoadImageMask": LoadImageMask, "ImageScale": ImageScale, "ConditioningCombine": ConditioningCombine, "ConditioningSetArea": ConditioningSetArea, From 59bef84bc8564f745a83b932537ca9e555d93e51 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 15 Feb 2023 17:52:34 -0500 Subject: [PATCH 05/11] Add the config for SD2.x inpainting models. --- models/configs/v2-inpainting-inference.yaml | 158 ++++++++++++++++++++ 1 file changed, 158 insertions(+) create mode 100644 models/configs/v2-inpainting-inference.yaml diff --git a/models/configs/v2-inpainting-inference.yaml b/models/configs/v2-inpainting-inference.yaml new file mode 100644 index 00000000..32a9471d --- /dev/null +++ b/models/configs/v2-inpainting-inference.yaml @@ -0,0 +1,158 @@ +model: + base_learning_rate: 5.0e-05 + target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false + conditioning_key: hybrid + scale_factor: 0.18215 + monitor: val/loss_simple_ema + finetune_keys: null + use_ema: False + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + use_checkpoint: True + image_size: 32 # unused + in_channels: 9 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + #attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [ ] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" + + +data: + target: ldm.data.laion.WebDataModuleFromConfig + params: + tar_base: null # for concat as in LAION-A + p_unsafe_threshold: 0.1 + filter_word_list: "data/filters.yaml" + max_pwatermark: 0.45 + batch_size: 8 + num_workers: 6 + multinode: True + min_size: 512 + train: + shards: + - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-0/{00000..18699}.tar -" + - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-1/{00000..18699}.tar -" + - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-2/{00000..18699}.tar -" + - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-3/{00000..18699}.tar -" + - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-4/{00000..18699}.tar -" #{00000-94333}.tar" + shuffle: 10000 + image_key: jpg + image_transforms: + - target: torchvision.transforms.Resize + params: + size: 512 + interpolation: 3 + - target: torchvision.transforms.RandomCrop + params: + size: 512 + postprocess: + target: ldm.data.laion.AddMask + params: + mode: "512train-large" + p_drop: 0.25 + # NOTE use enough shards to avoid empty validation loops in workers + validation: + shards: + - "pipe:aws s3 cp s3://deep-floyd-s3/datasets/laion_cleaned-part5/{93001..94333}.tar - " + shuffle: 0 + image_key: jpg + image_transforms: + - target: torchvision.transforms.Resize + params: + size: 512 + interpolation: 3 + - target: torchvision.transforms.CenterCrop + params: + size: 512 + postprocess: + target: ldm.data.laion.AddMask + params: + mode: "512train-large" + p_drop: 0.25 + +lightning: + find_unused_parameters: True + modelcheckpoint: + params: + every_n_train_steps: 5000 + + callbacks: + metrics_over_trainsteps_checkpoint: + params: + every_n_train_steps: 10000 + + image_logger: + target: main.ImageLogger + params: + enable_autocast: False + disabled: False + batch_frequency: 1000 + max_images: 4 + increase_log_steps: False + log_first_step: False + log_images_kwargs: + use_ema_scope: False + inpaint: False + plot_progressive_rows: False + plot_diffusion_rows: False + N: 4 + unconditional_guidance_scale: 5.0 + unconditional_guidance_label: [""] + ddim_steps: 50 # todo check these out for depth2img, + ddim_eta: 0.0 # todo check these out for depth2img, + + trainer: + benchmark: True + val_check_interval: 5000000 + num_sanity_val_steps: 0 + accumulate_grad_batches: 1 From bc69fb5245a4f892c39fef48ed672b1bfe7f492a Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 15 Feb 2023 20:44:51 -0500 Subject: [PATCH 06/11] Use inpaint models the proper way by using VAEEncodeForInpaint. --- comfy/samplers.py | 3 +-- nodes.py | 35 ++++++++++++++++++++++++++++++++++- 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index c8ab106b..7e2e667b 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -330,8 +330,7 @@ class KSampler: if ck == "mask": cond_concat.append(denoise_mask[:,:1]) elif ck == "masked_image": - blank_image = blank_inpaint_image_like(latent_image) - cond_concat.append(latent_image * (1.0 - denoise_mask) + denoise_mask * blank_image) + cond_concat.append(latent_image) #NOTE: the latent_image should be masked by the mask in pixel space else: if ck == "mask": cond_concat.append(torch.ones_like(noise)[:,:1]) diff --git a/nodes.py b/nodes.py index 0a41511c..5784ba3d 100644 --- a/nodes.py +++ b/nodes.py @@ -121,6 +121,38 @@ class VAEEncode: return ({"samples":t}, ) +class VAEEncodeForInpaint: + def __init__(self, device="cpu"): + self.device = device + + @classmethod + def INPUT_TYPES(s): + return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", ), "mask": ("MASK", )}} + RETURN_TYPES = ("LATENT",) + FUNCTION = "encode" + + CATEGORY = "latent/inpaint" + + def encode(self, vae, pixels, mask): + print(pixels.shape, mask.shape) + x = (pixels.shape[1] // 64) * 64 + y = (pixels.shape[2] // 64) * 64 + if pixels.shape[1] != x or pixels.shape[2] != y: + pixels = pixels[:,:x,:y,:] + mask = mask[:x,:y] + + #shave off a few pixels to keep things seamless + kernel_tensor = torch.ones((1, 1, 6, 6)) + mask_erosion = torch.clamp(torch.nn.functional.conv2d((1.0 - mask.round())[None], kernel_tensor, padding=3), 0, 1) + print(mask_erosion.shape, pixels.shape) + for i in range(3): + pixels[:,:,:,i] -= 0.5 + pixels[:,:,:,i] *= mask_erosion[0][:x,:y].round() + pixels[:,:,:,i] += 0.5 + t = vae.encode(pixels) + + return ({"samples":t, "noise_mask": mask}, ) + class CheckpointLoader: models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") config_dir = os.path.join(models_dir, "configs") @@ -391,7 +423,7 @@ class SetLatentNoiseMask: RETURN_TYPES = ("LATENT",) FUNCTION = "set_mask" - CATEGORY = "latent" + CATEGORY = "latent/inpaint" def set_mask(self, samples, mask): s = samples.copy() @@ -654,6 +686,7 @@ NODE_CLASS_MAPPINGS = { "CLIPTextEncode": CLIPTextEncode, "VAEDecode": VAEDecode, "VAEEncode": VAEEncode, + "VAEEncodeForInpaint": VAEEncodeForInpaint, "VAELoader": VAELoader, "EmptyLatentImage": EmptyLatentImage, "LatentUpscale": LatentUpscale, From 4efa67fa1239b49bdbdb944ac1980a6a4730b5e2 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 16 Feb 2023 10:38:08 -0500 Subject: [PATCH 07/11] Add ControlNet support. --- comfy/cldm/cldm.py | 286 ++++++++++++++++++ comfy/extra_samplers/uni_pc.py | 4 +- comfy/ldm/models/diffusion/ddpm.py | 18 +- .../modules/diffusionmodules/openaimodel.py | 10 +- comfy/model_management.py | 17 +- comfy/samplers.py | 121 ++++++-- comfy/sd.py | 76 +++++ comfy/utils.py | 18 ++ nodes.py | 93 ++++-- 9 files changed, 580 insertions(+), 63 deletions(-) create mode 100644 comfy/cldm/cldm.py create mode 100644 comfy/utils.py diff --git a/comfy/cldm/cldm.py b/comfy/cldm/cldm.py new file mode 100644 index 00000000..c75830ae --- /dev/null +++ b/comfy/cldm/cldm.py @@ -0,0 +1,286 @@ +#taken from: https://github.com/lllyasviel/ControlNet +#and modified + +import einops +import torch +import torch as th +import torch.nn as nn + +from ldm.modules.diffusionmodules.util import ( + conv_nd, + linear, + zero_module, + timestep_embedding, +) + +from einops import rearrange, repeat +from torchvision.utils import make_grid +from ldm.modules.attention import SpatialTransformer +from ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock +from ldm.models.diffusion.ddpm import LatentDiffusion +from ldm.util import log_txt_as_img, exists, instantiate_from_config + + +class ControlledUnetModel(UNetModel): + #implemented in the ldm unet + pass + +class ControlNet(nn.Module): + def __init__( + self, + image_size, + in_channels, + model_channels, + hint_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + use_checkpoint=False, + use_fp16=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + legacy=True, + disable_self_attentions=None, + num_attention_blocks=None, + disable_middle_self_attn=False, + use_linear_in_transformer=False, + ): + super().__init__() + if use_spatial_transformer: + assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' + + if context_dim is not None: + assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' + from omegaconf.listconfig import ListConfig + if type(context_dim) == ListConfig: + context_dim = list(context_dim) + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' + + if num_head_channels == -1: + assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' + + self.dims = dims + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + if isinstance(num_res_blocks, int): + self.num_res_blocks = len(channel_mult) * [num_res_blocks] + else: + if len(num_res_blocks) != len(channel_mult): + raise ValueError("provide num_res_blocks either as an int (globally constant) or " + "as a list/tuple (per-level) with the same length as channel_mult") + self.num_res_blocks = num_res_blocks + if disable_self_attentions is not None: + # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not + assert len(disable_self_attentions) == len(channel_mult) + if num_attention_blocks is not None: + assert len(num_attention_blocks) == len(self.num_res_blocks) + assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks)))) + print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " + f"This option has LESS priority than attention_resolutions {attention_resolutions}, " + f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " + f"attention will still not be set.") + + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.predict_codebook_ids = n_embed is not None + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)]) + + self.input_hint_block = TimestepEmbedSequential( + conv_nd(dims, hint_channels, 16, 3, padding=1), + nn.SiLU(), + conv_nd(dims, 16, 16, 3, padding=1), + nn.SiLU(), + conv_nd(dims, 16, 32, 3, padding=1, stride=2), + nn.SiLU(), + conv_nd(dims, 32, 32, 3, padding=1), + nn.SiLU(), + conv_nd(dims, 32, 96, 3, padding=1, stride=2), + nn.SiLU(), + conv_nd(dims, 96, 96, 3, padding=1), + nn.SiLU(), + conv_nd(dims, 96, 256, 3, padding=1, stride=2), + nn.SiLU(), + zero_module(conv_nd(dims, 256, model_channels, 3, padding=1)) + ) + + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for nr in range(self.num_res_blocks[level]): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + if exists(disable_self_attentions): + disabled_sa = disable_self_attentions[level] + else: + disabled_sa = False + + if not exists(num_attention_blocks) or nr < num_attention_blocks[level]: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self.zero_convs.append(self.make_zero_conv(ch)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + self.zero_convs.append(self.make_zero_conv(ch)) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self.middle_block_out = self.make_zero_conv(ch) + self._feature_size += ch + + def make_zero_conv(self, channels): + return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0))) + + def forward(self, x, hint, timesteps, context, **kwargs): + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + + guided_hint = self.input_hint_block(hint, emb, context) + + outs = [] + + h = x.type(self.dtype) + for module, zero_conv in zip(self.input_blocks, self.zero_convs): + if guided_hint is not None: + h = module(h, emb, context) + h += guided_hint + guided_hint = None + else: + h = module(h, emb, context) + outs.append(zero_conv(h, emb, context)) + + h = self.middle_block(h, emb, context) + outs.append(self.middle_block_out(h, emb, context)) + + return outs + diff --git a/comfy/extra_samplers/uni_pc.py b/comfy/extra_samplers/uni_pc.py index cfd7225b..bcc7c0f2 100644 --- a/comfy/extra_samplers/uni_pc.py +++ b/comfy/extra_samplers/uni_pc.py @@ -856,13 +856,13 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, extra_args=None device = noise.device - if model.inner_model.parameterization == "v": + if model.parameterization == "v": model_type = "v" else: model_type = "noise" model_fn = model_wrapper( - model.inner_model.apply_model, + model.inner_model.inner_model.apply_model, sampling_function, ns, model_type=model_type, diff --git a/comfy/ldm/models/diffusion/ddpm.py b/comfy/ldm/models/diffusion/ddpm.py index 074919d0..efe20a3b 100644 --- a/comfy/ldm/models/diffusion/ddpm.py +++ b/comfy/ldm/models/diffusion/ddpm.py @@ -1320,12 +1320,12 @@ class DiffusionWrapper(torch.nn.Module): self.conditioning_key = conditioning_key assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm'] - def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None): + def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None, control=None): if self.conditioning_key is None: - out = self.diffusion_model(x, t) + out = self.diffusion_model(x, t, control=control) elif self.conditioning_key == 'concat': xc = torch.cat([x] + c_concat, dim=1) - out = self.diffusion_model(xc, t) + out = self.diffusion_model(xc, t, control=control) elif self.conditioning_key == 'crossattn': if not self.sequential_cross_attn: cc = torch.cat(c_crossattn, 1) @@ -1335,25 +1335,25 @@ class DiffusionWrapper(torch.nn.Module): # TorchScript changes names of the arguments # with argument cc defined as context=cc scripted model will produce # an error: RuntimeError: forward() is missing value for argument 'argument_3'. - out = self.scripted_diffusion_model(x, t, cc) + out = self.scripted_diffusion_model(x, t, cc, control=control) else: - out = self.diffusion_model(x, t, context=cc) + out = self.diffusion_model(x, t, context=cc, control=control) elif self.conditioning_key == 'hybrid': xc = torch.cat([x] + c_concat, dim=1) cc = torch.cat(c_crossattn, 1) - out = self.diffusion_model(xc, t, context=cc) + out = self.diffusion_model(xc, t, context=cc, control=control) elif self.conditioning_key == 'hybrid-adm': assert c_adm is not None xc = torch.cat([x] + c_concat, dim=1) cc = torch.cat(c_crossattn, 1) - out = self.diffusion_model(xc, t, context=cc, y=c_adm) + out = self.diffusion_model(xc, t, context=cc, y=c_adm, control=control) elif self.conditioning_key == 'crossattn-adm': assert c_adm is not None cc = torch.cat(c_crossattn, 1) - out = self.diffusion_model(x, t, context=cc, y=c_adm) + out = self.diffusion_model(x, t, context=cc, y=c_adm, control=control) elif self.conditioning_key == 'adm': cc = c_crossattn[0] - out = self.diffusion_model(x, t, y=cc) + out = self.diffusion_model(x, t, y=cc, control=control) else: raise NotImplementedError() diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 764a34b8..1769cc00 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -753,7 +753,7 @@ class UNetModel(nn.Module): self.middle_block.apply(convert_module_to_f32) self.output_blocks.apply(convert_module_to_f32) - def forward(self, x, timesteps=None, context=None, y=None,**kwargs): + def forward(self, x, timesteps=None, context=None, y=None, control=None, **kwargs): """ Apply the model to an input batch. :param x: an [N x C x ...] Tensor of inputs. @@ -778,8 +778,14 @@ class UNetModel(nn.Module): h = module(h, emb, context) hs.append(h) h = self.middle_block(h, emb, context) + if control is not None: + h += control.pop() + for module in self.output_blocks: - h = th.cat([h, hs.pop()], dim=1) + hsp = hs.pop() + if control is not None: + hsp += control.pop() + h = th.cat([h, hsp], dim=1) h = module(h, emb, context) h = h.type(x.dtype) if self.predict_codebook_ids: diff --git a/comfy/model_management.py b/comfy/model_management.py index ff7cbeb0..b8fd8796 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -48,7 +48,7 @@ print("Set vram state to:", ["CPU", "NO VRAM", "LOW VRAM", "NORMAL VRAM"][vram_s current_loaded_model = None - +current_gpu_controlnets = [] model_accelerated = False @@ -56,6 +56,7 @@ model_accelerated = False def unload_model(): global current_loaded_model global model_accelerated + global current_gpu_controlnets if current_loaded_model is not None: if model_accelerated: accelerate.hooks.remove_hook_from_submodules(current_loaded_model.model) @@ -64,6 +65,10 @@ def unload_model(): current_loaded_model.model.cpu() current_loaded_model.unpatch_model() current_loaded_model = None + if len(current_gpu_controlnets) > 0: + for n in current_gpu_controlnets: + n.cpu() + current_gpu_controlnets = [] def load_model_gpu(model): @@ -95,6 +100,16 @@ def load_model_gpu(model): model_accelerated = True return current_loaded_model +def load_controlnet_gpu(models): + global current_gpu_controlnets + for m in current_gpu_controlnets: + if m not in models: + m.cpu() + + current_gpu_controlnets = [] + for m in models: + current_gpu_controlnets.append(m.cuda()) + def get_free_memory(): dev = torch.cuda.current_device() diff --git a/comfy/samplers.py b/comfy/samplers.py index 7e2e667b..a5a31811 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -21,12 +21,13 @@ class CFGDenoiser(torch.nn.Module): uncond = self.inner_model(x, sigma, cond=uncond) return uncond + (cond - uncond) * cond_scale -def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_concat=None): - def get_area_and_mult(cond, x_in, cond_concat_in): + +#The main sampling function shared by all the samplers +#Returns predicted noise +def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, cond_concat=None): + def get_area_and_mult(cond, x_in, cond_concat_in, timestep_in): area = (x_in.shape[2], x_in.shape[3], 0, 0) strength = 1.0 - min_sigma = 0.0 - max_sigma = 999.0 if 'area' in cond[1]: area = cond[1]['area'] if 'strength' in cond[1]: @@ -56,9 +57,15 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_c cr = x[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] cropped.append(cr) conditionning['c_concat'] = torch.cat(cropped, dim=1) - return (input_x, mult, conditionning, area) + + control = None + if 'control' in cond[1]: + control = cond[1]['control'] + return (input_x, mult, conditionning, area, control) def cond_equal_size(c1, c2): + if c1 is c2: + return True if c1.keys() != c2.keys(): return False if 'c_crossattn' in c1: @@ -69,6 +76,17 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_c return False return True + def can_concat_cond(c1, c2): + if c1[0].shape != c2[0].shape: + return False + if (c1[4] is None) != (c2[4] is None): + return False + if c1[4] is not None: + if c1[4] is not c2[4]: + return False + + return cond_equal_size(c1[2], c2[2]) + def cond_cat(c_list): c_crossattn = [] c_concat = [] @@ -84,7 +102,7 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_c out['c_concat'] = [torch.cat(c_concat)] return out - def calc_cond_uncond_batch(model_function, cond, uncond, x_in, sigma, max_total_area, cond_concat_in): + def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, cond_concat_in): out_cond = torch.zeros_like(x_in) out_count = torch.ones_like(x_in)/100000.0 @@ -96,13 +114,13 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_c to_run = [] for x in cond: - p = get_area_and_mult(x, x_in, cond_concat_in) + p = get_area_and_mult(x, x_in, cond_concat_in, timestep) if p is None: continue to_run += [(p, COND)] for x in uncond: - p = get_area_and_mult(x, x_in, cond_concat_in) + p = get_area_and_mult(x, x_in, cond_concat_in, timestep) if p is None: continue @@ -113,9 +131,8 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_c first_shape = first[0][0].shape to_batch_temp = [] for x in range(len(to_run)): - if to_run[x][0][0].shape == first_shape: - if cond_equal_size(to_run[x][0][2], first[0][2]): - to_batch_temp += [x] + if can_concat_cond(to_run[x][0], first[0]): + to_batch_temp += [x] to_batch_temp.reverse() to_batch = to_batch_temp[:1] @@ -131,6 +148,7 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_c c = [] cond_or_uncond = [] area = [] + control = None for x in to_batch: o = to_run.pop(x) p = o[0] @@ -139,13 +157,17 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_c c += [p[2]] area += [p[3]] cond_or_uncond += [o[1]] + control = p[4] batch_chunks = len(cond_or_uncond) input_x = torch.cat(input_x) c = cond_cat(c) - sigma_ = torch.cat([sigma] * batch_chunks) + timestep_ = torch.cat([timestep] * batch_chunks) - output = model_function(input_x, sigma_, cond=c).chunk(batch_chunks) + if control is not None: + c['control'] = control.get_control(input_x, timestep_, c['c_crossattn']) + + output = model_function(input_x, timestep_, cond=c).chunk(batch_chunks) del input_x for o in range(batch_chunks): @@ -166,10 +188,29 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_c max_total_area = model_management.maximum_batch_area() - cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, sigma, max_total_area, cond_concat) + cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, cond_concat) return uncond + (cond - uncond) * cond_scale -class CFGDenoiserComplex(torch.nn.Module): + +class CompVisVDenoiser(k_diffusion_external.DiscreteVDDPMDenoiser): + def __init__(self, model, quantize=False, device='cpu'): + super().__init__(model, model.alphas_cumprod, quantize=quantize) + + def get_v(self, x, t, cond, **kwargs): + return self.inner_model.apply_model(x, t, cond, **kwargs) + + +class CFGNoisePredictor(torch.nn.Module): + def __init__(self, model): + super().__init__() + self.inner_model = model + self.alphas_cumprod = model.alphas_cumprod + def apply_model(self, x, timestep, cond, uncond, cond_scale, cond_concat=None): + out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, cond_concat) + return out + + +class KSamplerX0Inpaint(torch.nn.Module): def __init__(self, model): super().__init__() self.inner_model = model @@ -177,7 +218,7 @@ class CFGDenoiserComplex(torch.nn.Module): if denoise_mask is not None: latent_mask = 1. - denoise_mask x = x * denoise_mask + (self.latent_image + self.noise * sigma) * latent_mask - out = sampling_function(self.inner_model, x, sigma, uncond, cond, cond_scale, cond_concat) + out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, cond_concat=cond_concat) if denoise_mask is not None: out *= denoise_mask @@ -196,8 +237,6 @@ def simple_scheduler(model, steps): def blank_inpaint_image_like(latent_image): blank_image = torch.ones_like(latent_image) # these are the values for "zero" in pixel space translated to latent space - # the proper way to do this is to apply the mask to the image in pixel space and then send it through the VAE - # unfortunately that gives zero flexibility so I did things like this instead which hopefully works blank_image[:,0] *= 0.8223 blank_image[:,1] *= -0.6876 blank_image[:,2] *= 0.6364 @@ -234,6 +273,42 @@ def create_cond_with_same_area_if_none(conds, c): n = c[1].copy() conds += [[smallest[0], n]] + +def apply_control_net_to_equal_area(conds, uncond): + cond_cnets = [] + cond_other = [] + uncond_cnets = [] + uncond_other = [] + for t in range(len(conds)): + x = conds[t] + if 'area' not in x[1]: + if 'control' in x[1] and x[1]['control'] is not None: + cond_cnets.append(x[1]['control']) + else: + cond_other.append((x, t)) + for t in range(len(uncond)): + x = uncond[t] + if 'area' not in x[1]: + if 'control' in x[1] and x[1]['control'] is not None: + uncond_cnets.append(x[1]['control']) + else: + uncond_other.append((x, t)) + + if len(uncond_cnets) > 0: + return + + for x in range(len(cond_cnets)): + temp = uncond_other[x % len(uncond_other)] + o = temp[0] + if 'control' in o[1] and o[1]['control'] is not None: + n = o[1].copy() + n['control'] = cond_cnets[x] + uncond += [[o[0], n]] + else: + n = o[1].copy() + n['control'] = cond_cnets[x] + uncond[temp[1]] = [o[0], n] + class KSampler: SCHEDULERS = ["karras", "normal", "simple"] SAMPLERS = ["sample_euler", "sample_euler_ancestral", "sample_heun", "sample_dpm_2", "sample_dpm_2_ancestral", @@ -242,11 +317,13 @@ class KSampler: def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None): self.model = model + self.model_denoise = CFGNoisePredictor(self.model) if self.model.parameterization == "v": - self.model_wrap = k_diffusion_external.CompVisVDenoiser(self.model, quantize=True) + self.model_wrap = CompVisVDenoiser(self.model_denoise, quantize=True) else: - self.model_wrap = k_diffusion_external.CompVisDenoiser(self.model, quantize=True) - self.model_k = CFGDenoiserComplex(self.model_wrap) + self.model_wrap = k_diffusion_external.CompVisDenoiser(self.model_denoise, quantize=True) + self.model_wrap.parameterization = self.model.parameterization + self.model_k = KSamplerX0Inpaint(self.model_wrap) self.device = device if scheduler not in self.SCHEDULERS: scheduler = self.SCHEDULERS[0] @@ -316,6 +393,8 @@ class KSampler: for c in negative: create_cond_with_same_area_if_none(positive, c) + apply_control_net_to_equal_area(positive, negative) + if self.model.model.diffusion_model.dtype == torch.float16: precision_scope = torch.autocast else: diff --git a/comfy/sd.py b/comfy/sd.py index a3c0066d..d37e5316 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -6,6 +6,9 @@ import model_management from ldm.util import instantiate_from_config from ldm.models.autoencoder import AutoencoderKL from omegaconf import OmegaConf +from .cldm import cldm + +from . import utils def load_torch_file(ckpt): if ckpt.lower().endswith(".safetensors"): @@ -323,6 +326,79 @@ class VAE: samples = samples.cpu() return samples +class ControlNet: + def __init__(self, control_model): + self.control_model = control_model + self.cond_hint_original = None + self.cond_hint = None + + def get_control(self, x_noisy, t, cond_txt): + if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]: + if self.cond_hint is not None: + del self.cond_hint + self.cond_hint = None + self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(x_noisy.device) + print("set cond_hint", self.cond_hint.shape) + control = self.control_model(x=x_noisy, hint=self.cond_hint, timesteps=t, context=cond_txt) + return control + + def set_cond_hint(self, cond_hint): + self.cond_hint_original = cond_hint + return self + + def cleanup(self): + if self.cond_hint is not None: + del self.cond_hint + self.cond_hint = None + + def copy(self): + c = ControlNet(self.control_model) + c.cond_hint_original = self.cond_hint_original + return c + +def load_controlnet(ckpt_path): + controlnet_data = load_torch_file(ckpt_path) + pth_key = 'control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight' + pth = False + sd2 = False + key = 'input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight' + if pth_key in controlnet_data: + pth = True + key = pth_key + elif key in controlnet_data: + pass + else: + print("error checkpoint does not contain controlnet data", ckpt_path) + return None + + context_dim = controlnet_data[key].shape[1] + control_model = cldm.ControlNet(image_size=32, + in_channels=4, + hint_channels=3, + model_channels=320, + attention_resolutions=[ 4, 2, 1 ], + num_res_blocks=2, + channel_mult=[ 1, 2, 4, 4 ], + num_heads=8, + use_spatial_transformer=True, + transformer_depth=1, + context_dim=context_dim, + use_checkpoint=True, + legacy=False) + + if pth: + class WeightsLoader(torch.nn.Module): + pass + w = WeightsLoader() + w.control_model = control_model + w.load_state_dict(controlnet_data, strict=False) + else: + control_model.load_state_dict(controlnet_data, strict=False) + + control = ControlNet(control_model) + return control + + def load_clip(ckpt_path, embedding_directory=None): clip_data = load_torch_file(ckpt_path) config = {} diff --git a/comfy/utils.py b/comfy/utils.py new file mode 100644 index 00000000..815e899c --- /dev/null +++ b/comfy/utils.py @@ -0,0 +1,18 @@ +import torch + +def common_upscale(samples, width, height, upscale_method, crop): + if crop == "center": + old_width = samples.shape[3] + old_height = samples.shape[2] + old_aspect = old_width / old_height + new_aspect = width / height + x = 0 + y = 0 + if old_aspect > new_aspect: + x = round((old_width - old_width * (new_aspect / old_aspect)) / 2) + elif old_aspect < new_aspect: + y = round((old_height - old_height * (old_aspect / new_aspect)) / 2) + s = samples[:,:,y:old_height-y,x:old_width-x] + else: + s = samples + return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method) diff --git a/nodes.py b/nodes.py index 5784ba3d..9aec9235 100644 --- a/nodes.py +++ b/nodes.py @@ -15,10 +15,12 @@ sys.path.insert(0, os.path.join(sys.path[0], "comfy")) import comfy.samplers import comfy.sd +import comfy.utils + import model_management -supported_ckpt_extensions = ['.ckpt'] -supported_pt_extensions = ['.ckpt', '.pt', '.bin'] +supported_ckpt_extensions = ['.ckpt', '.pth'] +supported_pt_extensions = ['.ckpt', '.pt', '.bin', '.pth'] try: import safetensors.torch supported_ckpt_extensions += ['.safetensors'] @@ -77,12 +79,14 @@ class ConditioningSetArea: CATEGORY = "conditioning" def append(self, conditioning, width, height, x, y, strength, min_sigma=0.0, max_sigma=99.0): - c = copy.deepcopy(conditioning) - for t in c: - t[1]['area'] = (height // 8, width // 8, y // 8, x // 8) - t[1]['strength'] = strength - t[1]['min_sigma'] = min_sigma - t[1]['max_sigma'] = max_sigma + c = [] + for t in conditioning: + n = [t[0], t[1].copy()] + n[1]['area'] = (height // 8, width // 8, y // 8, x // 8) + n[1]['strength'] = strength + n[1]['min_sigma'] = min_sigma + n[1]['max_sigma'] = max_sigma + c.append(n) return (c, ) class VAEDecode: @@ -134,7 +138,6 @@ class VAEEncodeForInpaint: CATEGORY = "latent/inpaint" def encode(self, vae, pixels, mask): - print(pixels.shape, mask.shape) x = (pixels.shape[1] // 64) * 64 y = (pixels.shape[2] // 64) * 64 if pixels.shape[1] != x or pixels.shape[2] != y: @@ -144,7 +147,6 @@ class VAEEncodeForInpaint: #shave off a few pixels to keep things seamless kernel_tensor = torch.ones((1, 1, 6, 6)) mask_erosion = torch.clamp(torch.nn.functional.conv2d((1.0 - mask.round())[None], kernel_tensor, padding=3), 0, 1) - print(mask_erosion.shape, pixels.shape) for i in range(3): pixels[:,:,:,i] -= 0.5 pixels[:,:,:,i] *= mask_erosion[0][:x,:y].round() @@ -211,6 +213,44 @@ class VAELoader: vae = comfy.sd.VAE(ckpt_path=vae_path) return (vae,) +class ControlNetLoader: + models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") + controlnet_dir = os.path.join(models_dir, "controlnet") + @classmethod + def INPUT_TYPES(s): + return {"required": { "control_net_name": (filter_files_extensions(recursive_search(s.controlnet_dir), supported_pt_extensions), )}} + + RETURN_TYPES = ("CONTROL_NET",) + FUNCTION = "load_controlnet" + + CATEGORY = "loaders" + + def load_controlnet(self, control_net_name): + controlnet_path = os.path.join(self.controlnet_dir, control_net_name) + controlnet = comfy.sd.load_controlnet(controlnet_path) + return (controlnet,) + + +class ControlNetApply: + @classmethod + def INPUT_TYPES(s): + return {"required": {"conditioning": ("CONDITIONING", ), "control_net": ("CONTROL_NET", ), "image": ("IMAGE", )}} + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "apply_controlnet" + + CATEGORY = "conditioning" + + def apply_controlnet(self, conditioning, control_net, image): + c = [] + control_hint = image.movedim(-1,1) + print(control_hint.shape) + for t in conditioning: + n = [t[0], t[1].copy()] + n[1]['control'] = control_net.copy().set_cond_hint(control_hint) + c.append(n) + return (c, ) + + class CLIPLoader: models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") clip_dir = os.path.join(models_dir, "clip") @@ -248,22 +288,7 @@ class EmptyLatentImage: latent = torch.zeros([batch_size, 4, height // 8, width // 8]) return ({"samples":latent}, ) -def common_upscale(samples, width, height, upscale_method, crop): - if crop == "center": - old_width = samples.shape[3] - old_height = samples.shape[2] - old_aspect = old_width / old_height - new_aspect = width / height - x = 0 - y = 0 - if old_aspect > new_aspect: - x = round((old_width - old_width * (new_aspect / old_aspect)) / 2) - elif old_aspect < new_aspect: - y = round((old_height - old_height * (old_aspect / new_aspect)) / 2) - s = samples[:,:,y:old_height-y,x:old_width-x] - else: - s = samples - return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method) + class LatentUpscale: upscale_methods = ["nearest-exact", "bilinear", "area"] @@ -282,7 +307,7 @@ class LatentUpscale: def upscale(self, samples, upscale_method, width, height, crop): s = samples.copy() - s["samples"] = common_upscale(samples["samples"], width // 8, height // 8, upscale_method, crop) + s["samples"] = comfy.utils.common_upscale(samples["samples"], width // 8, height // 8, upscale_method, crop) return (s,) class LatentRotate: @@ -461,19 +486,26 @@ def common_ksampler(device, model, seed, steps, cfg, sampler_name, scheduler, po positive_copy = [] negative_copy = [] + control_nets = [] for p in positive: t = p[0] if t.shape[0] < noise.shape[0]: t = torch.cat([t] * noise.shape[0]) t = t.to(device) + if 'control' in p[1]: + control_nets += [p[1]['control']] positive_copy += [[t] + p[1:]] for n in negative: t = n[0] if t.shape[0] < noise.shape[0]: t = torch.cat([t] * noise.shape[0]) t = t.to(device) + if 'control' in p[1]: + control_nets += [p[1]['control']] negative_copy += [[t] + n[1:]] + model_management.load_controlnet_gpu(list(map(lambda a: a.control_model, control_nets))) + if sampler_name in comfy.samplers.KSampler.SAMPLERS: sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise) else: @@ -482,6 +514,9 @@ def common_ksampler(device, model, seed, steps, cfg, sampler_name, scheduler, po samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask) samples = samples.cpu() + for c in control_nets: + c.cleanup() + out = latent.copy() out["samples"] = samples return (out, ) @@ -676,7 +711,7 @@ class ImageScale: def upscale(self, image, upscale_method, width, height, crop): samples = image.movedim(-1,1) - s = common_upscale(samples, width, height, upscale_method, crop) + s = comfy.utils.common_upscale(samples, width, height, upscale_method, crop) s = s.movedim(1,-1) return (s,) @@ -704,6 +739,8 @@ NODE_CLASS_MAPPINGS = { "LatentCrop": LatentCrop, "LoraLoader": LoraLoader, "CLIPLoader": CLIPLoader, + "ControlNetApply": ControlNetApply, + "ControlNetLoader": ControlNetLoader, } From 56498d505a7a518ca150cd15a382d5273d80bcee Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 16 Feb 2023 10:50:30 -0500 Subject: [PATCH 08/11] Create controlnet directory. --- models/controlnet/put_controlnets_here | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 models/controlnet/put_controlnets_here diff --git a/models/controlnet/put_controlnets_here b/models/controlnet/put_controlnets_here new file mode 100644 index 00000000..e69de29b From 0787299610c017949d98b5d7d313e0937636e500 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Thu, 16 Feb 2023 18:19:26 +0000 Subject: [PATCH 09/11] Hide Triton warning on Windows Triton doesnt work on Windows so hide the error --- main.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/main.py b/main.py index c5f04aa9..f5aec442 100644 --- a/main.py +++ b/main.py @@ -7,6 +7,10 @@ import heapq import traceback import asyncio +if os.name == "nt": + import logging + logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) + try: import aiohttp from aiohttp import web From 6135a21ee813bd7bcb11fdfd9f363b469d5dabe1 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 16 Feb 2023 18:08:01 -0500 Subject: [PATCH 10/11] Add a way to control controlnet strength. --- comfy/sd.py | 7 ++++++- nodes.py | 10 +++++++--- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index d37e5316..61a01dea 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -331,6 +331,7 @@ class ControlNet: self.control_model = control_model self.cond_hint_original = None self.cond_hint = None + self.strength = 1.0 def get_control(self, x_noisy, t, cond_txt): if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]: @@ -340,10 +341,13 @@ class ControlNet: self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(x_noisy.device) print("set cond_hint", self.cond_hint.shape) control = self.control_model(x=x_noisy, hint=self.cond_hint, timesteps=t, context=cond_txt) + for x in control: + x *= self.strength return control - def set_cond_hint(self, cond_hint): + def set_cond_hint(self, cond_hint, strength=1.0): self.cond_hint_original = cond_hint + self.strength = strength return self def cleanup(self): @@ -354,6 +358,7 @@ class ControlNet: def copy(self): c = ControlNet(self.control_model) c.cond_hint_original = self.cond_hint_original + c.strength = self.strength return c def load_controlnet(ckpt_path): diff --git a/nodes.py b/nodes.py index 9aec9235..be3952ac 100644 --- a/nodes.py +++ b/nodes.py @@ -234,19 +234,23 @@ class ControlNetLoader: class ControlNetApply: @classmethod def INPUT_TYPES(s): - return {"required": {"conditioning": ("CONDITIONING", ), "control_net": ("CONTROL_NET", ), "image": ("IMAGE", )}} + return {"required": {"conditioning": ("CONDITIONING", ), + "control_net": ("CONTROL_NET", ), + "image": ("IMAGE", ), + "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}) + }} RETURN_TYPES = ("CONDITIONING",) FUNCTION = "apply_controlnet" CATEGORY = "conditioning" - def apply_controlnet(self, conditioning, control_net, image): + def apply_controlnet(self, conditioning, control_net, image, strength): c = [] control_hint = image.movedim(-1,1) print(control_hint.shape) for t in conditioning: n = [t[0], t[1].copy()] - n[1]['control'] = control_net.copy().set_cond_hint(control_hint) + n[1]['control'] = control_net.copy().set_cond_hint(control_hint, strength) c.append(n) return (c, ) From 638f924546615711e504b0c3d92b095ee92c6fab Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 16 Feb 2023 21:01:46 -0500 Subject: [PATCH 11/11] Sort the list of input images. --- nodes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nodes.py b/nodes.py index be3952ac..547e6fbc 100644 --- a/nodes.py +++ b/nodes.py @@ -640,7 +640,7 @@ class LoadImage: @classmethod def INPUT_TYPES(s): return {"required": - {"image": (os.listdir(s.input_dir), )}, + {"image": (sorted(os.listdir(s.input_dir)), )}, } CATEGORY = "image"