Merge remote-tracking branch 'origin/master' into tiled-progress

This commit is contained in:
pythongosssss 2023-05-03 17:33:42 +01:00
commit fdf57325f4
32 changed files with 900 additions and 356 deletions

View File

@ -1,65 +0,0 @@
import pygit2
from datetime import datetime
import sys
def pull(repo, remote_name='origin', branch='master'):
for remote in repo.remotes:
if remote.name == remote_name:
remote.fetch()
remote_master_id = repo.lookup_reference('refs/remotes/origin/%s' % (branch)).target
merge_result, _ = repo.merge_analysis(remote_master_id)
# Up to date, do nothing
if merge_result & pygit2.GIT_MERGE_ANALYSIS_UP_TO_DATE:
return
# We can just fastforward
elif merge_result & pygit2.GIT_MERGE_ANALYSIS_FASTFORWARD:
repo.checkout_tree(repo.get(remote_master_id))
try:
master_ref = repo.lookup_reference('refs/heads/%s' % (branch))
master_ref.set_target(remote_master_id)
except KeyError:
repo.create_branch(branch, repo.get(remote_master_id))
repo.head.set_target(remote_master_id)
elif merge_result & pygit2.GIT_MERGE_ANALYSIS_NORMAL:
repo.merge(remote_master_id)
if repo.index.conflicts is not None:
for conflict in repo.index.conflicts:
print('Conflicts found in:', conflict[0].path)
raise AssertionError('Conflicts, ahhhhh!!')
user = repo.default_signature
tree = repo.index.write_tree()
commit = repo.create_commit('HEAD',
user,
user,
'Merge!',
tree,
[repo.head.target, remote_master_id])
# We need to do this or git CLI will think we are still merging.
repo.state_cleanup()
else:
raise AssertionError('Unknown merge analysis result')
repo = pygit2.Repository(str(sys.argv[1]))
ident = pygit2.Signature('comfyui', 'comfy@ui')
try:
print("stashing current changes")
repo.stash(ident)
except KeyError:
print("nothing to stash")
backup_branch_name = 'backup_branch_{}'.format(datetime.today().strftime('%Y-%m-%d_%H_%M_%S'))
print("creating backup branch: {}".format(backup_branch_name))
repo.branches.local.create(backup_branch_name, repo.head.peel())
print("checking out master branch")
branch = repo.lookup_branch('master')
ref = repo.lookup_reference(branch.name)
repo.checkout(ref)
print("pulling latest changes")
pull(repo)
print("Done!")

View File

@ -1,2 +0,0 @@
..\python_embeded\python.exe .\update.py ..\ComfyUI\
pause

View File

@ -1,3 +1,3 @@
..\python_embeded\python.exe .\update.py ..\ComfyUI\ ..\python_embeded\python.exe .\update.py ..\ComfyUI\
..\python_embeded\python.exe -s -m pip install --upgrade --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu118 -r ../ComfyUI/requirements.txt pygit2 ..\python_embeded\python.exe -s -m pip install --upgrade --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu121 -r ../ComfyUI/requirements.txt pygit2
pause pause

View File

@ -1,27 +0,0 @@
HOW TO RUN:
if you have a NVIDIA gpu:
run_nvidia_gpu.bat
To run it in slow CPU mode:
run_cpu.bat
IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints
You can download the stable diffusion 1.5 one from: https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt
RECOMMENDED WAY TO UPDATE:
To update the ComfyUI code: update\update_comfyui.bat
To update ComfyUI with the python dependencies:
update\update_comfyui_and_python_dependencies.bat

View File

@ -1,2 +0,0 @@
.\python_embeded\python.exe -s ComfyUI\main.py --cpu --windows-standalone-build
pause

View File

