ComfyUI/nodes.py

905 lines
33 KiB
Python
Raw Normal View History

2023-01-03 06:53:32 +00:00
import torch
import os
import sys
import json
import hashlib
import copy
import traceback
2023-01-03 06:53:32 +00:00
from PIL import Image
from PIL.PngImagePlugin import PngInfo
import numpy as np
sys.path.insert(0, os.path.join(sys.path[0], "comfy"))
2023-01-03 06:53:32 +00:00
import comfy.samplers
import comfy.sd
2023-02-16 15:38:08 +00:00
import comfy.utils
import model_management
2023-02-15 14:48:10 +00:00
import importlib
2023-01-03 06:53:32 +00:00
2023-02-16 15:38:08 +00:00
supported_ckpt_extensions = ['.ckpt', '.pth']
supported_pt_extensions = ['.ckpt', '.pt', '.bin', '.pth']
2023-01-03 06:53:32 +00:00
try:
import safetensors.torch
supported_ckpt_extensions += ['.safetensors']
2023-01-28 17:28:29 +00:00
supported_pt_extensions += ['.safetensors']
2023-01-03 06:53:32 +00:00
except:
print("Could not import safetensors, safetensors support disabled.")
def recursive_search(directory):
result = []
for root, subdir, file in os.walk(directory, followlinks=True):
for filepath in file:
#we os.path,join directory with a blank string to generate a path separator at the end.
result.append(os.path.join(root, filepath).replace(os.path.join(directory,''),''))
return result
2023-01-03 06:53:32 +00:00
def filter_files_extensions(files, extensions):
return sorted(list(filter(lambda a: os.path.splitext(a)[-1].lower() in extensions, files)))
def before_node_execution():
model_management.throw_exception_if_processing_interrupted()
def interrupt_processing(value=True):
model_management.interrupt_current_processing(value)
2023-01-03 06:53:32 +00:00
class CLIPTextEncode:
@classmethod
def INPUT_TYPES(s):
return {"required": {"text": ("STRING", {"multiline": True, "dynamic_prompt": True}), "clip": ("CLIP", )}}
2023-01-03 06:53:32 +00:00
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "encode"
2023-01-26 17:23:15 +00:00
CATEGORY = "conditioning"
2023-01-03 06:53:32 +00:00
def encode(self, clip, text):
return ([[clip.encode(text), {}]], )
class ConditioningCombine:
@classmethod
def INPUT_TYPES(s):
return {"required": {"conditioning_1": ("CONDITIONING", ), "conditioning_2": ("CONDITIONING", )}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "combine"
2023-01-26 17:23:15 +00:00
CATEGORY = "conditioning"
def combine(self, conditioning_1, conditioning_2):
return (conditioning_1 + conditioning_2, )
class ConditioningSetArea:
@classmethod
def INPUT_TYPES(s):
return {"required": {"conditioning": ("CONDITIONING", ),
"width": ("INT", {"default": 64, "min": 64, "max": 4096, "step": 64}),
"height": ("INT", {"default": 64, "min": 64, "max": 4096, "step": 64}),
"x": ("INT", {"default": 0, "min": 0, "max": 4096, "step": 64}),
"y": ("INT", {"default": 0, "min": 0, "max": 4096, "step": 64}),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "append"
2023-01-26 17:23:15 +00:00
CATEGORY = "conditioning"
def append(self, conditioning, width, height, x, y, strength, min_sigma=0.0, max_sigma=99.0):
2023-02-16 15:38:08 +00:00
c = []
for t in conditioning:
n = [t[0], t[1].copy()]
n[1]['area'] = (height // 8, width // 8, y // 8, x // 8)
n[1]['strength'] = strength
n[1]['min_sigma'] = min_sigma
n[1]['max_sigma'] = max_sigma
c.append(n)
return (c, )
2023-01-03 06:53:32 +00:00
class VAEDecode:
def __init__(self, device="cpu"):
self.device = device
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "decode"
2023-01-26 17:23:15 +00:00
CATEGORY = "latent"
2023-01-03 06:53:32 +00:00
def decode(self, vae, samples):
return (vae.decode(samples["samples"]), )
2023-01-03 06:53:32 +00:00
class VAEDecodeTiled:
def __init__(self, device="cpu"):
self.device = device
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "decode"
CATEGORY = "_for_testing"
def decode(self, vae, samples):
return (vae.decode_tiled(samples["samples"]), )
2023-01-03 06:53:32 +00:00
class VAEEncode:
def __init__(self, device="cpu"):
self.device = device
@classmethod
def INPUT_TYPES(s):
return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", )}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "encode"
2023-01-26 17:23:15 +00:00
CATEGORY = "latent"
2023-01-03 06:53:32 +00:00
def encode(self, vae, pixels):
x = (pixels.shape[1] // 64) * 64
y = (pixels.shape[2] // 64) * 64
if pixels.shape[1] != x or pixels.shape[2] != y:
pixels = pixels[:,:x,:y,:]
t = vae.encode(pixels[:,:,:,:3])
return ({"samples":t}, )
2023-01-03 06:53:32 +00:00
class VAEEncodeForInpaint:
def __init__(self, device="cpu"):
self.device = device
@classmethod
def INPUT_TYPES(s):
return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", ), "mask": ("MASK", )}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "encode"
CATEGORY = "latent/inpaint"
def encode(self, vae, pixels, mask):
x = (pixels.shape[1] // 64) * 64
y = (pixels.shape[2] // 64) * 64
2023-02-27 17:02:23 +00:00
mask = torch.nn.functional.interpolate(mask[None,None,], size=(pixels.shape[1], pixels.shape[2]), mode="bilinear")[0][0]
if pixels.shape[1] != x or pixels.shape[2] != y:
pixels = pixels[:,:x,:y,:]
mask = mask[:x,:y]
2023-02-27 17:02:23 +00:00
#grow mask by a few pixels to keep things seamless in latent space
kernel_tensor = torch.ones((1, 1, 6, 6))
2023-02-27 17:02:23 +00:00
mask_erosion = torch.clamp(torch.nn.functional.conv2d((mask.round())[None], kernel_tensor, padding=3), 0, 1)
m = (1.0 - mask.round())
for i in range(3):
pixels[:,:,:,i] -= 0.5
2023-02-27 17:02:23 +00:00
pixels[:,:,:,i] *= m
pixels[:,:,:,i] += 0.5
t = vae.encode(pixels)
2023-02-27 17:02:23 +00:00
return ({"samples":t, "noise_mask": (mask_erosion[0][:x,:y].round())}, )
2023-01-03 06:53:32 +00:00
class CheckpointLoader:
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
config_dir = os.path.join(models_dir, "configs")
ckpt_dir = os.path.join(models_dir, "checkpoints")
embedding_directory = os.path.join(models_dir, "embeddings")
2023-01-03 06:53:32 +00:00
@classmethod
def INPUT_TYPES(s):
return {"required": { "config_name": (filter_files_extensions(recursive_search(s.config_dir), '.yaml'), ),
"ckpt_name": (filter_files_extensions(recursive_search(s.ckpt_dir), supported_ckpt_extensions), )}}
2023-01-03 06:53:32 +00:00
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
FUNCTION = "load_checkpoint"
2023-01-26 17:23:15 +00:00
CATEGORY = "loaders"
2023-01-03 06:53:32 +00:00
def load_checkpoint(self, config_name, ckpt_name, output_vae=True, output_clip=True):
config_path = os.path.join(self.config_dir, config_name)
ckpt_path = os.path.join(self.ckpt_dir, ckpt_name)
return comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=self.embedding_directory)
2023-01-03 06:53:32 +00:00
class CheckpointLoaderSimple:
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
ckpt_dir = os.path.join(models_dir, "checkpoints")
@classmethod
def INPUT_TYPES(s):
return {"required": { "ckpt_name": (filter_files_extensions(recursive_search(s.ckpt_dir), supported_ckpt_extensions), ),
}}
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
FUNCTION = "load_checkpoint"
CATEGORY = "loaders"
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
ckpt_path = os.path.join(self.ckpt_dir, ckpt_name)
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=CheckpointLoader.embedding_directory)
return out
class CLIPSetLastLayer:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip": ("CLIP", ),
"stop_at_clip_layer": ("INT", {"default": -1, "min": -24, "max": -1, "step": 1}),
}}
RETURN_TYPES = ("CLIP",)
FUNCTION = "set_last_layer"
CATEGORY = "conditioning"
def set_last_layer(self, clip, stop_at_clip_layer):
clip = clip.clone()
clip.clip_layer(stop_at_clip_layer)
return (clip,)
class LoraLoader:
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
lora_dir = os.path.join(models_dir, "loras")
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"clip": ("CLIP", ),
"lora_name": (filter_files_extensions(recursive_search(s.lora_dir), supported_pt_extensions), ),
"strength_model": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"strength_clip": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL", "CLIP")
FUNCTION = "load_lora"
CATEGORY = "loaders"
def load_lora(self, model, clip, lora_name, strength_model, strength_clip):
lora_path = os.path.join(self.lora_dir, lora_name)
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora_path, strength_model, strength_clip)
return (model_lora, clip_lora)
2023-01-03 06:53:32 +00:00
class VAELoader:
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
vae_dir = os.path.join(models_dir, "vae")
@classmethod
def INPUT_TYPES(s):
return {"required": { "vae_name": (filter_files_extensions(recursive_search(s.vae_dir), supported_pt_extensions), )}}
2023-01-03 06:53:32 +00:00
RETURN_TYPES = ("VAE",)
FUNCTION = "load_vae"
2023-01-26 17:23:15 +00:00
CATEGORY = "loaders"
2023-01-03 06:53:32 +00:00
#TODO: scale factor?
def load_vae(self, vae_name):
vae_path = os.path.join(self.vae_dir, vae_name)
vae = comfy.sd.VAE(ckpt_path=vae_path)
return (vae,)
2023-02-16 15:38:08 +00:00
class ControlNetLoader:
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
controlnet_dir = os.path.join(models_dir, "controlnet")
@classmethod
def INPUT_TYPES(s):
return {"required": { "control_net_name": (filter_files_extensions(recursive_search(s.controlnet_dir), supported_pt_extensions), )}}
RETURN_TYPES = ("CONTROL_NET",)
FUNCTION = "load_controlnet"
CATEGORY = "loaders"
def load_controlnet(self, control_net_name):
controlnet_path = os.path.join(self.controlnet_dir, control_net_name)
controlnet = comfy.sd.load_controlnet(controlnet_path)
return (controlnet,)
2023-02-23 04:22:03 +00:00
class DiffControlNetLoader:
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
controlnet_dir = os.path.join(models_dir, "controlnet")
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"control_net_name": (filter_files_extensions(recursive_search(s.controlnet_dir), supported_pt_extensions), )}}
RETURN_TYPES = ("CONTROL_NET",)
FUNCTION = "load_controlnet"
CATEGORY = "loaders"
def load_controlnet(self, model, control_net_name):
controlnet_path = os.path.join(self.controlnet_dir, control_net_name)
controlnet = comfy.sd.load_controlnet(controlnet_path, model)
return (controlnet,)
2023-02-16 15:38:08 +00:00
class ControlNetApply:
@classmethod
def INPUT_TYPES(s):
return {"required": {"conditioning": ("CONDITIONING", ),
"control_net": ("CONTROL_NET", ),
"image": ("IMAGE", ),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01})
}}
2023-02-16 15:38:08 +00:00
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "apply_controlnet"
CATEGORY = "conditioning"
def apply_controlnet(self, conditioning, control_net, image, strength):
2023-02-16 15:38:08 +00:00
c = []
control_hint = image.movedim(-1,1)
print(control_hint.shape)
for t in conditioning:
n = [t[0], t[1].copy()]
c_net = control_net.copy().set_cond_hint(control_hint, strength)
if 'control' in t[1]:
c_net.set_previous_controlnet(t[1]['control'])
n[1]['control'] = c_net
2023-02-16 15:38:08 +00:00
c.append(n)
return (c, )
class T2IAdapterLoader:
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
t2i_adapter_dir = os.path.join(models_dir, "t2i_adapter")
@classmethod
def INPUT_TYPES(s):
return {"required": { "t2i_adapter_name": (filter_files_extensions(recursive_search(s.t2i_adapter_dir), supported_pt_extensions), )}}
RETURN_TYPES = ("CONTROL_NET",)
FUNCTION = "load_t2i_adapter"
CATEGORY = "loaders"
def load_t2i_adapter(self, t2i_adapter_name):
t2i_path = os.path.join(self.t2i_adapter_dir, t2i_adapter_name)
t2i_adapter = comfy.sd.load_t2i_adapter(t2i_path)
return (t2i_adapter,)
2023-02-16 15:38:08 +00:00
class CLIPLoader:
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
clip_dir = os.path.join(models_dir, "clip")
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name": (filter_files_extensions(recursive_search(s.clip_dir), supported_pt_extensions), ),
"stop_at_clip_layer": ("INT", {"default": -1, "min": -24, "max": -1, "step": 1}),
}}
RETURN_TYPES = ("CLIP",)
FUNCTION = "load_clip"
CATEGORY = "loaders"
def load_clip(self, clip_name, stop_at_clip_layer):
clip_path = os.path.join(self.clip_dir, clip_name)
clip = comfy.sd.load_clip(ckpt_path=clip_path, embedding_directory=CheckpointLoader.embedding_directory)
clip.clip_layer(stop_at_clip_layer)
return (clip,)
2023-01-03 06:53:32 +00:00
class EmptyLatentImage:
def __init__(self, device="cpu"):
self.device = device
@classmethod
def INPUT_TYPES(s):
return {"required": { "width": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 64}),
"height": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 64}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 64})}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "generate"
2023-01-26 17:23:15 +00:00
CATEGORY = "latent"
2023-01-03 06:53:32 +00:00
def generate(self, width, height, batch_size=1):
latent = torch.zeros([batch_size, 4, height // 8, width // 8])
return ({"samples":latent}, )
2023-01-03 06:53:32 +00:00
2023-02-16 15:38:08 +00:00
2023-01-03 06:53:32 +00:00
class LatentUpscale:
upscale_methods = ["nearest-exact", "bilinear", "area"]
crop_methods = ["disabled", "center"]
2023-01-03 06:53:32 +00:00
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",), "upscale_method": (s.upscale_methods,),
"width": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 64}),
"height": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 64}),
"crop": (s.crop_methods,)}}
2023-01-03 06:53:32 +00:00
RETURN_TYPES = ("LATENT",)
FUNCTION = "upscale"
CATEGORY = "latent"
def upscale(self, samples, upscale_method, width, height, crop):
s = samples.copy()
2023-02-16 15:38:08 +00:00
s["samples"] = comfy.utils.common_upscale(samples["samples"], width // 8, height // 8, upscale_method, crop)
2023-01-03 06:53:32 +00:00
return (s,)
2023-01-31 07:28:07 +00:00
class LatentRotate:
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",),
"rotation": (["none", "90 degrees", "180 degrees", "270 degrees"],),
}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "rotate"
CATEGORY = "latent"
def rotate(self, samples, rotation):
s = samples.copy()
2023-01-31 07:28:07 +00:00
rotate_by = 0
if rotation.startswith("90"):
rotate_by = 1
elif rotation.startswith("180"):
rotate_by = 2
elif rotation.startswith("270"):
rotate_by = 3
s["samples"] = torch.rot90(samples["samples"], k=rotate_by, dims=[3, 2])
2023-01-31 07:28:07 +00:00
return (s,)
2023-01-31 08:28:38 +00:00
class LatentFlip:
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",),
"flip_method": (["x-axis: vertically", "y-axis: horizontally"],),
}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "flip"
CATEGORY = "latent"
def flip(self, samples, flip_method):
s = samples.copy()
2023-01-31 08:28:38 +00:00
if flip_method.startswith("x"):
s["samples"] = torch.flip(samples["samples"], dims=[2])
2023-01-31 08:28:38 +00:00
elif flip_method.startswith("y"):
s["samples"] = torch.flip(samples["samples"], dims=[3])
2023-01-31 08:28:38 +00:00
return (s,)
class LatentComposite:
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples_to": ("LATENT",),
"samples_from": ("LATENT",),
"x": ("INT", {"default": 0, "min": 0, "max": 4096, "step": 8}),
"y": ("INT", {"default": 0, "min": 0, "max": 4096, "step": 8}),
"feather": ("INT", {"default": 0, "min": 0, "max": 4096, "step": 8}),
}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "composite"
CATEGORY = "latent"
def composite(self, samples_to, samples_from, x, y, composite_method="normal", feather=0):
x = x // 8
y = y // 8
feather = feather // 8
samples_out = samples_to.copy()
s = samples_to["samples"].clone()
samples_to = samples_to["samples"]
samples_from = samples_from["samples"]
if feather == 0:
s[:,:,y:y+samples_from.shape[2],x:x+samples_from.shape[3]] = samples_from[:,:,:samples_to.shape[2] - y, :samples_to.shape[3] - x]
else:
samples_from = samples_from[:,:,:samples_to.shape[2] - y, :samples_to.shape[3] - x]
mask = torch.ones_like(samples_from)
for t in range(feather):
if y != 0:
mask[:,:,t:1+t,:] *= ((1.0/feather) * (t + 1))
if y + samples_from.shape[2] < samples_to.shape[2]:
mask[:,:,mask.shape[2] -1 -t: mask.shape[2]-t,:] *= ((1.0/feather) * (t + 1))
if x != 0:
mask[:,:,:,t:1+t] *= ((1.0/feather) * (t + 1))
if x + samples_from.shape[3] < samples_to.shape[3]:
mask[:,:,:,mask.shape[3]- 1 - t: mask.shape[3]- t] *= ((1.0/feather) * (t + 1))
rev_mask = torch.ones_like(mask) - mask
s[:,:,y:y+samples_from.shape[2],x:x+samples_from.shape[3]] = samples_from[:,:,:samples_to.shape[2] - y, :samples_to.shape[3] - x] * mask + s[:,:,y:y+samples_from.shape[2],x:x+samples_from.shape[3]] * rev_mask
samples_out["samples"] = s
return (samples_out,)
2023-02-04 20:21:46 +00:00
class LatentCrop:
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",),
"width": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 64}),
"height": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 64}),
"x": ("INT", {"default": 0, "min": 0, "max": 4096, "step": 8}),
"y": ("INT", {"default": 0, "min": 0, "max": 4096, "step": 8}),
}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "crop"
CATEGORY = "latent"
def crop(self, samples, width, height, x, y):
s = samples.copy()
samples = samples['samples']
2023-02-04 20:21:46 +00:00
x = x // 8
y = y // 8
#enfonce minimum size of 64
if x > (samples.shape[3] - 8):
x = samples.shape[3] - 8
if y > (samples.shape[2] - 8):
y = samples.shape[2] - 8
new_height = height // 8
new_width = width // 8
to_x = new_width + x
to_y = new_height + y
def enforce_image_dim(d, to_d, max_d):
if to_d > max_d:
leftover = (to_d - max_d) % 8
to_d = max_d
d -= leftover
return (d, to_d)
#make sure size is always multiple of 64
x, to_x = enforce_image_dim(x, to_x, samples.shape[3])
y, to_y = enforce_image_dim(y, to_y, samples.shape[2])
s['samples'] = samples[:,:,y:to_y, x:to_x]
2023-02-04 20:21:46 +00:00
return (s,)
class SetLatentNoiseMask:
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",),
"mask": ("MASK",),
}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "set_mask"
CATEGORY = "latent/inpaint"
def set_mask(self, samples, mask):
s = samples.copy()
s["noise_mask"] = mask
return (s,)
def common_ksampler(device, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False):
latent_image = latent["samples"]
noise_mask = None
if disable_noise:
noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
else:
noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=torch.manual_seed(seed), device="cpu")
if "noise_mask" in latent:
noise_mask = latent['noise_mask']
noise_mask = torch.nn.functional.interpolate(noise_mask[None,None,], size=(noise.shape[2], noise.shape[3]), mode="bilinear")
noise_mask = noise_mask.round()
noise_mask = torch.cat([noise_mask] * noise.shape[1], dim=1)
noise_mask = torch.cat([noise_mask] * noise.shape[0])
noise_mask = noise_mask.to(device)
real_model = None
if device != "cpu":
model_management.load_model_gpu(model)
real_model = model.model
else:
#TODO: cpu support
real_model = model.patch_model()
noise = noise.to(device)
latent_image = latent_image.to(device)
positive_copy = []
negative_copy = []
2023-02-16 15:38:08 +00:00
control_nets = []
for p in positive:
t = p[0]
if t.shape[0] < noise.shape[0]:
t = torch.cat([t] * noise.shape[0])
t = t.to(device)
2023-02-16 15:38:08 +00:00
if 'control' in p[1]:
control_nets += [p[1]['control']]
positive_copy += [[t] + p[1:]]
for n in negative:
t = n[0]
if t.shape[0] < noise.shape[0]:
t = torch.cat([t] * noise.shape[0])
t = t.to(device)
2023-02-16 15:38:08 +00:00
if 'control' in p[1]:
control_nets += [p[1]['control']]
negative_copy += [[t] + n[1:]]
control_net_models = []
for x in control_nets:
control_net_models += x.get_control_models()
model_management.load_controlnet_gpu(control_net_models)
2023-02-16 15:38:08 +00:00
if sampler_name in comfy.samplers.KSampler.SAMPLERS:
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise)
else:
#other samplers
pass
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask)
samples = samples.cpu()
2023-02-16 15:38:08 +00:00
for c in control_nets:
c.cleanup()
out = latent.copy()
out["samples"] = samples
return (out, )
2023-01-03 06:53:32 +00:00
class KSampler:
def __init__(self, device="cuda"):
self.device = device
@classmethod
def INPUT_TYPES(s):
return {"required":
2023-01-03 06:53:32 +00:00
{"model": ("MODEL",),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0}),
"sampler_name": (comfy.samplers.KSampler.SAMPLERS, ),
"scheduler": (comfy.samplers.KSampler.SCHEDULERS, ),
"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"latent_image": ("LATENT", ),
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "sample"
2023-01-26 17:23:15 +00:00
CATEGORY = "sampling"
2023-01-03 06:53:32 +00:00
def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0):
return common_ksampler(self.device, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise)
2023-01-03 06:53:32 +00:00
class KSamplerAdvanced:
def __init__(self, device="cuda"):
self.device = device
@classmethod
def INPUT_TYPES(s):
return {"required":
{"model": ("MODEL",),
"add_noise": (["enable", "disable"], ),
"noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0}),
"sampler_name": (comfy.samplers.KSampler.SAMPLERS, ),
"scheduler": (comfy.samplers.KSampler.SCHEDULERS, ),
"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"latent_image": ("LATENT", ),
"start_at_step": ("INT", {"default": 0, "min": 0, "max": 10000}),
"end_at_step": ("INT", {"default": 10000, "min": 0, "max": 10000}),
"return_with_leftover_noise": (["disable", "enable"], ),
}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "sample"
CATEGORY = "sampling"
2023-01-03 06:53:32 +00:00
def sample(self, model, add_noise, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, start_at_step, end_at_step, return_with_leftover_noise, denoise=1.0):
force_full_denoise = True
if return_with_leftover_noise == "enable":
force_full_denoise = False
disable_noise = False
if add_noise == "disable":
disable_noise = True
return common_ksampler(self.device, model, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, disable_noise=disable_noise, start_step=start_at_step, last_step=end_at_step, force_full_denoise=force_full_denoise)
2023-01-03 06:53:32 +00:00
class SaveImage:
def __init__(self):
self.output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output")
@classmethod
def INPUT_TYPES(s):
return {"required":
2023-01-24 07:17:18 +00:00
{"images": ("IMAGE", ),
"filename_prefix": ("STRING", {"default": "ComfyUI"})},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
2023-01-03 06:53:32 +00:00
}
RETURN_TYPES = ()
FUNCTION = "save_images"
OUTPUT_NODE = True
2023-01-26 17:23:15 +00:00
CATEGORY = "image"
def save_images(self, images, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
2023-01-24 07:17:18 +00:00
def map_filename(filename):
prefix_len = len(filename_prefix)
prefix = filename[:prefix_len + 1]
try:
digits = int(filename[prefix_len + 1:].split('_')[0])
except:
digits = 0
return (digits, prefix)
try:
counter = max(filter(lambda a: a[1][:-1] == filename_prefix and a[1][-1] == "_", map(map_filename, os.listdir(self.output_dir))))[0] + 1
except ValueError:
counter = 1
2023-02-09 17:32:23 +00:00
except FileNotFoundError:
os.mkdir(self.output_dir)
counter = 1
paths = list()
2023-01-03 06:53:32 +00:00
for image in images:
i = 255. * image.cpu().numpy()
img = Image.fromarray(i.astype(np.uint8))
metadata = PngInfo()
if prompt is not None:
metadata.add_text("prompt", json.dumps(prompt))
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata.add_text(x, json.dumps(extra_pnginfo[x]))
file = f"{filename_prefix}_{counter:05}_.png"
img.save(os.path.join(self.output_dir, file), pnginfo=metadata, optimize=True)
paths.append(file)
2023-01-24 07:17:18 +00:00
counter += 1
return { "ui": { "images": paths } }
2023-01-03 06:53:32 +00:00
class LoadImage:
input_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
@classmethod
def INPUT_TYPES(s):
return {"required":
2023-02-17 02:01:46 +00:00
{"image": (sorted(os.listdir(s.input_dir)), )},
}
2023-01-26 17:23:15 +00:00
CATEGORY = "image"
RETURN_TYPES = ("IMAGE",)
FUNCTION = "load_image"
def load_image(self, image):
image_path = os.path.join(self.input_dir, image)
i = Image.open(image_path)
image = i.convert("RGB")
image = np.array(image).astype(np.float32) / 255.0
image = torch.from_numpy(image)[None,]
return (image,)
@classmethod
def IS_CHANGED(s, image):
image_path = os.path.join(s.input_dir, image)
m = hashlib.sha256()
with open(image_path, 'rb') as f:
m.update(f.read())
return m.digest().hex()
class LoadImageMask:
input_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
@classmethod
def INPUT_TYPES(s):
return {"required":
2023-02-17 17:53:05 +00:00
{"image": (sorted(os.listdir(s.input_dir)), ),
"channel": (["alpha", "red", "green", "blue"], ),}
}
CATEGORY = "image"
RETURN_TYPES = ("MASK",)
FUNCTION = "load_image"
def load_image(self, image, channel):
image_path = os.path.join(self.input_dir, image)
i = Image.open(image_path)
mask = None
c = channel[0].upper()
if c in i.getbands():
mask = np.array(i.getchannel(c)).astype(np.float32) / 255.0
mask = torch.from_numpy(mask)
if c == 'A':
mask = 1. - mask
else:
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
return (mask,)
@classmethod
def IS_CHANGED(s, image, channel):
image_path = os.path.join(s.input_dir, image)
m = hashlib.sha256()
with open(image_path, 'rb') as f:
m.update(f.read())
return m.digest().hex()
class ImageScale:
upscale_methods = ["nearest-exact", "bilinear", "area"]
crop_methods = ["disabled", "center"]
@classmethod
def INPUT_TYPES(s):
return {"required": { "image": ("IMAGE",), "upscale_method": (s.upscale_methods,),
"width": ("INT", {"default": 512, "min": 1, "max": 4096, "step": 1}),
"height": ("INT", {"default": 512, "min": 1, "max": 4096, "step": 1}),
"crop": (s.crop_methods,)}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "upscale"
CATEGORY = "image"
def upscale(self, image, upscale_method, width, height, crop):
samples = image.movedim(-1,1)
2023-02-16 15:38:08 +00:00
s = comfy.utils.common_upscale(samples, width, height, upscale_method, crop)
s = s.movedim(1,-1)
return (s,)
2023-01-03 06:53:32 +00:00
class ImageInvert:
@classmethod
def INPUT_TYPES(s):
return {"required": { "image": ("IMAGE",)}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "invert"
CATEGORY = "image"
def invert(self, image):
s = 1.0 - image
return (s,)
2023-01-03 06:53:32 +00:00
NODE_CLASS_MAPPINGS = {
"KSampler": KSampler,
"CheckpointLoader": CheckpointLoader,
"CheckpointLoaderSimple": CheckpointLoaderSimple,
2023-01-03 06:53:32 +00:00
"CLIPTextEncode": CLIPTextEncode,
"CLIPSetLastLayer": CLIPSetLastLayer,
2023-01-03 06:53:32 +00:00
"VAEDecode": VAEDecode,
"VAEEncode": VAEEncode,
"VAEEncodeForInpaint": VAEEncodeForInpaint,
2023-01-03 06:53:32 +00:00
"VAELoader": VAELoader,
"EmptyLatentImage": EmptyLatentImage,
"LatentUpscale": LatentUpscale,
"SaveImage": SaveImage,
"LoadImage": LoadImage,
"LoadImageMask": LoadImageMask,
"ImageScale": ImageScale,
"ImageInvert": ImageInvert,
"ConditioningCombine": ConditioningCombine,
"ConditioningSetArea": ConditioningSetArea,
"KSamplerAdvanced": KSamplerAdvanced,
"SetLatentNoiseMask": SetLatentNoiseMask,
"LatentComposite": LatentComposite,
2023-01-31 07:28:07 +00:00
"LatentRotate": LatentRotate,
2023-01-31 08:28:38 +00:00
"LatentFlip": LatentFlip,
2023-02-04 20:21:46 +00:00
"LatentCrop": LatentCrop,
"LoraLoader": LoraLoader,
"CLIPLoader": CLIPLoader,
2023-02-16 15:38:08 +00:00
"ControlNetApply": ControlNetApply,
"ControlNetLoader": ControlNetLoader,
2023-02-23 04:22:03 +00:00
"DiffControlNetLoader": DiffControlNetLoader,
"T2IAdapterLoader": T2IAdapterLoader,
"VAEDecodeTiled": VAEDecodeTiled,
2023-01-03 06:53:32 +00:00
}
2023-02-14 15:20:30 +00:00
CUSTOM_NODE_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "custom_nodes")
2023-02-13 11:17:40 +00:00
def load_custom_nodes():
2023-02-14 15:20:30 +00:00
possible_modules = os.listdir(CUSTOM_NODE_PATH)
if "__pycache__" in possible_modules:
2023-02-17 10:59:16 +00:00
possible_modules.remove("__pycache__")
2023-02-13 11:17:40 +00:00
for possible_module in possible_modules:
2023-02-14 15:20:30 +00:00
module_path = os.path.join(CUSTOM_NODE_PATH, possible_module)
if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue
module_name = possible_module
2023-02-13 11:17:40 +00:00
try:
2023-02-15 14:48:10 +00:00
if os.path.isfile(module_path):
module_spec = importlib.util.spec_from_file_location(module_name, module_path)
2023-02-15 14:48:10 +00:00
else:
module_spec = importlib.util.spec_from_file_location(module_name, os.path.join(module_path, "__init__.py"))
2023-02-15 14:48:10 +00:00
module = importlib.util.module_from_spec(module_spec)
sys.modules[module_name] = module
2023-02-15 14:48:10 +00:00
module_spec.loader.exec_module(module)
if hasattr(module, "NODE_CLASS_MAPPINGS") and getattr(module, "NODE_CLASS_MAPPINGS") is not None:
2023-02-15 14:48:10 +00:00
NODE_CLASS_MAPPINGS.update(module.NODE_CLASS_MAPPINGS)
2023-02-13 11:17:40 +00:00
else:
2023-02-13 11:24:54 +00:00
print(f"Skip {possible_module} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS.")
except Exception as e:
print(traceback.format_exc())
print(f"Cannot import {possible_module} module for custom nodes:", e)
2023-02-13 11:17:40 +00:00
load_custom_nodes()