@ -17,7 +17,7 @@ jobs:
- shell: bash - shell: bash
run: | run: |
python -m pip wheel --no-cache-dir torch torchvision torchaudio xformers==0.0.19.dev516 --extra-index-url https://download.pytorch.org/whl/cu118 -r requirements.txt pygit2 -w ./temp_wheel_dir python -m pip wheel --no-cache-dir torch torchvision torchaudio xformers --extra-index-url https://download.pytorch.org/whl/cu118 -r requirements.txt pygit2 -w ./temp_wheel_dir
python -m pip install --no-cache-dir ./temp_wheel_dir/* python -m pip install --no-cache-dir ./temp_wheel_dir/*
echo installed basic echo installed basic
ls -lah temp_wheel_dir ls -lah temp_wheel_dir

View File

@ -19,21 +19,21 @@ jobs:
fetch-depth: 0 fetch-depth: 0
- uses: actions/setup-python@v4 - uses: actions/setup-python@v4
with: with:
python-version: '3.10.9' python-version: '3.11.3'
- shell: bash - shell: bash
run: | run: |
cd .. cd ..
cp -r ComfyUI ComfyUI_copy cp -r ComfyUI ComfyUI_copy
curl https://www.python.org/ftp/python/3.10.9/python-3.10.9-embed-amd64.zip -o python_embeded.zip curl https://www.python.org/ftp/python/3.11.3/python-3.11.3-embed-amd64.zip -o python_embeded.zip
unzip python_embeded.zip -d python_embeded unzip python_embeded.zip -d python_embeded
cd python_embeded cd python_embeded
echo 'import site' >> ./python310._pth echo 'import site' >> ./python311._pth
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
./python.exe get-pip.py ./python.exe get-pip.py
python -m pip wheel torch torchvision torchaudio --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu118 -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir python -m pip wheel torch torchvision torchaudio --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu121 -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir
ls ../temp_wheel_dir ls ../temp_wheel_dir
./python.exe -s -m pip install --pre ../temp_wheel_dir/* ./python.exe -s -m pip install --pre ../temp_wheel_dir/*
sed -i '1i../ComfyUI' ./python310._pth sed -i '1i../ComfyUI' ./python311._pth
cd .. cd ..
@ -46,6 +46,8 @@ jobs:
mkdir update mkdir update
cp -r ComfyUI/.ci/update_windows/* ./update/ cp -r ComfyUI/.ci/update_windows/* ./update/
cp -r ComfyUI/.ci/windows_base_files/* ./ cp -r ComfyUI/.ci/windows_base_files/* ./
cp -r ComfyUI/.ci/nightly/update_windows/* ./update/
cp -r ComfyUI/.ci/nightly/windows_base_files/* ./
cd .. cd ..

View File

@ -7,6 +7,8 @@ A powerful and modular stable diffusion GUI and backend.
This ui will let you design and execute advanced stable diffusion pipelines using a graph/nodes/flowchart based interface. For some workflow examples and see what ComfyUI can do you can check out: This ui will let you design and execute advanced stable diffusion pipelines using a graph/nodes/flowchart based interface. For some workflow examples and see what ComfyUI can do you can check out:
### [ComfyUI Examples](https://comfyanonymous.github.io/ComfyUI_examples/) ### [ComfyUI Examples](https://comfyanonymous.github.io/ComfyUI_examples/)
### [Installing ComfyUI](#installing)
## Features ## Features
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything. - Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
- Fully supports SD1.x and SD2.x - Fully supports SD1.x and SD2.x
@ -17,6 +19,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
- Can load ckpt, safetensors and diffusers models/checkpoints. Standalone VAEs and CLIP models. - Can load ckpt, safetensors and diffusers models/checkpoints. Standalone VAEs and CLIP models.
- Embeddings/Textual inversion - Embeddings/Textual inversion
- [Loras (regular, locon and loha)](https://comfyanonymous.github.io/ComfyUI_examples/lora/) - [Loras (regular, locon and loha)](https://comfyanonymous.github.io/ComfyUI_examples/lora/)
- [Hypernetworks](https://comfyanonymous.github.io/ComfyUI_examples/hypernetworks/)
- Loading full workflows (with seeds) from generated PNG files. - Loading full workflows (with seeds) from generated PNG files.
- Saving/Loading workflows as Json files. - Saving/Loading workflows as Json files.
- Nodes interface can be used to create complex workflows like one for [Hires fix](https://comfyanonymous.github.io/ComfyUI_examples/2_pass_txt2img/) or much more advanced ones. - Nodes interface can be used to create complex workflows like one for [Hires fix](https://comfyanonymous.github.io/ComfyUI_examples/2_pass_txt2img/) or much more advanced ones.

View File

@ -10,6 +10,7 @@ parser.add_argument("--output-directory", type=str, default=None, help="Set the
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.") parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.") parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.")
parser.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).") parser.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
attn_group = parser.add_mutually_exclusive_group() attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.") attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.")

View File

@ -712,7 +712,7 @@ class UniPC:
def sample(self, x, timesteps, t_start=None, t_end=None, order=3, skip_type='time_uniform', def sample(self, x, timesteps, t_start=None, t_end=None, order=3, skip_type='time_uniform',
method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver', method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
atol=0.0078, rtol=0.05, corrector=False, atol=0.0078, rtol=0.05, corrector=False, callback=None, disable_pbar=False
): ):
t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
t_T = self.noise_schedule.T if t_start is None else t_start t_T = self.noise_schedule.T if t_start is None else t_start
@ -723,7 +723,7 @@ class UniPC:
# timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) # timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
assert timesteps.shape[0] - 1 == steps assert timesteps.shape[0] - 1 == steps
# with torch.no_grad(): # with torch.no_grad():
for step_index in trange(steps): for step_index in trange(steps, disable=disable_pbar):
if self.noise_mask is not None: if self.noise_mask is not None:
x = x * self.noise_mask + (1. - self.noise_mask) * (self.masked_image * self.noise_schedule.marginal_alpha(timesteps[step_index]) + self.noise * self.noise_schedule.marginal_std(timesteps[step_index])) x = x * self.noise_mask + (1. - self.noise_mask) * (self.masked_image * self.noise_schedule.marginal_alpha(timesteps[step_index]) + self.noise * self.noise_schedule.marginal_std(timesteps[step_index]))
if step_index == 0: if step_index == 0:
@ -766,6 +766,8 @@ class UniPC:
if model_x is None: if model_x is None:
model_x = self.model_fn(x, vec_t) model_x = self.model_fn(x, vec_t)
model_prev_list[-1] = model_x model_prev_list[-1] = model_x
if callback is not None:
callback(step_index, model_prev_list[-1], x)
else: else:
raise NotImplementedError() raise NotImplementedError()
if denoise_to_zero: if denoise_to_zero:
@ -833,7 +835,7 @@ def expand_dims(v, dims):
def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, extra_args=None, callback=None, disable=None, noise_mask=None, variant='bh1'): def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, extra_args=None, callback=None, disable=False, noise_mask=None, variant='bh1'):
to_zero = False to_zero = False
if sigmas[-1] == 0: if sigmas[-1] == 0:
timesteps = torch.nn.functional.interpolate(sigmas[None,None,:-1], size=(len(sigmas),), mode='linear')[0][0] timesteps = torch.nn.functional.interpolate(sigmas[None,None,:-1], size=(len(sigmas),), mode='linear')[0][0]
@ -877,7 +879,7 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex
order = min(3, len(timesteps) - 1) order = min(3, len(timesteps) - 1)
uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, noise_mask=noise_mask, masked_image=image, noise=noise, variant=variant) uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, noise_mask=noise_mask, masked_image=image, noise=noise, variant=variant)
x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True) x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True, callback=callback, disable_pbar=disable)
if not to_zero: if not to_zero:
x /= ns.marginal_alpha(timesteps[-1]) x /= ns.marginal_alpha(timesteps[-1])
return x return x

View File

@ -81,6 +81,7 @@ class DDIMSampler(object):
extra_args=None, extra_args=None,
to_zero=True, to_zero=True,
end_step=None, end_step=None,
disable_pbar=False,
**kwargs **kwargs
): ):
self.make_schedule_timesteps(ddim_timesteps=ddim_timesteps, ddim_eta=eta, verbose=verbose) self.make_schedule_timesteps(ddim_timesteps=ddim_timesteps, ddim_eta=eta, verbose=verbose)
@ -103,7 +104,8 @@ class DDIMSampler(object):
denoise_function=denoise_function, denoise_function=denoise_function,
extra_args=extra_args, extra_args=extra_args,
to_zero=to_zero, to_zero=to_zero,
end_step=end_step end_step=end_step,
disable_pbar=disable_pbar
) )
return samples, intermediates return samples, intermediates
@ -185,7 +187,7 @@ class DDIMSampler(object):
mask=None, x0=None, img_callback=None, log_every_t=100, mask=None, x0=None, img_callback=None, log_every_t=100,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None, unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
ucg_schedule=None, denoise_function=None, extra_args=None, to_zero=True, end_step=None): ucg_schedule=None, denoise_function=None, extra_args=None, to_zero=True, end_step=None, disable_pbar=False):
device = self.model.betas.device device = self.model.betas.device
b = shape[0] b = shape[0]
if x_T is None: if x_T is None:
@ -204,7 +206,7 @@ class DDIMSampler(object):
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
# print(f"Running DDIM Sampling with {total_steps} timesteps") # print(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range[:end_step], desc='DDIM Sampler', total=end_step) iterator = tqdm(time_range[:end_step], desc='DDIM Sampler', total=end_step, disable=disable_pbar)
for i, step in enumerate(iterator): for i, step in enumerate(iterator):
index = total_steps - i - 1 index = total_steps - i - 1

View File

@ -76,12 +76,14 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
support it as an extra input. support it as an extra input.
""" """
def forward(self, x, emb, context=None, transformer_options={}): def forward(self, x, emb, context=None, transformer_options={}, output_shape=None):
for layer in self: for layer in self:
if isinstance(layer, TimestepBlock): if isinstance(layer, TimestepBlock):
x = layer(x, emb) x = layer(x, emb)
elif isinstance(layer, SpatialTransformer): elif isinstance(layer, SpatialTransformer):
x = layer(x, context, transformer_options) x = layer(x, context, transformer_options)
elif isinstance(layer, Upsample):
x = layer(x, output_shape=output_shape)
else: else:
x = layer(x) x = layer(x)
return x return x
@ -105,14 +107,20 @@ class Upsample(nn.Module):
if use_conv: if use_conv:
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
def forward(self, x): def forward(self, x, output_shape=None):
assert x.shape[1] == self.channels assert x.shape[1] == self.channels
if self.dims == 3: if self.dims == 3:
x = F.interpolate( shape = [x.shape[2], x.shape[3] * 2, x.shape[4] * 2]
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" if output_shape is not None:
) shape[1] = output_shape[3]
shape[2] = output_shape[4]
else: else:
x = F.interpolate(x, scale_factor=2, mode="nearest") shape = [x.shape[2] * 2, x.shape[3] * 2]
if output_shape is not None:
shape[0] = output_shape[2]
shape[1] = output_shape[3]
x = F.interpolate(x, size=shape, mode="nearest")
if self.use_conv: if self.use_conv:
x = self.conv(x) x = self.conv(x)
return x return x
@ -813,9 +821,14 @@ class UNetModel(nn.Module):
ctrl = control['output'].pop() ctrl = control['output'].pop()
if ctrl is not None: if ctrl is not None:
hsp += ctrl hsp += ctrl
h = th.cat([h, hsp], dim=1) h = th.cat([h, hsp], dim=1)
del hsp del hsp
h = module(h, emb, context, transformer_options) if len(hs) > 0:
output_shape = hs[-1].shape
else:
output_shape = None
h = module(h, emb, context, transformer_options, output_shape)
h = h.type(x.dtype) h = h.type(x.dtype)
if self.predict_codebook_ids: if self.predict_codebook_ids:
return self.id_predictor(h) return self.id_predictor(h)

View File

@ -20,15 +20,30 @@ total_vram_available_mb = -1
accelerate_enabled = False accelerate_enabled = False
xpu_available = False xpu_available = False
directml_enabled = False
if args.directml is not None:
import torch_directml
directml_enabled = True
device_index = args.directml
if device_index < 0:
directml_device = torch_directml.device()
else:
directml_device = torch_directml.device(device_index)
print("Using directml with device:", torch_directml.device_name(device_index))
# torch_directml.disable_tiled_resources(True)
try: try:
import torch import torch
try: if directml_enabled:
import intel_extension_for_pytorch as ipex total_vram = 4097 #TODO
if torch.xpu.is_available(): else:
xpu_available = True try:
total_vram = torch.xpu.get_device_properties(torch.xpu.current_device()).total_memory / (1024 * 1024) import intel_extension_for_pytorch as ipex
except: if torch.xpu.is_available():
total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024) xpu_available = True
total_vram = torch.xpu.get_device_properties(torch.xpu.current_device()).total_memory / (1024 * 1024)
except:
total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024)
total_ram = psutil.virtual_memory().total / (1024 * 1024) total_ram = psutil.virtual_memory().total / (1024 * 1024)
if not args.normalvram and not args.cpu: if not args.normalvram and not args.cpu:
if total_vram <= 4096: if total_vram <= 4096:
@ -217,6 +232,10 @@ def unload_if_low_vram(model):
def get_torch_device(): def get_torch_device():
global xpu_available global xpu_available
global directml_enabled
if directml_enabled:
global directml_device
return directml_device
if vram_state == VRAMState.MPS: if vram_state == VRAMState.MPS:
return torch.device("mps") return torch.device("mps")
if vram_state == VRAMState.CPU: if vram_state == VRAMState.CPU:
@ -234,8 +253,14 @@ def get_autocast_device(dev):
def xformers_enabled(): def xformers_enabled():
global xpu_available
global directml_enabled
if vram_state == VRAMState.CPU: if vram_state == VRAMState.CPU:
return False return False
if xpu_available:
return False
if directml_enabled:
return False
return XFORMERS_IS_AVAILABLE return XFORMERS_IS_AVAILABLE
@ -251,6 +276,7 @@ def pytorch_attention_enabled():
def get_free_memory(dev=None, torch_free_too=False): def get_free_memory(dev=None, torch_free_too=False):
global xpu_available global xpu_available
global directml_enabled
if dev is None: if dev is None:
dev = get_torch_device() dev = get_torch_device()
@ -258,7 +284,10 @@ def get_free_memory(dev=None, torch_free_too=False):
mem_free_total = psutil.virtual_memory().available mem_free_total = psutil.virtual_memory().available
mem_free_torch = mem_free_total mem_free_torch = mem_free_total
else: else:
if xpu_available: if directml_enabled:
mem_free_total = 1024 * 1024 * 1024 #TODO
mem_free_torch = mem_free_total
elif xpu_available:
mem_free_total = torch.xpu.get_device_properties(dev).total_memory - torch.xpu.memory_allocated(dev) mem_free_total = torch.xpu.get_device_properties(dev).total_memory - torch.xpu.memory_allocated(dev)
mem_free_torch = mem_free_total mem_free_torch = mem_free_total
else: else:
@ -293,9 +322,14 @@ def mps_mode():
def should_use_fp16(): def should_use_fp16():
global xpu_available global xpu_available
global directml_enabled
if FORCE_FP32: if FORCE_FP32:
return False return False
if directml_enabled:
return False
if cpu_mode() or mps_mode() or xpu_available: if cpu_mode() or mps_mode() or xpu_available:
return False #TODO ? return False #TODO ?

83
comfy/sample.py Normal file
View File

@ -0,0 +1,83 @@
import torch
import comfy.model_management
import comfy.samplers
import math
def prepare_noise(latent_image, seed, skip=0):
"""
creates random noise given a latent image and a seed.
optional arg skip can be used to skip and discard x number of noise generations for a given seed
"""
generator = torch.manual_seed(seed)
for _ in range(skip):
noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
return noise
def prepare_mask(noise_mask, shape, device):
"""ensures noise mask is of proper dimensions"""
noise_mask = torch.nn.functional.interpolate(noise_mask.reshape((-1, 1, noise_mask.shape[-2], noise_mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear")
noise_mask = noise_mask.round()
noise_mask = torch.cat([noise_mask] * shape[1], dim=1)
if noise_mask.shape[0] < shape[0]:
noise_mask = noise_mask.repeat(math.ceil(shape[0] / noise_mask.shape[0]), 1, 1, 1)[:shape[0]]
noise_mask = noise_mask.to(device)
return noise_mask
def broadcast_cond(cond, batch, device):
"""broadcasts conditioning to the batch size"""
copy = []
for p in cond:
t = p[0]
if t.shape[0] < batch:
t = torch.cat([t] * batch)
t = t.to(device)
copy += [[t] + p[1:]]
return copy
def get_models_from_cond(cond, model_type):
models = []
for c in cond:
if model_type in c[1]:
models += [c[1][model_type]]
return models
def load_additional_models(positive, negative):
"""loads additional models in positive and negative conditioning"""
control_nets = get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control")
gligen = get_models_from_cond(positive, "gligen") + get_models_from_cond(negative, "gligen")
gligen = [x[1] for x in gligen]
models = control_nets + gligen
comfy.model_management.load_controlnet_gpu(models)
return models
def cleanup_additional_models(models):
"""cleanup additional models that were loaded"""
for m in models:
m.cleanup()
def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False):
device = comfy.model_management.get_torch_device()
if noise_mask is not None:
noise_mask = prepare_mask(noise_mask, noise.shape, device)
real_model = None
comfy.model_management.load_model_gpu(model)
real_model = model.model
noise = noise.to(device)
latent_image = latent_image.to(device)
positive_copy = broadcast_cond(positive, noise.shape[0], device)
negative_copy = broadcast_cond(negative, noise.shape[0], device)
models = load_additional_models(positive, negative)
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar)
samples = samples.cpu()
cleanup_additional_models(models)
return samples

View File

@ -23,21 +23,36 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
adm_cond = cond[1]['adm_encoded'] adm_cond = cond[1]['adm_encoded']
input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
mult = torch.ones_like(input_x) * strength if 'mask' in cond[1]:
# Scale the mask to the size of the input
# The mask should have been resized as we began the sampling process
mask_strength = 1.0
if "mask_strength" in cond[1]:
mask_strength = cond[1]["mask_strength"]
mask = cond[1]['mask']
assert(mask.shape[1] == x_in.shape[2])
assert(mask.shape[2] == x_in.shape[3])
mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] * mask_strength
mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1)
else:
mask = torch.ones_like(input_x)
mult = mask * strength
if 'mask' not in cond[1]:
rr = 8
if area[2] != 0:
for t in range(rr):
mult[:,:,t:1+t,:] *= ((1.0/rr) * (t + 1))
if (area[0] + area[2]) < x_in.shape[2]:
for t in range(rr):
mult[:,:,area[0] - 1 - t:area[0] - t,:] *= ((1.0/rr) * (t + 1))
if area[3] != 0:
for t in range(rr):
mult[:,:,:,t:1+t] *= ((1.0/rr) * (t + 1))
if (area[1] + area[3]) < x_in.shape[3]:
for t in range(rr):
mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1))
rr = 8
if area[2] != 0:
for t in range(rr):
mult[:,:,t:1+t,:] *= ((1.0/rr) * (t + 1))
if (area[0] + area[2]) < x_in.shape[2]:
for t in range(rr):
mult[:,:,area[0] - 1 - t:area[0] - t,:] *= ((1.0/rr) * (t + 1))
if area[3] != 0:
for t in range(rr):
mult[:,:,:,t:1+t] *= ((1.0/rr) * (t + 1))
if (area[1] + area[3]) < x_in.shape[3]:
for t in range(rr):
mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1))
conditionning = {} conditionning = {}
conditionning['c_crossattn'] = cond[0] conditionning['c_crossattn'] = cond[0]
if cond_concat_in is not None and len(cond_concat_in) > 0: if cond_concat_in is not None and len(cond_concat_in) > 0:
@ -301,6 +316,71 @@ def blank_inpaint_image_like(latent_image):
blank_image[:,3] *= 0.1380 blank_image[:,3] *= 0.1380
return blank_image return blank_image
def get_mask_aabb(masks):
if masks.numel() == 0:
return torch.zeros((0, 4), device=masks.device, dtype=torch.int)
b = masks.shape[0]
bounding_boxes = torch.zeros((b, 4), device=masks.device, dtype=torch.int)
is_empty = torch.zeros((b), device=masks.device, dtype=torch.bool)
for i in range(b):
mask = masks[i]
if mask.numel() == 0:
continue
if torch.max(mask != 0) == False:
is_empty[i] = True
continue
y, x = torch.where(mask)
bounding_boxes[i, 0] = torch.min(x)
bounding_boxes[i, 1] = torch.min(y)
bounding_boxes[i, 2] = torch.max(x)
bounding_boxes[i, 3] = torch.max(y)
return bounding_boxes, is_empty
def resolve_cond_masks(conditions, h, w, device):
# We need to decide on an area outside the sampling loop in order to properly generate opposite areas of equal sizes.
# While we're doing this, we can also resolve the mask device and scaling for performance reasons
for i in range(len(conditions)):
c = conditions[i]
if 'mask' in c[1]:
mask = c[1]['mask']
mask = mask.to(device=device)
modified = c[1].copy()
if len(mask.shape) == 2:
mask = mask.unsqueeze(0)
if mask.shape[2] != h or mask.shape[3] != w:
mask = torch.nn.functional.interpolate(mask.unsqueeze(1), size=(h, w), mode='bilinear', align_corners=False).squeeze(1)
if modified.get("set_area_to_bounds", False):
bounds = torch.max(torch.abs(mask),dim=0).values.unsqueeze(0)
boxes, is_empty = get_mask_aabb(bounds)
if is_empty[0]:
# Use the minimum possible size for efficiency reasons. (Since the mask is all-0, this becomes a noop anyway)
modified['area'] = (8, 8, 0, 0)
else:
box = boxes[0]
H, W, Y, X = (box[3] - box[1] + 1, box[2] - box[0] + 1, box[1], box[0])
# Make sure the height and width are divisible by 8
if X % 8 != 0:
newx = X // 8 * 8
W = W + (X - newx)
X = newx
if Y % 8 != 0:
newy = Y // 8 * 8
H = H + (Y - newy)
Y = newy
if H % 8 != 0:
H = H + (8 - (H % 8))
if W % 8 != 0:
W = W + (8 - (W % 8))
area = (int(H), int(W), int(Y), int(X))
modified['area'] = area
modified['mask'] = mask
conditions[i] = [c[0], modified]
def create_cond_with_same_area_if_none(conds, c): def create_cond_with_same_area_if_none(conds, c):
if 'area' not in c[1]: if 'area' not in c[1]:
return return
@ -429,7 +509,7 @@ class KSampler:
self.denoise = denoise self.denoise = denoise
self.model_options = model_options self.model_options = model_options
def _calculate_sigmas(self, steps): def calculate_sigmas(self, steps):
sigmas = None sigmas = None
discard_penultimate_sigma = False discard_penultimate_sigma = False
@ -438,13 +518,13 @@ class KSampler:
discard_penultimate_sigma = True discard_penultimate_sigma = True
if self.scheduler == "karras": if self.scheduler == "karras":
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max, device=self.device) sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max)
elif self.scheduler == "normal": elif self.scheduler == "normal":
sigmas = self.model_wrap.get_sigmas(steps).to(self.device) sigmas = self.model_wrap.get_sigmas(steps)
elif self.scheduler == "simple": elif self.scheduler == "simple":
sigmas = simple_scheduler(self.model_wrap, steps).to(self.device) sigmas = simple_scheduler(self.model_wrap, steps)
elif self.scheduler == "ddim_uniform": elif self.scheduler == "ddim_uniform":
sigmas = ddim_scheduler(self.model_wrap, steps).to(self.device) sigmas = ddim_scheduler(self.model_wrap, steps)
else: else:
print("error invalid scheduler", self.scheduler) print("error invalid scheduler", self.scheduler)
@ -455,15 +535,15 @@ class KSampler:
def set_steps(self, steps, denoise=None): def set_steps(self, steps, denoise=None):
self.steps = steps self.steps = steps
if denoise is None or denoise > 0.9999: if denoise is None or denoise > 0.9999:
self.sigmas = self._calculate_sigmas(steps) self.sigmas = self.calculate_sigmas(steps).to(self.device)
else: else:
new_steps = int(steps/denoise) new_steps = int(steps/denoise)
sigmas = self._calculate_sigmas(new_steps) sigmas = self.calculate_sigmas(new_steps).to(self.device)
self.sigmas = sigmas[-(steps + 1):] self.sigmas = sigmas[-(steps + 1):]
def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None, sigmas=None, callback=None, disable_pbar=False):
def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None): if sigmas is None:
sigmas = self.sigmas sigmas = self.sigmas
sigma_min = self.sigma_min sigma_min = self.sigma_min
if last_step is not None and last_step < (len(sigmas) - 1): if last_step is not None and last_step < (len(sigmas) - 1):
@ -483,6 +563,10 @@ class KSampler:
positive = positive[:] positive = positive[:]
negative = negative[:] negative = negative[:]
resolve_cond_masks(positive, noise.shape[2], noise.shape[3], self.device)
resolve_cond_masks(negative, noise.shape[2], noise.shape[3], self.device)
#make sure each cond area has an opposite one with the same area #make sure each cond area has an opposite one with the same area
for c in positive: for c in positive:
create_cond_with_same_area_if_none(negative, c) create_cond_with_same_area_if_none(negative, c)
@ -526,9 +610,9 @@ class KSampler:
with precision_scope(model_management.get_autocast_device(self.device)): with precision_scope(model_management.get_autocast_device(self.device)):
if self.sampler == "uni_pc": if self.sampler == "uni_pc":
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask) samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback, disable=disable_pbar)
elif self.sampler == "uni_pc_bh2": elif self.sampler == "uni_pc_bh2":
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, variant='bh2') samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2', disable=disable_pbar)
elif self.sampler == "ddim": elif self.sampler == "ddim":
timesteps = [] timesteps = []
for s in range(sigmas.shape[0]): for s in range(sigmas.shape[0]):
@ -536,6 +620,11 @@ class KSampler:
noise_mask = None noise_mask = None
if denoise_mask is not None: if denoise_mask is not None:
noise_mask = 1.0 - denoise_mask noise_mask = 1.0 - denoise_mask
ddim_callback = None
if callback is not None:
ddim_callback = lambda pred_x0, i: callback(i, pred_x0, None)
sampler = DDIMSampler(self.model, device=self.device) sampler = DDIMSampler(self.model, device=self.device)
sampler.make_schedule_timesteps(ddim_timesteps=timesteps, verbose=False) sampler.make_schedule_timesteps(ddim_timesteps=timesteps, verbose=False)
z_enc = sampler.stochastic_encode(latent_image, torch.tensor([len(timesteps) - 1] * noise.shape[0]).to(self.device), noise=noise, max_denoise=max_denoise) z_enc = sampler.stochastic_encode(latent_image, torch.tensor([len(timesteps) - 1] * noise.shape[0]).to(self.device), noise=noise, max_denoise=max_denoise)
@ -549,11 +638,13 @@ class KSampler:
eta=0.0, eta=0.0,
x_T=z_enc, x_T=z_enc,
x0=latent_image, x0=latent_image,
img_callback=ddim_callback,
denoise_function=sampling_function, denoise_function=sampling_function,
extra_args=extra_args, extra_args=extra_args,
mask=noise_mask, mask=noise_mask,
to_zero=sigmas[-1]==0, to_zero=sigmas[-1]==0,
end_step=sigmas.shape[0] - 1) end_step=sigmas.shape[0] - 1,
disable_pbar=disable_pbar)
else: else:
extra_args["denoise_mask"] = denoise_mask extra_args["denoise_mask"] = denoise_mask
@ -562,13 +653,17 @@ class KSampler:
noise = noise * sigmas[0] noise = noise * sigmas[0]
k_callback = None
if callback is not None:
k_callback = lambda x: callback(x["i"], x["denoised"], x["x"])
if latent_image is not None: if latent_image is not None:
noise += latent_image noise += latent_image
if self.sampler == "dpm_fast": if self.sampler == "dpm_fast":
samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], self.steps, extra_args=extra_args) samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], self.steps, extra_args=extra_args, callback=k_callback, disable=disable_pbar)
elif self.sampler == "dpm_adaptive": elif self.sampler == "dpm_adaptive":
samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args=extra_args) samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback, disable=disable_pbar)
else: else:
samples = getattr(k_diffusion_sampling, "sample_{}".format(self.sampler))(self.model_k, noise, sigmas, extra_args=extra_args) samples = getattr(k_diffusion_sampling, "sample_{}".format(self.sampler))(self.model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar)
return samples.to(torch.float32) return samples.to(torch.float32)

View File

@ -112,6 +112,8 @@ def load_lora(path, to_load):
loaded_keys.add(A_name) loaded_keys.add(A_name)
loaded_keys.add(B_name) loaded_keys.add(B_name)
######## loha
hada_w1_a_name = "{}.hada_w1_a".format(x) hada_w1_a_name = "{}.hada_w1_a".format(x)
hada_w1_b_name = "{}.hada_w1_b".format(x) hada_w1_b_name = "{}.hada_w1_b".format(x)
hada_w2_a_name = "{}.hada_w2_a".format(x) hada_w2_a_name = "{}.hada_w2_a".format(x)
@ -133,6 +135,54 @@ def load_lora(path, to_load):
loaded_keys.add(hada_w2_a_name) loaded_keys.add(hada_w2_a_name)
loaded_keys.add(hada_w2_b_name) loaded_keys.add(hada_w2_b_name)
######## lokr
lokr_w1_name = "{}.lokr_w1".format(x)
lokr_w2_name = "{}.lokr_w2".format(x)
lokr_w1_a_name = "{}.lokr_w1_a".format(x)
lokr_w1_b_name = "{}.lokr_w1_b".format(x)
lokr_t2_name = "{}.lokr_t2".format(x)
lokr_w2_a_name = "{}.lokr_w2_a".format(x)
lokr_w2_b_name = "{}.lokr_w2_b".format(x)
lokr_w1 = None
if lokr_w1_name in lora.keys():
lokr_w1 = lora[lokr_w1_name]
loaded_keys.add(lokr_w1_name)
lokr_w2 = None
if lokr_w2_name in lora.keys():
lokr_w2 = lora[lokr_w2_name]
loaded_keys.add(lokr_w2_name)
lokr_w1_a = None
if lokr_w1_a_name in lora.keys():
lokr_w1_a = lora[lokr_w1_a_name]
loaded_keys.add(lokr_w1_a_name)
lokr_w1_b = None
if lokr_w1_b_name in lora.keys():
lokr_w1_b = lora[lokr_w1_b_name]
loaded_keys.add(lokr_w1_b_name)
lokr_w2_a = None
if lokr_w2_a_name in lora.keys():
lokr_w2_a = lora[lokr_w2_a_name]
loaded_keys.add(lokr_w2_a_name)
lokr_w2_b = None
if lokr_w2_b_name in lora.keys():
lokr_w2_b = lora[lokr_w2_b_name]
loaded_keys.add(lokr_w2_b_name)
lokr_t2 = None
if lokr_t2_name in lora.keys():
lokr_t2 = lora[lokr_t2_name]
loaded_keys.add(lokr_t2_name)
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
patch_dict[to_load[x]] = (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2)
for x in lora.keys(): for x in lora.keys():
if x not in loaded_keys: if x not in loaded_keys:
print("lora key not loaded", x) print("lora key not loaded", x)
@ -316,6 +366,33 @@ class ModelPatcher:
final_shape = [mat2.shape[1], mat2.shape[0], v[3].shape[2], v[3].shape[3]] final_shape = [mat2.shape[1], mat2.shape[0], v[3].shape[2], v[3].shape[3]]
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1).float(), v[3].transpose(0, 1).flatten(start_dim=1).float()).reshape(final_shape).transpose(0, 1) mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1).float(), v[3].transpose(0, 1).flatten(start_dim=1).float()).reshape(final_shape).transpose(0, 1)
weight += (alpha * torch.mm(mat1.flatten(start_dim=1).float(), mat2.flatten(start_dim=1).float())).reshape(weight.shape).type(weight.dtype).to(weight.device) weight += (alpha * torch.mm(mat1.flatten(start_dim=1).float(), mat2.flatten(start_dim=1).float())).reshape(weight.shape).type(weight.dtype).to(weight.device)
elif len(v) == 8: #lokr
w1 = v[0]
w2 = v[1]
w1_a = v[3]
w1_b = v[4]
w2_a = v[5]
w2_b = v[6]
t2 = v[7]
dim = None
if w1 is None:
dim = w1_b.shape[0]
w1 = torch.mm(w1_a.float(), w1_b.float())
if w2 is None:
dim = w2_b.shape[0]
if t2 is None:
w2 = torch.mm(w2_a.float(), w2_b.float())
else:
w2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float(), w2_b.float(), w2_a.float())
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
if v[2] is not None and dim is not None:
alpha *= v[2] / dim
weight += alpha * torch.kron(w1.float(), w2.float()).reshape(weight.shape).type(weight.dtype).to(weight.device)
else: #loha else: #loha
w1a = v[0] w1a = v[0]
w1b = v[1] w1b = v[1]

View File

@ -94,3 +94,26 @@ def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_am
output[b:b+1] = out/out_div output[b:b+1] = out/out_div
return output return output
PROGRESS_BAR_HOOK = None
def set_progress_bar_global_hook(function):
global PROGRESS_BAR_HOOK
PROGRESS_BAR_HOOK = function
class ProgressBar:
def __init__(self, total):
global PROGRESS_BAR_HOOK
self.total = total
self.current = 0
self.hook = PROGRESS_BAR_HOOK
def update_absolute(self, value):
if value > self.total:
value = self.total
self.current = value
if self.hook is not None:
self.hook(self.current, self.total)
def update(self, value):
self.update_absolute(self.current + value)

View File

@ -4,7 +4,10 @@
from __future__ import annotations from __future__ import annotations
from collections import OrderedDict from collections import OrderedDict
from typing import Literal try:
from typing import Literal
except ImportError:
from typing_extensions import Literal
import torch import torch
import torch.nn as nn import torch.nn as nn

View File

@ -10,7 +10,17 @@ def load_hypernetwork_patch(path, strength):
activate_output = sd.get('activate_output', False) activate_output = sd.get('activate_output', False)
last_layer_dropout = sd.get('last_layer_dropout', False) last_layer_dropout = sd.get('last_layer_dropout', False)
if activation_func != 'linear' or is_layer_norm != False or use_dropout != False or activate_output != False or last_layer_dropout != False: valid_activation = {
"linear": torch.nn.Identity,
"relu": torch.nn.ReLU,
"leakyrelu": torch.nn.LeakyReLU,
"elu": torch.nn.ELU,
"swish": torch.nn.Hardswish,
"tanh": torch.nn.Tanh,
"sigmoid": torch.nn.Sigmoid,
}
if activation_func not in valid_activation:
print("Unsupported Hypernetwork format, if you report it I might implement it.", path, " ", activation_func, is_layer_norm, use_dropout, activate_output, last_layer_dropout) print("Unsupported Hypernetwork format, if you report it I might implement it.", path, " ", activation_func, is_layer_norm, use_dropout, activate_output, last_layer_dropout)
return None return None
@ -28,15 +38,27 @@ def load_hypernetwork_patch(path, strength):
keys = attn_weights.keys() keys = attn_weights.keys()
linears = filter(lambda a: a.endswith(".weight"), keys) linears = filter(lambda a: a.endswith(".weight"), keys)
linears = sorted(list(map(lambda a: a[:-len(".weight")], linears))) linears = list(map(lambda a: a[:-len(".weight")], linears))
layers = [] layers = []
for lin_name in linears: for i in range(len(linears)):
lin_name = linears[i]
last_layer = (i == (len(linears) - 1))
penultimate_layer = (i == (len(linears) - 2))
lin_weight = attn_weights['{}.weight'.format(lin_name)] lin_weight = attn_weights['{}.weight'.format(lin_name)]
lin_bias = attn_weights['{}.bias'.format(lin_name)] lin_bias = attn_weights['{}.bias'.format(lin_name)]
layer = torch.nn.Linear(lin_weight.shape[1], lin_weight.shape[0]) layer = torch.nn.Linear(lin_weight.shape[1], lin_weight.shape[0])
layer.load_state_dict({"weight": lin_weight, "bias": lin_bias}) layer.load_state_dict({"weight": lin_weight, "bias": lin_bias})
layers += [layer] layers.append(layer)
if activation_func != "linear":
if (not last_layer) or (activate_output):
layers.append(valid_activation[activation_func]())
if is_layer_norm:
layers.append(torch.nn.LayerNorm(lin_weight.shape[0]))
if use_dropout:
if (not last_layer) and (not penultimate_layer or last_layer_dropout):
layers.append(torch.nn.Dropout(p=0.3))
output.append(torch.nn.Sequential(*layers)) output.append(torch.nn.Sequential(*layers))
out[dim] = torch.nn.ModuleList(output) out[dim] = torch.nn.ModuleList(output)
@ -71,7 +93,7 @@ class HypernetworkLoader:
RETURN_TYPES = ("MODEL",) RETURN_TYPES = ("MODEL",)
FUNCTION = "load_hypernetwork" FUNCTION = "load_hypernetwork"
CATEGORY = "_for_testing" CATEGORY = "loaders"
def load_hypernetwork(self, model, hypernetwork_name, strength): def load_hypernetwork(self, model, hypernetwork_name, strength):
hypernetwork_path = folder_paths.get_full_path("hypernetworks", hypernetwork_name) hypernetwork_path = folder_paths.get_full_path("hypernetworks", hypernetwork_name)

View File

@ -40,15 +40,13 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
input_data_all[x] = unique_id input_data_all[x] = unique_id
return input_data_all return input_data_all
def recursive_execute(server, prompt, outputs, current_item, extra_data={}): def recursive_execute(server, prompt, outputs, current_item, extra_data, executed):
unique_id = current_item unique_id = current_item
inputs = prompt[unique_id]['inputs'] inputs = prompt[unique_id]['inputs']
class_type = prompt[unique_id]['class_type'] class_type = prompt[unique_id]['class_type']
class_def = nodes.NODE_CLASS_MAPPINGS[class_type] class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
if unique_id in outputs: if unique_id in outputs:
return [] return
executed = []
for x in inputs: for x in inputs:
input_data = inputs[x] input_data = inputs[x]
@ -57,7 +55,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data={}):
input_unique_id = input_data[0] input_unique_id = input_data[0]
output_index = input_data[1] output_index = input_data[1]
if input_unique_id not in outputs: if input_unique_id not in outputs:
executed += recursive_execute(server, prompt, outputs, input_unique_id, extra_data) recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed)
input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data) input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data)
if server.client_id is not None: if server.client_id is not None:
@ -72,7 +70,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data={}):
server.send_sync("executed", { "node": unique_id, "output": outputs[unique_id]["ui"] }, server.client_id) server.send_sync("executed", { "node": unique_id, "output": outputs[unique_id]["ui"] }, server.client_id)
if "result" in outputs[unique_id]: if "result" in outputs[unique_id]:
outputs[unique_id] = outputs[unique_id]["result"] outputs[unique_id] = outputs[unique_id]["result"]
return executed + [unique_id] executed.add(unique_id)
def recursive_will_execute(prompt, outputs, current_item): def recursive_will_execute(prompt, outputs, current_item):
unique_id = current_item unique_id = current_item
@ -99,40 +97,44 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item
is_changed_old = '' is_changed_old = ''
is_changed = '' is_changed = ''
to_delete = False
if hasattr(class_def, 'IS_CHANGED'): if hasattr(class_def, 'IS_CHANGED'):
if unique_id in old_prompt and 'is_changed' in old_prompt[unique_id]: if unique_id in old_prompt and 'is_changed' in old_prompt[unique_id]:
is_changed_old = old_prompt[unique_id]['is_changed'] is_changed_old = old_prompt[unique_id]['is_changed']
if 'is_changed' not in prompt[unique_id]: if 'is_changed' not in prompt[unique_id]:
input_data_all = get_input_data(inputs, class_def, unique_id, outputs) input_data_all = get_input_data(inputs, class_def, unique_id, outputs)
if input_data_all is not None: if input_data_all is not None:
is_changed = class_def.IS_CHANGED(**input_data_all) try:
prompt[unique_id]['is_changed'] = is_changed is_changed = class_def.IS_CHANGED(**input_data_all)
prompt[unique_id]['is_changed'] = is_changed
except:
to_delete = True
else: else:
is_changed = prompt[unique_id]['is_changed'] is_changed = prompt[unique_id]['is_changed']
if unique_id not in outputs: if unique_id not in outputs:
return True return True
to_delete = False if not to_delete:
if is_changed != is_changed_old: if is_changed != is_changed_old:
to_delete = True to_delete = True
elif unique_id not in old_prompt: elif unique_id not in old_prompt:
to_delete = True to_delete = True
elif inputs == old_prompt[unique_id]['inputs']: elif inputs == old_prompt[unique_id]['inputs']:
for x in inputs: for x in inputs:
input_data = inputs[x] input_data = inputs[x]
if isinstance(input_data, list): if isinstance(input_data, list):
input_unique_id = input_data[0] input_unique_id = input_data[0]
output_index = input_data[1] output_index = input_data[1]
if input_unique_id in outputs: if input_unique_id in outputs:
to_delete = recursive_output_delete_if_changed(prompt, old_prompt, outputs, input_unique_id) to_delete = recursive_output_delete_if_changed(prompt, old_prompt, outputs, input_unique_id)
else: else:
to_delete = True to_delete = True
if to_delete: if to_delete:
break break
else: else:
to_delete = True to_delete = True
if to_delete: if to_delete:
d = outputs.pop(unique_id) d = outputs.pop(unique_id)
@ -154,11 +156,20 @@ class PromptExecutor:
self.server.client_id = None self.server.client_id = None
with torch.inference_mode(): with torch.inference_mode():
#delete cached outputs if nodes don't exist for them
to_delete = []
for o in self.outputs:
if o not in prompt:
to_delete += [o]
for o in to_delete:
d = self.outputs.pop(o)
del d
for x in prompt: for x in prompt:
recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x) recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x)
current_outputs = set(self.outputs.keys()) current_outputs = set(self.outputs.keys())
executed = [] executed = set()
try: try:
to_execute = [] to_execute = []
for x in prompt: for x in prompt:
@ -181,12 +192,12 @@ class PromptExecutor:
except: except:
valid = False valid = False
if valid: if valid:
executed += recursive_execute(self.server, prompt, self.outputs, x, extra_data) recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed)
except Exception as e: except Exception as e:
print(traceback.format_exc()) print(traceback.format_exc())
to_delete = [] to_delete = []
for o in self.outputs: for o in self.outputs:
if o not in current_outputs: if (o not in current_outputs) and (o not in executed):
to_delete += [o] to_delete += [o]
if o in self.old_prompt: if o in self.old_prompt:
d = self.old_prompt.pop(o) d = self.old_prompt.pop(o)
@ -194,11 +205,9 @@ class PromptExecutor:
for o in to_delete: for o in to_delete:
d = self.outputs.pop(o) d = self.outputs.pop(o)
del d del d
else: finally:
executed = set(executed)
for x in executed: for x in executed:
self.old_prompt[x] = copy.deepcopy(prompt[x]) self.old_prompt[x] = copy.deepcopy(prompt[x])
finally:
self.server.last_node_id = None self.server.last_node_id = None
if self.server.client_id is not None: if self.server.client_id is not None:
self.server.send_sync("executing", { "node": None }, self.server.client_id) self.server.send_sync("executing", { "node": None }, self.server.client_id)
@ -249,9 +258,15 @@ def validate_inputs(prompt, item):
if "max" in info[1] and val > info[1]["max"]: if "max" in info[1] and val > info[1]["max"]:
return (False, "Value bigger than max. {}, {}".format(class_type, x)) return (False, "Value bigger than max. {}, {}".format(class_type, x))
if isinstance(type_input, list): if hasattr(obj_class, "VALIDATE_INPUTS"):
if val not in type_input: input_data_all = get_input_data(inputs, obj_class, unique_id)
return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input)) ret = obj_class.VALIDATE_INPUTS(**input_data_all)
if ret != True:
return (False, "{}, {}".format(class_type, ret))
else:
if isinstance(type_input, list):
if val not in type_input:
return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input))
return (True, "") return (True, "")
def validate_prompt(prompt): def validate_prompt(prompt):
@ -273,7 +288,8 @@ def validate_prompt(prompt):
m = validate_inputs(prompt, o) m = validate_inputs(prompt, o)
valid = m[0] valid = m[0]
reason = m[1] reason = m[1]
except: except Exception as e:
print(traceback.format_exc())
valid = False valid = False
reason = "Parsing error" reason = "Parsing error"

View File

@ -13,6 +13,7 @@ a111:
models/ESRGAN models/ESRGAN
models/SwinIR models/SwinIR
embeddings: embeddings embeddings: embeddings
hypernetworks: models/hypernetworks
controlnet: models/ControlNet controlnet: models/ControlNet
#other_ui: #other_ui:

View File

@ -69,6 +69,46 @@ def get_directory_by_type(type_name):
return None return None
# determine base_dir rely on annotation if name is 'filename.ext [annotation]' format
# otherwise use default_path as base_dir
def annotated_filepath(name):
if name.endswith("[output]"):
base_dir = get_output_directory()
name = name[:-9]
elif name.endswith("[input]"):
base_dir = get_input_directory()
name = name[:-8]
elif name.endswith("[temp]"):
base_dir = get_temp_directory()
name = name[:-7]
else:
return name, None
return name, base_dir
def get_annotated_filepath(name, default_dir=None):
name, base_dir = annotated_filepath(name)
if base_dir is None:
if default_dir is not None:
base_dir = default_dir
else:
base_dir = get_input_directory() # fallback path
return os.path.join(base_dir, name)
def exists_annotated_filepath(name):
name, base_dir = annotated_filepath(name)
if base_dir is None:
base_dir = get_input_directory() # fallback path
filepath = os.path.join(base_dir, name)
return os.path.exists(filepath)
def add_model_folder_path(folder_name, full_folder_path): def add_model_folder_path(folder_name, full_folder_path):
global folder_names_and_paths global folder_names_and_paths
if folder_name in folder_names_and_paths: if folder_name in folder_names_and_paths:

12
main.py
View File

@ -5,6 +5,7 @@ import shutil
import threading import threading
from comfy.cli_args import args from comfy.cli_args import args
import comfy.utils
if os.name == "nt": if os.name == "nt":
import logging import logging
@ -39,14 +40,9 @@ async def run(server, address='', port=8188, verbose=True, call_on_start=None):
await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop()) await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop())
def hijack_progress(server): def hijack_progress(server):
from tqdm.auto import tqdm def hook(value, total):
orig_func = getattr(tqdm, "update") server.send_sync("progress", { "value": value, "max": total}, server.client_id)
def wrapped_func(*args, **kwargs): comfy.utils.set_progress_bar_global_hook(hook)
pbar = args[0]
v = orig_func(*args, **kwargs)
server.send_sync("progress", { "value": pbar.n, "max": pbar.total}, server.client_id)
return v
setattr(tqdm, "update", wrapped_func)
def cleanup_temp(): def cleanup_temp():
temp_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp") temp_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp")

264
nodes.py
View File

@ -5,6 +5,7 @@ import sys
import json import json
import hashlib import hashlib
import traceback import traceback
import math
from PIL import Image from PIL import Image
from PIL.PngImagePlugin import PngInfo from PIL.PngImagePlugin import PngInfo
@ -16,6 +17,7 @@ sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "co
import comfy.diffusers_convert import comfy.diffusers_convert
import comfy.samplers import comfy.samplers
import comfy.sample
import comfy.sd import comfy.sd
import comfy.utils import comfy.utils
@ -58,14 +60,44 @@ class ConditioningCombine:
def combine(self, conditioning_1, conditioning_2): def combine(self, conditioning_1, conditioning_2):
return (conditioning_1 + conditioning_2, ) return (conditioning_1 + conditioning_2, )
class ConditioningAverage :
@classmethod
def INPUT_TYPES(s):
return {"required": {"conditioning_to": ("CONDITIONING", ), "conditioning_from": ("CONDITIONING", ),
"conditioning_to_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "addWeighted"
CATEGORY = "conditioning"
def addWeighted(self, conditioning_to, conditioning_from, conditioning_to_strength):
out = []
if len(conditioning_from) > 1:
print("Warning: ConditioningAverage conditioning_from contains more than 1 cond, only the first one will actually be applied to conditioning_to.")
cond_from = conditioning_from[0][0]
for i in range(len(conditioning_to)):
t1 = conditioning_to[i][0]
t0 = cond_from[:,:t1.shape[1]]
if t0.shape[1] < t1.shape[1]:
t0 = torch.cat([t0] + [torch.zeros((1, (t1.shape[1] - t0.shape[1]), t1.shape[2]))], dim=1)
tw = torch.mul(t1, conditioning_to_strength) + torch.mul(t0, (1.0 - conditioning_to_strength))
n = [tw, conditioning_to[i][1].copy()]
out.append(n)
return (out, )
class ConditioningSetArea: class ConditioningSetArea:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": {"conditioning": ("CONDITIONING", ), return {"required": {"conditioning": ("CONDITIONING", ),
"width": ("INT", {"default": 64, "min": 64, "max": MAX_RESOLUTION, "step": 64}), "width": ("INT", {"default": 64, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 64, "min": 64, "max": MAX_RESOLUTION, "step": 64}), "height": ("INT", {"default": 64, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 64}), "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 64}), "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
}} }}
RETURN_TYPES = ("CONDITIONING",) RETURN_TYPES = ("CONDITIONING",)
@ -79,11 +111,41 @@ class ConditioningSetArea:
n = [t[0], t[1].copy()] n = [t[0], t[1].copy()]
n[1]['area'] = (height // 8, width // 8, y // 8, x // 8) n[1]['area'] = (height // 8, width // 8, y // 8, x // 8)
n[1]['strength'] = strength n[1]['strength'] = strength
n[1]['set_area_to_bounds'] = False
n[1]['min_sigma'] = min_sigma n[1]['min_sigma'] = min_sigma
n[1]['max_sigma'] = max_sigma n[1]['max_sigma'] = max_sigma
c.append(n) c.append(n)
return (c, ) return (c, )
class ConditioningSetMask:
@classmethod
def INPUT_TYPES(s):
return {"required": {"conditioning": ("CONDITIONING", ),
"mask": ("MASK", ),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"set_cond_area": (["default", "mask bounds"],),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "append"
CATEGORY = "conditioning"
def append(self, conditioning, mask, set_cond_area, strength):
c = []
set_area_to_bounds = False
if set_cond_area != "default":
set_area_to_bounds = True
if len(mask.shape) < 3:
mask = mask.unsqueeze(0)
for t in conditioning:
n = [t[0], t[1].copy()]
_, h, w = mask.shape
n[1]['mask'] = mask
n[1]['set_area_to_bounds'] = set_area_to_bounds
n[1]['mask_strength'] = strength
c.append(n)
return (c, )
class VAEDecode: class VAEDecode:
def __init__(self, device="cpu"): def __init__(self, device="cpu"):
self.device = device self.device = device
@ -126,16 +188,21 @@ class VAEEncode:
CATEGORY = "latent" CATEGORY = "latent"
def encode(self, vae, pixels): @staticmethod
x = (pixels.shape[1] // 64) * 64 def vae_encode_crop_pixels(pixels):
y = (pixels.shape[2] // 64) * 64 x = (pixels.shape[1] // 8) * 8
y = (pixels.shape[2] // 8) * 8
if pixels.shape[1] != x or pixels.shape[2] != y: if pixels.shape[1] != x or pixels.shape[2] != y:
pixels = pixels[:,:x,:y,:] x_offset = (pixels.shape[1] % 8) // 2
y_offset = (pixels.shape[2] % 8) // 2
pixels = pixels[:, x_offset:x + x_offset, y_offset:y + y_offset, :]
return pixels
def encode(self, vae, pixels):
pixels = self.vae_encode_crop_pixels(pixels)
t = vae.encode(pixels[:,:,:,:3]) t = vae.encode(pixels[:,:,:,:3])
return ({"samples":t}, ) return ({"samples":t}, )
class VAEEncodeTiled: class VAEEncodeTiled:
def __init__(self, device="cpu"): def __init__(self, device="cpu"):
self.device = device self.device = device
@ -149,46 +216,51 @@ class VAEEncodeTiled:
CATEGORY = "_for_testing" CATEGORY = "_for_testing"
def encode(self, vae, pixels): def encode(self, vae, pixels):
x = (pixels.shape[1] // 64) * 64 pixels = VAEEncode.vae_encode_crop_pixels(pixels)
y = (pixels.shape[2] // 64) * 64
if pixels.shape[1] != x or pixels.shape[2] != y:
pixels = pixels[:,:x,:y,:]
t = vae.encode_tiled(pixels[:,:,:,:3]) t = vae.encode_tiled(pixels[:,:,:,:3])
return ({"samples":t}, ) return ({"samples":t}, )
class VAEEncodeForInpaint: class VAEEncodeForInpaint:
def __init__(self, device="cpu"): def __init__(self, device="cpu"):
self.device = device self.device = device
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", ), "mask": ("MASK", )}} return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", ), "mask": ("MASK", ), "grow_mask_by": ("INT", {"default": 6, "min": 0, "max": 64, "step": 1}),}}
RETURN_TYPES = ("LATENT",) RETURN_TYPES = ("LATENT",)
FUNCTION = "encode" FUNCTION = "encode"
CATEGORY = "latent/inpaint" CATEGORY = "latent/inpaint"
def encode(self, vae, pixels, mask): def encode(self, vae, pixels, mask, grow_mask_by=6):
x = (pixels.shape[1] // 64) * 64 x = (pixels.shape[1] // 8) * 8
y = (pixels.shape[2] // 64) * 64 y = (pixels.shape[2] // 8) * 8
mask = torch.nn.functional.interpolate(mask[None,None,], size=(pixels.shape[1], pixels.shape[2]), mode="bilinear")[0][0] mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear")
pixels = pixels.clone() pixels = pixels.clone()
if pixels.shape[1] != x or pixels.shape[2] != y: if pixels.shape[1] != x or pixels.shape[2] != y:
pixels = pixels[:,:x,:y,:] x_offset = (pixels.shape[1] % 8) // 2
mask = mask[:x,:y] y_offset = (pixels.shape[2] % 8) // 2
pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:]
mask = mask[:,:,x_offset:x + x_offset, y_offset:y + y_offset]
#grow mask by a few pixels to keep things seamless in latent space #grow mask by a few pixels to keep things seamless in latent space
kernel_tensor = torch.ones((1, 1, 6, 6)) if grow_mask_by == 0:
mask_erosion = torch.clamp(torch.nn.functional.conv2d((mask.round())[None], kernel_tensor, padding=3), 0, 1) mask_erosion = mask
m = (1.0 - mask.round()) else:
kernel_tensor = torch.ones((1, 1, grow_mask_by, grow_mask_by))
padding = math.ceil((grow_mask_by - 1) / 2)
mask_erosion = torch.clamp(torch.nn.functional.conv2d(mask.round(), kernel_tensor, padding=padding), 0, 1)
m = (1.0 - mask.round()).squeeze(1)
for i in range(3): for i in range(3):
pixels[:,:,:,i] -= 0.5 pixels[:,:,:,i] -= 0.5
pixels[:,:,:,i] *= m pixels[:,:,:,i] *= m
pixels[:,:,:,i] += 0.5 pixels[:,:,:,i] += 0.5
t = vae.encode(pixels) t = vae.encode(pixels)
return ({"samples":t, "noise_mask": (mask_erosion[0][:x,:y].round())}, ) return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, )
class CheckpointLoader: class CheckpointLoader:
@classmethod @classmethod
@ -542,8 +614,8 @@ class EmptyLatentImage:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "width": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 64}), return {"required": { "width": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 64}), "height": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 64})}} "batch_size": ("INT", {"default": 1, "min": 1, "max": 64})}}
RETURN_TYPES = ("LATENT",) RETURN_TYPES = ("LATENT",)
FUNCTION = "generate" FUNCTION = "generate"
@ -581,8 +653,8 @@ class LatentUpscale:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",), "upscale_method": (s.upscale_methods,), return {"required": { "samples": ("LATENT",), "upscale_method": (s.upscale_methods,),
"width": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 64}), "width": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 64}), "height": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
"crop": (s.crop_methods,)}} "crop": (s.crop_methods,)}}
RETURN_TYPES = ("LATENT",) RETURN_TYPES = ("LATENT",)
FUNCTION = "upscale" FUNCTION = "upscale"
@ -684,8 +756,8 @@ class LatentCrop:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",), return {"required": { "samples": ("LATENT",),
"width": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 64}), "width": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 64}), "height": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
}} }}
@ -710,16 +782,6 @@ class LatentCrop:
new_width = width // 8 new_width = width // 8
to_x = new_width + x to_x = new_width + x
to_y = new_height + y to_y = new_height + y
def enforce_image_dim(d, to_d, max_d):
if to_d > max_d:
leftover = (to_d - max_d) % 8
to_d = max_d
d -= leftover
return (d, to_d)
#make sure size is always multiple of 64
x, to_x = enforce_image_dim(x, to_x, samples.shape[3])
y, to_y = enforce_image_dim(y, to_y, samples.shape[2])
s['samples'] = samples[:,:,y:to_y, x:to_x] s['samples'] = samples[:,:,y:to_y, x:to_x]
return (s,) return (s,)
@ -739,79 +801,27 @@ class SetLatentNoiseMask:
s["noise_mask"] = mask s["noise_mask"] = mask
return (s,) return (s,)
def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False): def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False):
latent_image = latent["samples"]
noise_mask = None
device = comfy.model_management.get_torch_device() device = comfy.model_management.get_torch_device()
latent_image = latent["samples"]
if disable_noise: if disable_noise:
noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
else: else:
batch_index = 0 skip = latent["batch_index"] if "batch_index" in latent else 0
if "batch_index" in latent: noise = comfy.sample.prepare_noise(latent_image, seed, skip)
batch_index = latent["batch_index"]
generator = torch.manual_seed(seed)
for i in range(batch_index):
noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
noise_mask = None
if "noise_mask" in latent: if "noise_mask" in latent:
noise_mask = latent['noise_mask'] noise_mask = latent["noise_mask"]
noise_mask = torch.nn.functional.interpolate(noise_mask[None,None,], size=(noise.shape[2], noise.shape[3]), mode="bilinear")
noise_mask = noise_mask.round()
noise_mask = torch.cat([noise_mask] * noise.shape[1], dim=1)
noise_mask = torch.cat([noise_mask] * noise.shape[0])
noise_mask = noise_mask.to(device)
real_model = None pbar = comfy.utils.ProgressBar(steps)
comfy.model_management.load_model_gpu(model) def callback(step, x0, x):
real_model = model.model pbar.update_absolute(step + 1)
noise = noise.to(device)
latent_image = latent_image.to(device)
positive_copy = []
negative_copy = []
control_nets = []
def get_models(cond):
models = []
for c in cond:
if 'control' in c[1]:
models += [c[1]['control']]
if 'gligen' in c[1]:
models += [c[1]['gligen'][1]]
return models
for p in positive:
t = p[0]
if t.shape[0] < noise.shape[0]:
t = torch.cat([t] * noise.shape[0])
t = t.to(device)
positive_copy += [[t] + p[1:]]
for n in negative:
t = n[0]
if t.shape[0] < noise.shape[0]:
t = torch.cat([t] * noise.shape[0])
t = t.to(device)
negative_copy += [[t] + n[1:]]
models = get_models(positive) + get_models(negative)
comfy.model_management.load_controlnet_gpu(models)
if sampler_name in comfy.samplers.KSampler.SAMPLERS:
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
else:
#other samplers
pass
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask)
samples = samples.cpu()
for m in models:
m.cleanup()
samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step,
force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback)
out = latent.copy() out = latent.copy()
out["samples"] = samples out["samples"] = samples
return (out, ) return (out, )
@ -974,8 +984,7 @@ class LoadImage:
RETURN_TYPES = ("IMAGE", "MASK") RETURN_TYPES = ("IMAGE", "MASK")
FUNCTION = "load_image" FUNCTION = "load_image"
def load_image(self, image): def load_image(self, image):
input_dir = folder_paths.get_input_directory() image_path = folder_paths.get_annotated_filepath(image)
image_path = os.path.join(input_dir, image)
i = Image.open(image_path) i = Image.open(image_path)
image = i.convert("RGB") image = i.convert("RGB")
image = np.array(image).astype(np.float32) / 255.0 image = np.array(image).astype(np.float32) / 255.0
@ -989,20 +998,27 @@ class LoadImage:
@classmethod @classmethod
def IS_CHANGED(s, image): def IS_CHANGED(s, image):
input_dir = folder_paths.get_input_directory() image_path = folder_paths.get_annotated_filepath(image)
image_path = os.path.join(input_dir, image)
m = hashlib.sha256() m = hashlib.sha256()
with open(image_path, 'rb') as f: with open(image_path, 'rb') as f:
m.update(f.read()) m.update(f.read())
return m.digest().hex() return m.digest().hex()
@classmethod
def VALIDATE_INPUTS(s, image):
if not folder_paths.exists_annotated_filepath(image):
return "Invalid image file: {}".format(image)
return True
class LoadImageMask: class LoadImageMask:
_color_channels = ["alpha", "red", "green", "blue"]
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
input_dir = folder_paths.get_input_directory() input_dir = folder_paths.get_input_directory()
return {"required": return {"required":
{"image": (sorted(os.listdir(input_dir)), ), {"image": (sorted(os.listdir(input_dir)), ),
"channel": (["alpha", "red", "green", "blue"], ),} "channel": (s._color_channels, ),}
} }
CATEGORY = "mask" CATEGORY = "mask"
@ -1010,8 +1026,7 @@ class LoadImageMask:
RETURN_TYPES = ("MASK",) RETURN_TYPES = ("MASK",)
FUNCTION = "load_image" FUNCTION = "load_image"
def load_image(self, image, channel): def load_image(self, image, channel):
input_dir = folder_paths.get_input_directory() image_path = folder_paths.get_annotated_filepath(image)
image_path = os.path.join(input_dir, image)
i = Image.open(image_path) i = Image.open(image_path)
if i.getbands() != ("R", "G", "B", "A"): if i.getbands() != ("R", "G", "B", "A"):
i = i.convert("RGBA") i = i.convert("RGBA")
@ -1028,13 +1043,22 @@ class LoadImageMask:
@classmethod @classmethod
def IS_CHANGED(s, image, channel): def IS_CHANGED(s, image, channel):
input_dir = folder_paths.get_input_directory() image_path = folder_paths.get_annotated_filepath(image)
image_path = os.path.join(input_dir, image)
m = hashlib.sha256() m = hashlib.sha256()
with open(image_path, 'rb') as f: with open(image_path, 'rb') as f:
m.update(f.read()) m.update(f.read())
return m.digest().hex() return m.digest().hex()
@classmethod
def VALIDATE_INPUTS(s, image, channel):
if not folder_paths.exists_annotated_filepath(image):
return "Invalid image file: {}".format(image)
if channel not in s._color_channels:
return "Invalid color channel: {}".format(channel)
return True
class ImageScale: class ImageScale:
upscale_methods = ["nearest-exact", "bilinear", "area"] upscale_methods = ["nearest-exact", "bilinear", "area"]
crop_methods = ["disabled", "center"] crop_methods = ["disabled", "center"]
@ -1079,10 +1103,10 @@ class ImagePadForOutpaint:
return { return {
"required": { "required": {
"image": ("IMAGE",), "image": ("IMAGE",),
"left": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 64}), "left": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
"top": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 64}), "top": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
"right": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 64}), "right": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
"bottom": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 64}), "bottom": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
"feathering": ("INT", {"default": 40, "min": 0, "max": MAX_RESOLUTION, "step": 1}), "feathering": ("INT", {"default": 40, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
} }
} }
@ -1154,8 +1178,10 @@ NODE_CLASS_MAPPINGS = {
"ImageScale": ImageScale, "ImageScale": ImageScale,
"ImageInvert": ImageInvert, "ImageInvert": ImageInvert,
"ImagePadForOutpaint": ImagePadForOutpaint, "ImagePadForOutpaint": ImagePadForOutpaint,
"ConditioningAverage ": ConditioningAverage ,
"ConditioningCombine": ConditioningCombine, "ConditioningCombine": ConditioningCombine,
"ConditioningSetArea": ConditioningSetArea, "ConditioningSetArea": ConditioningSetArea,
"ConditioningSetMask": ConditioningSetMask,
"KSamplerAdvanced": KSamplerAdvanced, "KSamplerAdvanced": KSamplerAdvanced,
"SetLatentNoiseMask": SetLatentNoiseMask, "SetLatentNoiseMask": SetLatentNoiseMask,
"LatentComposite": LatentComposite, "LatentComposite": LatentComposite,
@ -1204,7 +1230,9 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"CLIPTextEncode": "CLIP Text Encode (Prompt)", "CLIPTextEncode": "CLIP Text Encode (Prompt)",
"CLIPSetLastLayer": "CLIP Set Last Layer", "CLIPSetLastLayer": "CLIP Set Last Layer",
"ConditioningCombine": "Conditioning (Combine)", "ConditioningCombine": "Conditioning (Combine)",
"ConditioningAverage ": "Conditioning (Average)",
"ConditioningSetArea": "Conditioning (Set Area)", "ConditioningSetArea": "Conditioning (Set Area)",
"ConditioningSetMask": "Conditioning (Set Mask)",
"ControlNetApply": "Apply ControlNet", "ControlNetApply": "Apply ControlNet",
# Latent # Latent
"VAEEncodeForInpaint": "VAE Encode (for Inpainting)", "VAEEncodeForInpaint": "VAE Encode (for Inpainting)",

View File

@ -47,7 +47,7 @@
" !git pull\n", " !git pull\n",
"\n", "\n",
"!echo -= Install dependencies =-\n", "!echo -= Install dependencies =-\n",
"!pip install xformers!=0.0.18 -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu118" "!pip install xformers!=0.0.18 -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu118 --extra-index-url https://download.pytorch.org/whl/cu117"
] ]
}, },
{ {

View File

@ -112,13 +112,20 @@ class PromptServer():
@routes.post("/upload/image") @routes.post("/upload/image")
async def upload_image(request): async def upload_image(request):
upload_dir = folder_paths.get_input_directory() post = await request.post()
image = post.get("image")
if post.get("type") is None:
upload_dir = folder_paths.get_input_directory()
elif post.get("type") == "input":
upload_dir = folder_paths.get_input_directory()
elif post.get("type") == "temp":
upload_dir = folder_paths.get_temp_directory()
elif post.get("type") == "output":
upload_dir = folder_paths.get_output_directory()
if not os.path.exists(upload_dir): if not os.path.exists(upload_dir):
os.makedirs(upload_dir) os.makedirs(upload_dir)
post = await request.post()
image = post.get("image")
if image and image.file: if image and image.file:
filename = image.filename filename = image.filename

View File

@ -232,10 +232,27 @@ app.registerExtension({
"name": "My Color Palette", "name": "My Color Palette",
"colors": { "colors": {
"node_slot": { "node_slot": {
},
"litegraph_base": {
},
"comfy_base": {
} }
} }
}; };
// Copy over missing keys from default color palette
const defaultColorPalette = colorPalettes[defaultColorPaletteId];
for (const key in defaultColorPalette.colors.litegraph_base) {
if (!colorPalette.colors.litegraph_base[key]) {
colorPalette.colors.litegraph_base[key] = "";
}
}
for (const key in defaultColorPalette.colors.comfy_base) {
if (!colorPalette.colors.comfy_base[key]) {
colorPalette.colors.comfy_base[key] = "";
}
}
return completeColorPalette(colorPalette); return completeColorPalette(colorPalette);
}; };

View File

@ -6,6 +6,7 @@ app.registerExtension({
name: "Comfy.SlotDefaults", name: "Comfy.SlotDefaults",
suggestionsNumber: null, suggestionsNumber: null,
init() { init() {
LiteGraph.search_filter_enabled = true;
LiteGraph.middle_click_slot_add_default_node = true; LiteGraph.middle_click_slot_add_default_node = true;
this.suggestionsNumber = app.ui.settings.addSetting({ this.suggestionsNumber = app.ui.settings.addSetting({
id: "Comfy.NodeSuggestions.number", id: "Comfy.NodeSuggestions.number",
@ -43,6 +44,14 @@ app.registerExtension({
} }
if (this.slot_types_default_out[type].includes(nodeId)) continue; if (this.slot_types_default_out[type].includes(nodeId)) continue;
this.slot_types_default_out[type].push(nodeId); this.slot_types_default_out[type].push(nodeId);
// Input types have to be stored as lower case
// Store each node that can handle this input type
const lowerType = type.toLocaleLowerCase();
if (!(lowerType in LiteGraph.registered_slot_in_types)) {
LiteGraph.registered_slot_in_types[lowerType] = { nodes: [] };
}
LiteGraph.registered_slot_in_types[lowerType].nodes.push(nodeType.comfyClass);
} }
var outputs = nodeData["output"]; var outputs = nodeData["output"];
@ -53,6 +62,16 @@ app.registerExtension({
} }
this.slot_types_default_in[type].push(nodeId); this.slot_types_default_in[type].push(nodeId);
// Store each node that can handle this output type
if (!(type in LiteGraph.registered_slot_out_types)) {
LiteGraph.registered_slot_out_types[type] = { nodes: [] };
}
LiteGraph.registered_slot_out_types[type].nodes.push(nodeType.comfyClass);
if(!LiteGraph.slot_types_out.includes(type)) {
LiteGraph.slot_types_out.push(type);
}
} }
var maxNum = this.suggestionsNumber.value; var maxNum = this.suggestionsNumber.value;
this.setDefaults(maxNum); this.setDefaults(maxNum);

View File

@ -3628,6 +3628,18 @@
return size; return size;
}; };
LGraphNode.prototype.inResizeCorner = function(canvasX, canvasY) {
var rows = this.outputs ? this.outputs.length : 1;
var outputs_offset = (this.constructor.slot_start_y || 0) + rows * LiteGraph.NODE_SLOT_HEIGHT;
return isInsideRectangle(canvasX,
canvasY,
this.pos[0] + this.size[0] - 15,
this.pos[1] + Math.max(this.size[1] - 15, outputs_offset),
20,
20
);
}
/** /**
* returns all the info available about a property of this node. * returns all the info available about a property of this node.
* *
@ -5877,14 +5889,7 @@ LGraphNode.prototype.executeAction = function(action)
if ( !this.connecting_node && !node.flags.collapsed && !this.live_mode ) { if ( !this.connecting_node && !node.flags.collapsed && !this.live_mode ) {
//Search for corner for resize //Search for corner for resize
if ( !skip_action && if ( !skip_action &&
node.resizable !== false && node.resizable !== false && node.inResizeCorner(e.canvasX, e.canvasY)
isInsideRectangle( e.canvasX,
e.canvasY,
node.pos[0] + node.size[0] - 5,
node.pos[1] + node.size[1] - 5,
10,
10
)
) { ) {
this.graph.beforeChange(); this.graph.beforeChange();
this.resizing_node = node; this.resizing_node = node;
@ -6424,16 +6429,7 @@ LGraphNode.prototype.executeAction = function(action)
//Search for corner //Search for corner
if (this.canvas) { if (this.canvas) {
if ( if (node.inResizeCorner(e.canvasX, e.canvasY)) {
isInsideRectangle(
e.canvasX,
e.canvasY,
node.pos[0] + node.size[0] - 5,
node.pos[1] + node.size[1] - 5,
5,
5
)
) {
this.canvas.style.cursor = "se-resize"; this.canvas.style.cursor = "se-resize";
} else { } else {
this.canvas.style.cursor = "crosshair"; this.canvas.style.cursor = "crosshair";
@ -9953,11 +9949,11 @@ LGraphNode.prototype.executeAction = function(action)
} }
break; break;
case "slider": case "slider":
var range = w.options.max - w.options.min; var old_value = w.value;
var nvalue = Math.clamp((x - 15) / (widget_width - 30), 0, 1); var nvalue = Math.clamp((x - 15) / (widget_width - 30), 0, 1);
if(w.options.read_only) break; if(w.options.read_only) break;
w.value = w.options.min + (w.options.max - w.options.min) * nvalue; w.value = w.options.min + (w.options.max - w.options.min) * nvalue;
if (w.callback) { if (old_value != w.value) {
setTimeout(function() { setTimeout(function() {
inner_value_change(w, w.value); inner_value_change(w, w.value);
}, 20); }, 20);
@ -10044,7 +10040,7 @@ LGraphNode.prototype.executeAction = function(action)
if (event.click_time < 200 && delta == 0) { if (event.click_time < 200 && delta == 0) {
this.prompt("Value",w.value,function(v) { this.prompt("Value",w.value,function(v) {
// check if v is a valid equation or a number // check if v is a valid equation or a number
if (/^[0-9+\-*/()\s]+$/.test(v)) { if (/^[0-9+\-*/()\s]+|\d+\.\d+$/.test(v)) {
try {//solve the equation if possible try {//solve the equation if possible
v = eval(v); v = eval(v);
} catch (e) { } } catch (e) { }

View File

@ -20,6 +20,12 @@ export class ComfyApp {
*/ */
#processingQueue = false; #processingQueue = false;
/**
* Content Clipboard
* @type {serialized node object}
*/
static clipspace = null;
constructor() { constructor() {
this.ui = new ComfyUI(this); this.ui = new ComfyUI(this);
@ -130,6 +136,83 @@ export class ComfyApp {
); );
} }
} }
options.push(
{
content: "Copy (Clipspace)",
callback: (obj) => {
var widgets = null;
if(this.widgets) {
widgets = this.widgets.map(({ type, name, value }) => ({ type, name, value }));
}
let img = new Image();
var imgs = undefined;
if(this.imgs != undefined) {
img.src = this.imgs[0].src;
imgs = [img];
}
ComfyApp.clipspace = {
'widgets': widgets,
'imgs': imgs,
'original_imgs': imgs,
'images': this.images
};
}
});
if(ComfyApp.clipspace != null) {
options.push(
{
content: "Paste (Clipspace)",
callback: () => {
if(ComfyApp.clipspace != null) {
if(ComfyApp.clipspace.widgets != null && this.widgets != null) {
ComfyApp.clipspace.widgets.forEach(({ type, name, value }) => {
const prop = Object.values(this.widgets).find(obj => obj.type === type && obj.name === name);
if (prop) {
prop.callback(value);
}
});
}
// image paste
if(ComfyApp.clipspace.imgs != undefined && this.imgs != undefined && this.widgets != null) {
var filename = "";
if(this.images && ComfyApp.clipspace.images) {
this.images = ComfyApp.clipspace.images;
}
if(ComfyApp.clipspace.images != undefined) {
const clip_image = ComfyApp.clipspace.images[0];
if(clip_image.subfolder != '')
filename = `${clip_image.subfolder}/`;
filename += `${clip_image.filename} [${clip_image.type}]`;
}
else if(ComfyApp.clipspace.widgets != undefined) {
const index_in_clip = ComfyApp.clipspace.widgets.findIndex(obj => obj.name === 'image');
if(index_in_clip >= 0) {
filename = `${ComfyApp.clipspace.widgets[index_in_clip].value}`;
}
}
const index = this.widgets.findIndex(obj => obj.name === 'image');
if(index >= 0 && filename != "" && ComfyApp.clipspace.imgs != undefined) {
this.imgs = ComfyApp.clipspace.imgs;
this.widgets[index].value = filename;
if(this.widgets_values != undefined) {
this.widgets_values[index] = filename;
}
}
}
this.trigger('changed');
}
}
}
);
}
}; };
} }
@ -888,8 +971,10 @@ export class ComfyApp {
loadGraphData(graphData) { loadGraphData(graphData) {
this.clean(); this.clean();
let reset_invalid_values = false;
if (!graphData) { if (!graphData) {
graphData = structuredClone(defaultGraph); graphData = structuredClone(defaultGraph);
reset_invalid_values = true;
} }
const missingNodeTypes = []; const missingNodeTypes = [];
@ -975,6 +1060,13 @@ export class ComfyApp {
} }
} }
} }
if (reset_invalid_values) {
if (widget.type == "combo") {
if (!widget.options.values.includes(widget.value) && widget.options.values.length > 0) {
widget.value = widget.options.values[0];
}
}
}
} }
} }

View File

@ -136,9 +136,11 @@ function addMultilineWidget(node, name, opts, app) {
left: `${t.a * margin + t.e}px`, left: `${t.a * margin + t.e}px`,
top: `${t.d * (y + widgetHeight - margin - 3) + t.f}px`, top: `${t.d * (y + widgetHeight - margin - 3) + t.f}px`,
width: `${(widgetWidth - margin * 2 - 3) * t.a}px`, width: `${(widgetWidth - margin * 2 - 3) * t.a}px`,
background: (!node.color)?'':node.color,
height: `${(this.parent.inputHeight - margin * 2 - 4) * t.d}px`, height: `${(this.parent.inputHeight - margin * 2 - 4) * t.d}px`,
position: "absolute", position: "absolute",
zIndex: 1, color: (!node.color)?'':'white',
zIndex: app.graph._nodes.indexOf(node),
fontSize: `${t.d * 10.0}px`, fontSize: `${t.d * 10.0}px`,
}); });
this.inputEl.hidden = !visible; this.inputEl.hidden = !visible;
@ -270,6 +272,9 @@ export const ComfyWidgets = {
app.graph.setDirtyCanvas(true); app.graph.setDirtyCanvas(true);
}; };
img.src = `/view?filename=${name}&type=input`; img.src = `/view?filename=${name}&type=input`;
if ((node.size[1] - node.imageOffset) < 100) {
node.size[1] = 250 + node.imageOffset;
}
} }
// Add our own callback to the combo widget to render an image when it changes // Add our own callback to the combo widget to render an image when it changes

View File

@ -120,7 +120,7 @@ body {
.comfy-menu > button, .comfy-menu > button,
.comfy-menu-btns button, .comfy-menu-btns button,
.comfy-menu .comfy-list button, .comfy-menu .comfy-list button,
.comfy-modal button{ .comfy-modal button {
color: var(--input-text); color: var(--input-text);
background-color: var(--comfy-input-bg); background-color: var(--comfy-input-bg);
border-radius: 8px; border-radius: 8px;
@ -129,6 +129,15 @@ body {
margin-top: 2px; margin-top: 2px;
} }
.comfy-menu > button:hover,
.comfy-menu-btns button:hover,
.comfy-menu .comfy-list button:hover,
.comfy-modal button:hover,
.comfy-settings-btn:hover {
filter: brightness(1.2);
cursor: pointer;
}
.comfy-menu span.drag-handle { .comfy-menu span.drag-handle {
width: 10px; width: 10px;
height: 20px; height: 20px;
@ -248,8 +257,11 @@ button.comfy-queue-btn {
} }
} }
/* Input popup */
.graphdialog { .graphdialog {
min-height: 1em; min-height: 1em;
background-color: var(--comfy-menu-bg);
} }
.graphdialog .name { .graphdialog .name {
@ -273,15 +285,66 @@ button.comfy-queue-btn {
border-radius: 12px 0 0 12px; border-radius: 12px 0 0 12px;
} }
/* Context menu */
.litegraph .litemenu-entry.has_submenu { .litegraph .litemenu-entry.has_submenu {
position: relative; position: relative;
padding-right: 20px; padding-right: 20px;
} }
.litemenu-entry.has_submenu::after { .litemenu-entry.has_submenu::after {
content: ">"; content: ">";
position: absolute; position: absolute;
top: 0; top: 0;
right: 2px; right: 2px;
} }
.litegraph.litecontextmenu,
.litegraph.litecontextmenu.dark {
z-index: 9999 !important;
background-color: var(--comfy-menu-bg) !important;
filter: brightness(95%);
}
.litegraph.litecontextmenu .litemenu-entry:hover:not(.disabled):not(.separator) {
background-color: var(--comfy-menu-bg) !important;
filter: brightness(155%);
color: var(--input-text);
}
.litegraph.litecontextmenu .litemenu-entry.submenu,
.litegraph.litecontextmenu.dark .litemenu-entry.submenu {
background-color: var(--comfy-menu-bg) !important;
color: var(--input-text);
}
.litegraph.litecontextmenu input {
background-color: var(--comfy-input-bg) !important;
color: var(--input-text) !important;
}
/* Search box */
.litegraph.litesearchbox {
z-index: 9999 !important;
background-color: var(--comfy-menu-bg) !important;
overflow: hidden;
}
.litegraph.litesearchbox input,
.litegraph.litesearchbox select {
background-color: var(--comfy-input-bg) !important;
color: var(--input-text);
}
.litegraph.lite-search-item {
color: var(--input-text);
background-color: var(--comfy-input-bg);
filter: brightness(80%);
padding-left: 0.2em;
}
.litegraph.lite-search-item.generic_type {
color: var(--input-text);
filter: brightness(50%);
}