Merge branch 'master' into image-cache

This commit is contained in:
Jairo Correa 2023-12-02 05:16:21 -03:00 committed by GitHub
commit c92f3dca73
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
57 changed files with 5588 additions and 924 deletions

9
.vscode/settings.json vendored Normal file
View File

@ -0,0 +1,9 @@
{
"path-intellisense.mappings": {
"../": "${workspaceFolder}/web/extensions/core"
},
"[python]": {
"editor.defaultFormatter": "ms-python.autopep8"
},
"python.formatting.provider": "none"
}

View File

@ -11,7 +11,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
## Features ## Features
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything. - Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
- Fully supports SD1.x, SD2.x and SDXL - Fully supports SD1.x, SD2.x, [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/) and [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/)
- Asynchronous Queue system - Asynchronous Queue system
- Many optimizations: Only re-executes the parts of the workflow that changes between executions. - Many optimizations: Only re-executes the parts of the workflow that changes between executions.
- Command line option: ```--lowvram``` to make it work on GPUs with less than 3GB vram (enabled automatically on GPUs with low vram) - Command line option: ```--lowvram``` to make it work on GPUs with less than 3GB vram (enabled automatically on GPUs with low vram)
@ -30,6 +30,8 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
- [unCLIP Models](https://comfyanonymous.github.io/ComfyUI_examples/unclip/) - [unCLIP Models](https://comfyanonymous.github.io/ComfyUI_examples/unclip/)
- [GLIGEN](https://comfyanonymous.github.io/ComfyUI_examples/gligen/) - [GLIGEN](https://comfyanonymous.github.io/ComfyUI_examples/gligen/)
- [Model Merging](https://comfyanonymous.github.io/ComfyUI_examples/model_merging/) - [Model Merging](https://comfyanonymous.github.io/ComfyUI_examples/model_merging/)
- [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/)
- [SDXL Turbo](https://comfyanonymous.github.io/ComfyUI_examples/sdturbo/)
- Latent previews with [TAESD](#how-to-show-high-quality-previews) - Latent previews with [TAESD](#how-to-show-high-quality-previews)
- Starts up very fast. - Starts up very fast.
- Works fully offline: will never download anything. - Works fully offline: will never download anything.
@ -43,6 +45,7 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
|---------------------------|--------------------------------------------------------------------------------------------------------------------| |---------------------------|--------------------------------------------------------------------------------------------------------------------|
| Ctrl + Enter | Queue up current graph for generation | | Ctrl + Enter | Queue up current graph for generation |
| Ctrl + Shift + Enter | Queue up current graph as first for generation | | Ctrl + Shift + Enter | Queue up current graph as first for generation |
| Ctrl + Z/Ctrl + Y | Undo/Redo |
| Ctrl + S | Save workflow | | Ctrl + S | Save workflow |
| Ctrl + O | Load workflow | | Ctrl + O | Load workflow |
| Ctrl + A | Select all nodes | | Ctrl + A | Select all nodes |
@ -98,6 +101,7 @@ AMD users can install rocm and pytorch with pip if you don't have it already ins
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm5.6``` ```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm5.6```
This is the command to install the nightly with ROCm 5.7 that might have some performance improvements: This is the command to install the nightly with ROCm 5.7 that might have some performance improvements:
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm5.7``` ```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm5.7```
### NVIDIA ### NVIDIA
@ -190,7 +194,7 @@ To use a textual inversion concepts/embeddings in a text prompt put them in the
Make sure you use the regular loaders/Load Checkpoint node to load checkpoints. It will auto pick the right settings depending on your GPU. Make sure you use the regular loaders/Load Checkpoint node to load checkpoints. It will auto pick the right settings depending on your GPU.
You can set this command line setting to disable the upcasting to fp32 in some cross attention operations which will increase your speed. Note that this will very likely give you black images on SD2.x models. If you use xformers this option does not do anything. You can set this command line setting to disable the upcasting to fp32 in some cross attention operations which will increase your speed. Note that this will very likely give you black images on SD2.x models. If you use xformers or pytorch attention this option does not do anything.
```--dont-upcast-attention``` ```--dont-upcast-attention```

View File

@ -54,6 +54,7 @@ class ControlNet(nn.Module):
transformer_depth_output=None, transformer_depth_output=None,
device=None, device=None,
operations=comfy.ops, operations=comfy.ops,
**kwargs,
): ):
super().__init__() super().__init__()
assert use_spatial_transformer == True, "use_spatial_transformer has to be true" assert use_spatial_transformer == True, "use_spatial_transformer has to be true"

View File

@ -62,6 +62,13 @@ fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in
fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.") fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.")
fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16.") fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16.")
fpte_group = parser.add_mutually_exclusive_group()
fpte_group.add_argument("--fp8_e4m3fn-text-enc", action="store_true", help="Store text encoder weights in fp8 (e4m3fn variant).")
fpte_group.add_argument("--fp8_e5m2-text-enc", action="store_true", help="Store text encoder weights in fp8 (e5m2 variant).")
fpte_group.add_argument("--fp16-text-enc", action="store_true", help="Store text encoder weights in fp16.")
fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Store text encoder weights in fp32.")
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.") parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize when loading models with Intel GPUs.") parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize when loading models with Intel GPUs.")

View File

@ -33,7 +33,7 @@ class ControlBase:
self.cond_hint_original = None self.cond_hint_original = None
self.cond_hint = None self.cond_hint = None
self.strength = 1.0 self.strength = 1.0
self.timestep_percent_range = (1.0, 0.0) self.timestep_percent_range = (0.0, 1.0)
self.timestep_range = None self.timestep_range = None
if device is None: if device is None:
@ -42,7 +42,7 @@ class ControlBase:
self.previous_controlnet = None self.previous_controlnet = None
self.global_average_pooling = False self.global_average_pooling = False
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(1.0, 0.0)): def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0)):
self.cond_hint_original = cond_hint self.cond_hint_original = cond_hint
self.strength = strength self.strength = strength
self.timestep_percent_range = timestep_percent_range self.timestep_percent_range = timestep_percent_range

View File

@ -858,7 +858,7 @@ def predict_eps_sigma(model, input, sigma_in, **kwargs):
return (input - model(input, sigma_in, **kwargs)) / sigma return (input - model(input, sigma_in, **kwargs)) / sigma
def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, extra_args=None, callback=None, disable=False, noise_mask=None, variant='bh1'): def sample_unipc(model, noise, image, sigmas, max_denoise, extra_args=None, callback=None, disable=False, noise_mask=None, variant='bh1'):
timesteps = sigmas.clone() timesteps = sigmas.clone()
if sigmas[-1] == 0: if sigmas[-1] == 0:
timesteps = sigmas[:] timesteps = sigmas[:]

View File

@ -750,3 +750,61 @@ def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, n
if sigmas[i + 1] > 0: if sigmas[i + 1] > 0:
x += sigmas[i + 1] * noise_sampler(sigmas[i], sigmas[i + 1]) x += sigmas[i + 1] * noise_sampler(sigmas[i], sigmas[i + 1])
return x return x
@torch.no_grad()
def sample_heunpp2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
# From MIT licensed: https://github.com/Carzit/sd-webui-samplers-scheduler/
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
s_end = sigmas[-1]
for i in trange(len(sigmas) - 1, disable=disable):
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
eps = torch.randn_like(x) * s_noise
sigma_hat = sigmas[i] * (gamma + 1)
if gamma > 0:
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
denoised = model(x, sigma_hat * s_in, **extra_args)
d = to_d(x, sigma_hat, denoised)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
dt = sigmas[i + 1] - sigma_hat
if sigmas[i + 1] == s_end:
# Euler method
x = x + d * dt
elif sigmas[i + 2] == s_end:
# Heun's method
x_2 = x + d * dt
denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
w = 2 * sigmas[0]
w2 = sigmas[i+1]/w
w1 = 1 - w2
d_prime = d * w1 + d_2 * w2
x = x + d_prime * dt
else:
# Heun++
x_2 = x + d * dt
denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
dt_2 = sigmas[i + 2] - sigmas[i + 1]
x_3 = x_2 + d_2 * dt_2
denoised_3 = model(x_3, sigmas[i + 2] * s_in, **extra_args)
d_3 = to_d(x_3, sigmas[i + 2], denoised_3)
w = 3 * sigmas[0]
w2 = sigmas[i + 1] / w
w3 = sigmas[i + 2] / w
w1 = 1 - w2 - w3
d_prime = w1 * d + w2 * d_2 + w3 * d_3
x = x + d_prime * dt
return x

View File

@ -5,8 +5,10 @@ import torch.nn.functional as F
from torch import nn, einsum from torch import nn, einsum
from einops import rearrange, repeat from einops import rearrange, repeat
from typing import Optional, Any from typing import Optional, Any
from functools import partial
from .diffusionmodules.util import checkpoint
from .diffusionmodules.util import checkpoint, AlphaBlender, timestep_embedding
from .sub_quadratic_attention import efficient_dot_product_attention from .sub_quadratic_attention import efficient_dot_product_attention
from comfy import model_management from comfy import model_management
@ -276,9 +278,20 @@ def attention_split(q, k, v, heads, mask=None):
) )
return r1 return r1
BROKEN_XFORMERS = False
try:
x_vers = xformers.__version__
#I think 0.0.23 is also broken (q with bs bigger than 65535 gives CUDA error)
BROKEN_XFORMERS = x_vers.startswith("0.0.21") or x_vers.startswith("0.0.22") or x_vers.startswith("0.0.23")
except:
pass
def attention_xformers(q, k, v, heads, mask=None): def attention_xformers(q, k, v, heads, mask=None):
b, _, dim_head = q.shape b, _, dim_head = q.shape
dim_head //= heads dim_head //= heads
if BROKEN_XFORMERS:
if b * heads > 65535:
return attention_pytorch(q, k, v, heads, mask)
q, k, v = map( q, k, v = map(
lambda t: t.unsqueeze(3) lambda t: t.unsqueeze(3)
@ -370,53 +383,72 @@ class CrossAttention(nn.Module):
class BasicTransformerBlock(nn.Module): class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, ff_in=False, inner_dim=None,
disable_self_attn=False, dtype=None, device=None, operations=comfy.ops): disable_self_attn=False, disable_temporal_crossattention=False, switch_temporal_ca_to_sa=False, dtype=None, device=None, operations=comfy.ops):
super().__init__() super().__init__()
self.ff_in = ff_in or inner_dim is not None
if inner_dim is None:
inner_dim = dim
self.is_res = inner_dim == dim
if self.ff_in:
self.norm_in = nn.LayerNorm(dim, dtype=dtype, device=device)
self.ff_in = FeedForward(dim, dim_out=inner_dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations)
self.disable_self_attn = disable_self_attn self.disable_self_attn = disable_self_attn
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, self.attn1 = CrossAttention(query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout,
context_dim=context_dim if self.disable_self_attn else None, dtype=dtype, device=device, operations=operations) # is a self-attention if not self.disable_self_attn context_dim=context_dim if self.disable_self_attn else None, dtype=dtype, device=device, operations=operations) # is a self-attention if not self.disable_self_attn
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations) self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations)
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
if disable_temporal_crossattention:
if switch_temporal_ca_to_sa:
raise ValueError
else:
self.attn2 = None
else:
context_dim_attn2 = None
if not switch_temporal_ca_to_sa:
context_dim_attn2 = context_dim
self.attn2 = CrossAttention(query_dim=inner_dim, context_dim=context_dim_attn2,
heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device, operations=operations) # is self-attn if context is none heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device, operations=operations) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim, dtype=dtype, device=device) self.norm2 = nn.LayerNorm(inner_dim, dtype=dtype, device=device)
self.norm2 = nn.LayerNorm(dim, dtype=dtype, device=device)
self.norm3 = nn.LayerNorm(dim, dtype=dtype, device=device) self.norm1 = nn.LayerNorm(inner_dim, dtype=dtype, device=device)
self.norm3 = nn.LayerNorm(inner_dim, dtype=dtype, device=device)
self.checkpoint = checkpoint self.checkpoint = checkpoint
self.n_heads = n_heads self.n_heads = n_heads
self.d_head = d_head self.d_head = d_head
self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa
def forward(self, x, context=None, transformer_options={}): def forward(self, x, context=None, transformer_options={}):
return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint) return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint)
def _forward(self, x, context=None, transformer_options={}): def _forward(self, x, context=None, transformer_options={}):
extra_options = {} extra_options = {}
block = None block = transformer_options.get("block", None)
block_index = 0 block_index = transformer_options.get("block_index", 0)
if "current_index" in transformer_options:
extra_options["transformer_index"] = transformer_options["current_index"]
if "block_index" in transformer_options:
block_index = transformer_options["block_index"]
extra_options["block_index"] = block_index
if "original_shape" in transformer_options:
extra_options["original_shape"] = transformer_options["original_shape"]
if "block" in transformer_options:
block = transformer_options["block"]
extra_options["block"] = block
if "cond_or_uncond" in transformer_options:
extra_options["cond_or_uncond"] = transformer_options["cond_or_uncond"]
if "patches" in transformer_options:
transformer_patches = transformer_options["patches"]
else:
transformer_patches = {} transformer_patches = {}
transformer_patches_replace = {}
for k in transformer_options:
if k == "patches":
transformer_patches = transformer_options[k]
elif k == "patches_replace":
transformer_patches_replace = transformer_options[k]
else:
extra_options[k] = transformer_options[k]
extra_options["n_heads"] = self.n_heads extra_options["n_heads"] = self.n_heads
extra_options["dim_head"] = self.d_head extra_options["dim_head"] = self.d_head
if "patches_replace" in transformer_options: if self.ff_in:
transformer_patches_replace = transformer_options["patches_replace"] x_skip = x
else: x = self.ff_in(self.norm_in(x))
transformer_patches_replace = {} if self.is_res:
x += x_skip
n = self.norm1(x) n = self.norm1(x)
if self.disable_self_attn: if self.disable_self_attn:
@ -465,8 +497,11 @@ class BasicTransformerBlock(nn.Module):
for p in patch: for p in patch:
x = p(x, extra_options) x = p(x, extra_options)
if self.attn2 is not None:
n = self.norm2(x) n = self.norm2(x)
if self.switch_temporal_ca_to_sa:
context_attn2 = n
else:
context_attn2 = context context_attn2 = context
value_attn2 = None value_attn2 = None
if "attn2_patch" in transformer_patches: if "attn2_patch" in transformer_patches:
@ -497,7 +532,12 @@ class BasicTransformerBlock(nn.Module):
n = p(n, extra_options) n = p(n, extra_options)
x += n x += n
x = self.ff(self.norm3(x)) + x if self.is_res:
x_skip = x
x = self.ff(self.norm3(x))
if self.is_res:
x += x_skip
return x return x
@ -565,3 +605,164 @@ class SpatialTransformer(nn.Module):
x = self.proj_out(x) x = self.proj_out(x)
return x + x_in return x + x_in
class SpatialVideoTransformer(SpatialTransformer):
def __init__(
self,
in_channels,
n_heads,
d_head,
depth=1,
dropout=0.0,
use_linear=False,
context_dim=None,
use_spatial_context=False,
timesteps=None,
merge_strategy: str = "fixed",
merge_factor: float = 0.5,
time_context_dim=None,
ff_in=False,
checkpoint=False,
time_depth=1,
disable_self_attn=False,
disable_temporal_crossattention=False,
max_time_embed_period: int = 10000,
dtype=None, device=None, operations=comfy.ops
):
super().__init__(
in_channels,
n_heads,
d_head,
depth=depth,
dropout=dropout,
use_checkpoint=checkpoint,
context_dim=context_dim,
use_linear=use_linear,
disable_self_attn=disable_self_attn,
dtype=dtype, device=device, operations=operations
)
self.time_depth = time_depth
self.depth = depth
self.max_time_embed_period = max_time_embed_period
time_mix_d_head = d_head
n_time_mix_heads = n_heads
time_mix_inner_dim = int(time_mix_d_head * n_time_mix_heads)
inner_dim = n_heads * d_head
if use_spatial_context:
time_context_dim = context_dim
self.time_stack = nn.ModuleList(
[
BasicTransformerBlock(
inner_dim,
n_time_mix_heads,
time_mix_d_head,
dropout=dropout,
context_dim=time_context_dim,
# timesteps=timesteps,
checkpoint=checkpoint,
ff_in=ff_in,
inner_dim=time_mix_inner_dim,
disable_self_attn=disable_self_attn,
disable_temporal_crossattention=disable_temporal_crossattention,
dtype=dtype, device=device, operations=operations
)
for _ in range(self.depth)
]
)
assert len(self.time_stack) == len(self.transformer_blocks)
self.use_spatial_context = use_spatial_context
self.in_channels = in_channels
time_embed_dim = self.in_channels * 4
self.time_pos_embed = nn.Sequential(
operations.Linear(self.in_channels, time_embed_dim, dtype=dtype, device=device),
nn.SiLU(),
operations.Linear(time_embed_dim, self.in_channels, dtype=dtype, device=device),
)
self.time_mixer = AlphaBlender(
alpha=merge_factor, merge_strategy=merge_strategy
)
def forward(
self,
x: torch.Tensor,
context: Optional[torch.Tensor] = None,
time_context: Optional[torch.Tensor] = None,
timesteps: Optional[int] = None,
image_only_indicator: Optional[torch.Tensor] = None,
transformer_options={}
) -> torch.Tensor:
_, _, h, w = x.shape
x_in = x
spatial_context = None
if exists(context):
spatial_context = context
if self.use_spatial_context:
assert (
context.ndim == 3
), f"n dims of spatial context should be 3 but are {context.ndim}"
if time_context is None:
time_context = context
time_context_first_timestep = time_context[::timesteps]
time_context = repeat(
time_context_first_timestep, "b ... -> (b n) ...", n=h * w
)
elif time_context is not None and not self.use_spatial_context:
time_context = repeat(time_context, "b ... -> (b n) ...", n=h * w)
if time_context.ndim == 2:
time_context = rearrange(time_context, "b c -> b 1 c")
x = self.norm(x)
if not self.use_linear:
x = self.proj_in(x)
x = rearrange(x, "b c h w -> b (h w) c")
if self.use_linear:
x = self.proj_in(x)
num_frames = torch.arange(timesteps, device=x.device)
num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
num_frames = rearrange(num_frames, "b t -> (b t)")
t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False, max_period=self.max_time_embed_period).to(x.dtype)
emb = self.time_pos_embed(t_emb)
emb = emb[:, None, :]
for it_, (block, mix_block) in enumerate(
zip(self.transformer_blocks, self.time_stack)
):
transformer_options["block_index"] = it_
x = block(
x,
context=spatial_context,
transformer_options=transformer_options,
)
x_mix = x
x_mix = x_mix + emb
B, S, C = x_mix.shape
x_mix = rearrange(x_mix, "(b t) s c -> (b s) t c", t=timesteps)
x_mix = mix_block(x_mix, context=time_context) #TODO: transformer_options
x_mix = rearrange(
x_mix, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps
)
x = self.time_mixer(x_spatial=x, x_temporal=x_mix, image_only_indicator=image_only_indicator)
if self.use_linear:
x = self.proj_out(x)
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
if not self.use_linear:
x = self.proj_out(x)
out = x + x_in
return out

View File

@ -5,6 +5,8 @@ import numpy as np
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange
from functools import partial
from .util import ( from .util import (
checkpoint, checkpoint,
@ -12,8 +14,9 @@ from .util import (
zero_module, zero_module,
normalization, normalization,
timestep_embedding, timestep_embedding,
AlphaBlender,
) )
from ..attention import SpatialTransformer from ..attention import SpatialTransformer, SpatialVideoTransformer, default
from comfy.ldm.util import exists from comfy.ldm.util import exists
import comfy.ops import comfy.ops
@ -28,6 +31,26 @@ class TimestepBlock(nn.Module):
Apply the module to `x` given `emb` timestep embeddings. Apply the module to `x` given `emb` timestep embeddings.
""" """
#This is needed because accelerate makes a copy of transformer_options which breaks "transformer_index"
def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None, time_context=None, num_video_frames=None, image_only_indicator=None):
for layer in ts:
if isinstance(layer, VideoResBlock):
x = layer(x, emb, num_video_frames, image_only_indicator)
elif isinstance(layer, TimestepBlock):
x = layer(x, emb)
elif isinstance(layer, SpatialVideoTransformer):
x = layer(x, context, time_context, num_video_frames, image_only_indicator, transformer_options)
if "transformer_index" in transformer_options:
transformer_options["transformer_index"] += 1
elif isinstance(layer, SpatialTransformer):
x = layer(x, context, transformer_options)
if "transformer_index" in transformer_options:
transformer_options["transformer_index"] += 1
elif isinstance(layer, Upsample):
x = layer(x, output_shape=output_shape)
else:
x = layer(x)
return x
class TimestepEmbedSequential(nn.Sequential, TimestepBlock): class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
""" """
@ -35,31 +58,8 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
support it as an extra input. support it as an extra input.
""" """
def forward(self, x, emb, context=None, transformer_options={}, output_shape=None): def forward(self, *args, **kwargs):
for layer in self: return forward_timestep_embed(self, *args, **kwargs)
if isinstance(layer, TimestepBlock):
x = layer(x, emb)
elif isinstance(layer, SpatialTransformer):
x = layer(x, context, transformer_options)
elif isinstance(layer, Upsample):
x = layer(x, output_shape=output_shape)
else:
x = layer(x)
return x
#This is needed because accelerate makes a copy of transformer_options which breaks "current_index"
def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None):
for layer in ts:
if isinstance(layer, TimestepBlock):
x = layer(x, emb)
elif isinstance(layer, SpatialTransformer):
x = layer(x, context, transformer_options)
transformer_options["current_index"] += 1
elif isinstance(layer, Upsample):
x = layer(x, output_shape=output_shape)
else:
x = layer(x)
return x
class Upsample(nn.Module): class Upsample(nn.Module):
""" """
@ -154,6 +154,9 @@ class ResBlock(TimestepBlock):
use_checkpoint=False, use_checkpoint=False,
up=False, up=False,
down=False, down=False,
kernel_size=3,
exchange_temb_dims=False,
skip_t_emb=False,
dtype=None, dtype=None,
device=None, device=None,
operations=comfy.ops operations=comfy.ops
@ -166,11 +169,17 @@ class ResBlock(TimestepBlock):
self.use_conv = use_conv self.use_conv = use_conv
self.use_checkpoint = use_checkpoint self.use_checkpoint = use_checkpoint
self.use_scale_shift_norm = use_scale_shift_norm self.use_scale_shift_norm = use_scale_shift_norm
self.exchange_temb_dims = exchange_temb_dims
if isinstance(kernel_size, list):
padding = [k // 2 for k in kernel_size]
else:
padding = kernel_size // 2
self.in_layers = nn.Sequential( self.in_layers = nn.Sequential(
nn.GroupNorm(32, channels, dtype=dtype, device=device), nn.GroupNorm(32, channels, dtype=dtype, device=device),
nn.SiLU(), nn.SiLU(),
operations.conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=dtype, device=device), operations.conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device),
) )
self.updown = up or down self.updown = up or down
@ -184,6 +193,11 @@ class ResBlock(TimestepBlock):
else: else:
self.h_upd = self.x_upd = nn.Identity() self.h_upd = self.x_upd = nn.Identity()
self.skip_t_emb = skip_t_emb
if self.skip_t_emb:
self.emb_layers = None
self.exchange_temb_dims = False
else:
self.emb_layers = nn.Sequential( self.emb_layers = nn.Sequential(
nn.SiLU(), nn.SiLU(),
operations.Linear( operations.Linear(
@ -196,7 +210,7 @@ class ResBlock(TimestepBlock):
nn.SiLU(), nn.SiLU(),
nn.Dropout(p=dropout), nn.Dropout(p=dropout),
zero_module( zero_module(
operations.conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1, dtype=dtype, device=device) operations.conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device)
), ),
) )
@ -204,7 +218,7 @@ class ResBlock(TimestepBlock):
self.skip_connection = nn.Identity() self.skip_connection = nn.Identity()
elif use_conv: elif use_conv:
self.skip_connection = operations.conv_nd( self.skip_connection = operations.conv_nd(
dims, channels, self.out_channels, 3, padding=1, dtype=dtype, device=device dims, channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device
) )
else: else:
self.skip_connection = operations.conv_nd(dims, channels, self.out_channels, 1, dtype=dtype, device=device) self.skip_connection = operations.conv_nd(dims, channels, self.out_channels, 1, dtype=dtype, device=device)
@ -230,19 +244,110 @@ class ResBlock(TimestepBlock):
h = in_conv(h) h = in_conv(h)
else: else:
h = self.in_layers(x) h = self.in_layers(x)
emb_out = None
if not self.skip_t_emb:
emb_out = self.emb_layers(emb).type(h.dtype) emb_out = self.emb_layers(emb).type(h.dtype)
while len(emb_out.shape) < len(h.shape): while len(emb_out.shape) < len(h.shape):
emb_out = emb_out[..., None] emb_out = emb_out[..., None]
if self.use_scale_shift_norm: if self.use_scale_shift_norm:
out_norm, out_rest = self.out_layers[0], self.out_layers[1:] out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
h = out_norm(h)
if emb_out is not None:
scale, shift = th.chunk(emb_out, 2, dim=1) scale, shift = th.chunk(emb_out, 2, dim=1)
h = out_norm(h) * (1 + scale) + shift h *= (1 + scale)
h += shift
h = out_rest(h) h = out_rest(h)
else: else:
if emb_out is not None:
if self.exchange_temb_dims:
emb_out = rearrange(emb_out, "b t c ... -> b c t ...")
h = h + emb_out h = h + emb_out
h = self.out_layers(h) h = self.out_layers(h)
return self.skip_connection(x) + h return self.skip_connection(x) + h
class VideoResBlock(ResBlock):
def __init__(
self,
channels: int,
emb_channels: int,
dropout: float,
video_kernel_size=3,
merge_strategy: str = "fixed",
merge_factor: float = 0.5,
out_channels=None,
use_conv: bool = False,
use_scale_shift_norm: bool = False,
dims: int = 2,
use_checkpoint: bool = False,
up: bool = False,
down: bool = False,
dtype=None,
device=None,
operations=comfy.ops
):
super().__init__(
channels,
emb_channels,
dropout,
out_channels=out_channels,
use_conv=use_conv,
use_scale_shift_norm=use_scale_shift_norm,
dims=dims,
use_checkpoint=use_checkpoint,
up=up,
down=down,
dtype=dtype,
device=device,
operations=operations
)
self.time_stack = ResBlock(
default(out_channels, channels),
emb_channels,
dropout=dropout,
dims=3,
out_channels=default(out_channels, channels),
use_scale_shift_norm=False,
use_conv=False,
up=False,
down=False,
kernel_size=video_kernel_size,
use_checkpoint=use_checkpoint,
exchange_temb_dims=True,
dtype=dtype,
device=device,
operations=operations
)
self.time_mixer = AlphaBlender(
alpha=merge_factor,
merge_strategy=merge_strategy,
rearrange_pattern="b t -> b 1 t 1 1",
)
def forward(
self,
x: th.Tensor,
emb: th.Tensor,
num_video_frames: int,
image_only_indicator = None,
) -> th.Tensor:
x = super().forward(x, emb)
x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames)
x = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames)
x = self.time_stack(
x, rearrange(emb, "(b t) ... -> b t ...", t=num_video_frames)
)
x = self.time_mixer(
x_spatial=x_mix, x_temporal=x, image_only_indicator=image_only_indicator
)
x = rearrange(x, "b c t h w -> (b t) c h w")
return x
class Timestep(nn.Module): class Timestep(nn.Module):
def __init__(self, dim): def __init__(self, dim):
super().__init__() super().__init__()
@ -255,7 +360,10 @@ def apply_control(h, control, name):
if control is not None and name in control and len(control[name]) > 0: if control is not None and name in control and len(control[name]) > 0:
ctrl = control[name].pop() ctrl = control[name].pop()
if ctrl is not None: if ctrl is not None:
try:
h += ctrl h += ctrl
except:
print("warning control could not be applied", h.shape, ctrl.shape)
return h return h
class UNetModel(nn.Module): class UNetModel(nn.Module):
@ -316,6 +424,16 @@ class UNetModel(nn.Module):
adm_in_channels=None, adm_in_channels=None,
transformer_depth_middle=None, transformer_depth_middle=None,
transformer_depth_output=None, transformer_depth_output=None,
use_temporal_resblock=False,
use_temporal_attention=False,
time_context_dim=None,
extra_ff_mix_layer=False,
use_spatial_context=False,
merge_strategy=None,
merge_factor=0.0,
video_kernel_size=None,
disable_temporal_crossattention=False,
max_ddpm_temb_period=10000,
device=None, device=None,
operations=comfy.ops, operations=comfy.ops,
): ):
@ -370,8 +488,12 @@ class UNetModel(nn.Module):
self.num_heads = num_heads self.num_heads = num_heads
self.num_head_channels = num_head_channels self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample self.num_heads_upsample = num_heads_upsample
self.use_temporal_resblocks = use_temporal_resblock
self.predict_codebook_ids = n_embed is not None self.predict_codebook_ids = n_embed is not None
self.default_num_video_frames = None
self.default_image_only_indicator = None
time_embed_dim = model_channels * 4 time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential( self.time_embed = nn.Sequential(
operations.Linear(model_channels, time_embed_dim, dtype=self.dtype, device=device), operations.Linear(model_channels, time_embed_dim, dtype=self.dtype, device=device),
@ -408,13 +530,104 @@ class UNetModel(nn.Module):
input_block_chans = [model_channels] input_block_chans = [model_channels]
ch = model_channels ch = model_channels
ds = 1 ds = 1
for level, mult in enumerate(channel_mult):
for nr in range(self.num_res_blocks[level]): def get_attention_layer(
layers = [ ch,
ResBlock( num_heads,
dim_head,
depth=1,
context_dim=None,
use_checkpoint=False,
disable_self_attn=False,
):
if use_temporal_attention:
return SpatialVideoTransformer(
ch,
num_heads,
dim_head,
depth=depth,
context_dim=context_dim,
time_context_dim=time_context_dim,
dropout=dropout,
ff_in=extra_ff_mix_layer,
use_spatial_context=use_spatial_context,
merge_strategy=merge_strategy,
merge_factor=merge_factor,
checkpoint=use_checkpoint,
use_linear=use_linear_in_transformer,
disable_self_attn=disable_self_attn,
disable_temporal_crossattention=disable_temporal_crossattention,
max_time_embed_period=max_ddpm_temb_period,
dtype=self.dtype, device=device, operations=operations
)
else:
return SpatialTransformer(
ch, num_heads, dim_head, depth=depth, context_dim=context_dim,
disable_self_attn=disable_self_attn, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
)
def get_resblock(
merge_factor,
merge_strategy,
video_kernel_size,
ch, ch,
time_embed_dim, time_embed_dim,
dropout, dropout,
out_channels,
dims,
use_checkpoint,
use_scale_shift_norm,
down=False,
up=False,
dtype=None,
device=None,
operations=comfy.ops
):
if self.use_temporal_resblocks:
return VideoResBlock(
merge_factor=merge_factor,
merge_strategy=merge_strategy,
video_kernel_size=video_kernel_size,
channels=ch,
emb_channels=time_embed_dim,
dropout=dropout,
out_channels=out_channels,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
down=down,
up=up,
dtype=dtype,
device=device,
operations=operations
)
else:
return ResBlock(
channels=ch,
emb_channels=time_embed_dim,
dropout=dropout,
out_channels=out_channels,
use_checkpoint=use_checkpoint,
dims=dims,
use_scale_shift_norm=use_scale_shift_norm,
down=down,
up=up,
dtype=dtype,
device=device,
operations=operations
)
for level, mult in enumerate(channel_mult):
for nr in range(self.num_res_blocks[level]):
layers = [
get_resblock(
merge_factor=merge_factor,
merge_strategy=merge_strategy,
video_kernel_size=video_kernel_size,
ch=ch,
time_embed_dim=time_embed_dim,
dropout=dropout,
out_channels=mult * model_channels, out_channels=mult * model_channels,
dims=dims, dims=dims,
use_checkpoint=use_checkpoint, use_checkpoint=use_checkpoint,
@ -441,11 +654,9 @@ class UNetModel(nn.Module):
disabled_sa = False disabled_sa = False
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]: if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
layers.append(SpatialTransformer( layers.append(get_attention_layer(
ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim, ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, disable_self_attn=disabled_sa, use_checkpoint=use_checkpoint)
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
)
) )
self.input_blocks.append(TimestepEmbedSequential(*layers)) self.input_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch self._feature_size += ch
@ -454,10 +665,13 @@ class UNetModel(nn.Module):
out_ch = ch out_ch = ch
self.input_blocks.append( self.input_blocks.append(
TimestepEmbedSequential( TimestepEmbedSequential(
ResBlock( get_resblock(
ch, merge_factor=merge_factor,
time_embed_dim, merge_strategy=merge_strategy,
dropout, video_kernel_size=video_kernel_size,
ch=ch,
time_embed_dim=time_embed_dim,
dropout=dropout,
out_channels=out_ch, out_channels=out_ch,
dims=dims, dims=dims,
use_checkpoint=use_checkpoint, use_checkpoint=use_checkpoint,
@ -487,10 +701,14 @@ class UNetModel(nn.Module):
#num_heads = 1 #num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
mid_block = [ mid_block = [
ResBlock( get_resblock(
ch, merge_factor=merge_factor,
time_embed_dim, merge_strategy=merge_strategy,
dropout, video_kernel_size=video_kernel_size,
ch=ch,
time_embed_dim=time_embed_dim,
dropout=dropout,
out_channels=None,
dims=dims, dims=dims,
use_checkpoint=use_checkpoint, use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm, use_scale_shift_norm=use_scale_shift_norm,
@ -499,15 +717,18 @@ class UNetModel(nn.Module):
operations=operations operations=operations
)] )]
if transformer_depth_middle >= 0: if transformer_depth_middle >= 0:
mid_block += [SpatialTransformer( # always uses a self-attn mid_block += [get_attention_layer( # always uses a self-attn
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim, ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, disable_self_attn=disable_middle_self_attn, use_checkpoint=use_checkpoint
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
), ),
ResBlock( get_resblock(
ch, merge_factor=merge_factor,
time_embed_dim, merge_strategy=merge_strategy,
dropout, video_kernel_size=video_kernel_size,
ch=ch,
time_embed_dim=time_embed_dim,
dropout=dropout,
out_channels=None,
dims=dims, dims=dims,
use_checkpoint=use_checkpoint, use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm, use_scale_shift_norm=use_scale_shift_norm,
@ -523,10 +744,13 @@ class UNetModel(nn.Module):
for i in range(self.num_res_blocks[level] + 1): for i in range(self.num_res_blocks[level] + 1):
ich = input_block_chans.pop() ich = input_block_chans.pop()
layers = [ layers = [
ResBlock( get_resblock(
ch + ich, merge_factor=merge_factor,
time_embed_dim, merge_strategy=merge_strategy,
dropout, video_kernel_size=video_kernel_size,
ch=ch + ich,
time_embed_dim=time_embed_dim,
dropout=dropout,
out_channels=model_channels * mult, out_channels=model_channels * mult,
dims=dims, dims=dims,
use_checkpoint=use_checkpoint, use_checkpoint=use_checkpoint,
@ -554,19 +778,21 @@ class UNetModel(nn.Module):
if not exists(num_attention_blocks) or i < num_attention_blocks[level]: if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
layers.append( layers.append(
SpatialTransformer( get_attention_layer(
ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim, ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, disable_self_attn=disabled_sa, use_checkpoint=use_checkpoint
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
) )
) )
if level and i == self.num_res_blocks[level]: if level and i == self.num_res_blocks[level]:
out_ch = ch out_ch = ch
layers.append( layers.append(
ResBlock( get_resblock(
ch, merge_factor=merge_factor,
time_embed_dim, merge_strategy=merge_strategy,
dropout, video_kernel_size=video_kernel_size,
ch=ch,
time_embed_dim=time_embed_dim,
dropout=dropout,
out_channels=out_ch, out_channels=out_ch,
dims=dims, dims=dims,
use_checkpoint=use_checkpoint, use_checkpoint=use_checkpoint,
@ -605,9 +831,13 @@ class UNetModel(nn.Module):
:return: an [N x C x ...] Tensor of outputs. :return: an [N x C x ...] Tensor of outputs.
""" """
transformer_options["original_shape"] = list(x.shape) transformer_options["original_shape"] = list(x.shape)
transformer_options["current_index"] = 0 transformer_options["transformer_index"] = 0
transformer_patches = transformer_options.get("patches", {}) transformer_patches = transformer_options.get("patches", {})
num_video_frames = kwargs.get("num_video_frames", self.default_num_video_frames)
image_only_indicator = kwargs.get("image_only_indicator", self.default_image_only_indicator)
time_context = kwargs.get("time_context", None)
assert (y is not None) == ( assert (y is not None) == (
self.num_classes is not None self.num_classes is not None
), "must specify y if and only if the model is class-conditional" ), "must specify y if and only if the model is class-conditional"
@ -622,14 +852,24 @@ class UNetModel(nn.Module):
h = x.type(self.dtype) h = x.type(self.dtype)
for id, module in enumerate(self.input_blocks): for id, module in enumerate(self.input_blocks):
transformer_options["block"] = ("input", id) transformer_options["block"] = ("input", id)
h = forward_timestep_embed(module, h, emb, context, transformer_options) h = forward_timestep_embed(module, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
h = apply_control(h, control, 'input') h = apply_control(h, control, 'input')
if "input_block_patch" in transformer_patches:
patch = transformer_patches["input_block_patch"]
for p in patch:
h = p(h, transformer_options)
hs.append(h) hs.append(h)
if "input_block_patch_after_skip" in transformer_patches:
patch = transformer_patches["input_block_patch_after_skip"]
for p in patch:
h = p(h, transformer_options)
transformer_options["block"] = ("middle", 0) transformer_options["block"] = ("middle", 0)
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options) h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
h = apply_control(h, control, 'middle') h = apply_control(h, control, 'middle')
for id, module in enumerate(self.output_blocks): for id, module in enumerate(self.output_blocks):
transformer_options["block"] = ("output", id) transformer_options["block"] = ("output", id)
hsp = hs.pop() hsp = hs.pop()
@ -646,7 +886,7 @@ class UNetModel(nn.Module):
output_shape = hs[-1].shape output_shape = hs[-1].shape
else: else:
output_shape = None output_shape = None
h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape) h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
h = h.type(x.dtype) h = h.type(x.dtype)
if self.predict_codebook_ids: if self.predict_codebook_ids:
return self.id_predictor(h) return self.id_predictor(h)

View File

@ -13,11 +13,78 @@ import math
import torch import torch
import torch.nn as nn import torch.nn as nn
import numpy as np import numpy as np
from einops import repeat from einops import repeat, rearrange
from comfy.ldm.util import instantiate_from_config from comfy.ldm.util import instantiate_from_config
import comfy.ops import comfy.ops
class AlphaBlender(nn.Module):
strategies = ["learned", "fixed", "learned_with_images"]
def __init__(
self,
alpha: float,
merge_strategy: str = "learned_with_images",
rearrange_pattern: str = "b t -> (b t) 1 1",
):
super().__init__()
self.merge_strategy = merge_strategy
self.rearrange_pattern = rearrange_pattern
assert (
merge_strategy in self.strategies
), f"merge_strategy needs to be in {self.strategies}"
if self.merge_strategy == "fixed":
self.register_buffer("mix_factor", torch.Tensor([alpha]))
elif (
self.merge_strategy == "learned"
or self.merge_strategy == "learned_with_images"
):
self.register_parameter(
"mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
)
else:
raise ValueError(f"unknown merge strategy {self.merge_strategy}")
def get_alpha(self, image_only_indicator: torch.Tensor) -> torch.Tensor:
# skip_time_mix = rearrange(repeat(skip_time_mix, 'b -> (b t) () () ()', t=t), '(b t) 1 ... -> b 1 t ...', t=t)
if self.merge_strategy == "fixed":
# make shape compatible
# alpha = repeat(self.mix_factor, '1 -> b () t () ()', t=t, b=bs)
alpha = self.mix_factor
elif self.merge_strategy == "learned":
alpha = torch.sigmoid(self.mix_factor)
# make shape compatible
# alpha = repeat(alpha, '1 -> s () ()', s = t * bs)
elif self.merge_strategy == "learned_with_images":
assert image_only_indicator is not None, "need image_only_indicator ..."
alpha = torch.where(
image_only_indicator.bool(),
torch.ones(1, 1, device=image_only_indicator.device),
rearrange(torch.sigmoid(self.mix_factor), "... -> ... 1"),
)
alpha = rearrange(alpha, self.rearrange_pattern)
# make shape compatible
# alpha = repeat(alpha, '1 -> s () ()', s = t * bs)
else:
raise NotImplementedError()
return alpha
def forward(
self,
x_spatial,
x_temporal,
image_only_indicator=None,
) -> torch.Tensor:
alpha = self.get_alpha(image_only_indicator)
x = (
alpha.to(x_spatial.dtype) * x_spatial
+ (1.0 - alpha).to(x_spatial.dtype) * x_temporal
)
return x
def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
if schedule == "linear": if schedule == "linear":
betas = ( betas = (

View File

@ -0,0 +1,244 @@
import functools
from typing import Callable, Iterable, Union
import torch
from einops import rearrange, repeat
import comfy.ops
from .diffusionmodules.model import (
AttnBlock,
Decoder,
ResnetBlock,
)
from .diffusionmodules.openaimodel import ResBlock, timestep_embedding
from .attention import BasicTransformerBlock
def partialclass(cls, *args, **kwargs):
class NewCls(cls):
__init__ = functools.partialmethod(cls.__init__, *args, **kwargs)
return NewCls
class VideoResBlock(ResnetBlock):
def __init__(
self,
out_channels,
*args,
dropout=0.0,
video_kernel_size=3,
alpha=0.0,
merge_strategy="learned",
**kwargs,
):
super().__init__(out_channels=out_channels, dropout=dropout, *args, **kwargs)
if video_kernel_size is None:
video_kernel_size = [3, 1, 1]
self.time_stack = ResBlock(
channels=out_channels,
emb_channels=0,
dropout=dropout,
dims=3,
use_scale_shift_norm=False,
use_conv=False,
up=False,
down=False,
kernel_size=video_kernel_size,
use_checkpoint=False,
skip_t_emb=True,
)
self.merge_strategy = merge_strategy
if self.merge_strategy == "fixed":
self.register_buffer("mix_factor", torch.Tensor([alpha]))
elif self.merge_strategy == "learned":
self.register_parameter(
"mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
)
else:
raise ValueError(f"unknown merge strategy {self.merge_strategy}")
def get_alpha(self, bs):
if self.merge_strategy == "fixed":
return self.mix_factor
elif self.merge_strategy == "learned":
return torch.sigmoid(self.mix_factor)
else:
raise NotImplementedError()
def forward(self, x, temb, skip_video=False, timesteps=None):
b, c, h, w = x.shape
if timesteps is None:
timesteps = b
x = super().forward(x, temb)
if not skip_video:
x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
x = self.time_stack(x, temb)
alpha = self.get_alpha(bs=b // timesteps)
x = alpha * x + (1.0 - alpha) * x_mix
x = rearrange(x, "b c t h w -> (b t) c h w")
return x
class AE3DConv(torch.nn.Conv2d):
def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs):
super().__init__(in_channels, out_channels, *args, **kwargs)
if isinstance(video_kernel_size, Iterable):
padding = [int(k // 2) for k in video_kernel_size]
else:
padding = int(video_kernel_size // 2)
self.time_mix_conv = torch.nn.Conv3d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=video_kernel_size,
padding=padding,
)
def forward(self, input, timesteps=None, skip_video=False):
if timesteps is None:
timesteps = input.shape[0]
x = super().forward(input)
if skip_video:
return x
x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
x = self.time_mix_conv(x)
return rearrange(x, "b c t h w -> (b t) c h w")
class AttnVideoBlock(AttnBlock):
def __init__(
self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned"
):
super().__init__(in_channels)
# no context, single headed, as in base class
self.time_mix_block = BasicTransformerBlock(
dim=in_channels,
n_heads=1,
d_head=in_channels,
checkpoint=False,
ff_in=True,
)
time_embed_dim = self.in_channels * 4
self.video_time_embed = torch.nn.Sequential(
comfy.ops.Linear(self.in_channels, time_embed_dim),
torch.nn.SiLU(),
comfy.ops.Linear(time_embed_dim, self.in_channels),
)
self.merge_strategy = merge_strategy
if self.merge_strategy == "fixed":
self.register_buffer("mix_factor", torch.Tensor([alpha]))
elif self.merge_strategy == "learned":
self.register_parameter(
"mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
)
else:
raise ValueError(f"unknown merge strategy {self.merge_strategy}")
def forward(self, x, timesteps=None, skip_time_block=False):
if skip_time_block:
return super().forward(x)
if timesteps is None:
timesteps = x.shape[0]
x_in = x
x = self.attention(x)
h, w = x.shape[2:]
x = rearrange(x, "b c h w -> b (h w) c")
x_mix = x
num_frames = torch.arange(timesteps, device=x.device)
num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
num_frames = rearrange(num_frames, "b t -> (b t)")
t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False)
emb = self.video_time_embed(t_emb) # b, n_channels
emb = emb[:, None, :]
x_mix = x_mix + emb
alpha = self.get_alpha()
x_mix = self.time_mix_block(x_mix, timesteps=timesteps)
x = alpha * x + (1.0 - alpha) * x_mix # alpha merge
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
x = self.proj_out(x)
return x_in + x
def get_alpha(
self,
):
if self.merge_strategy == "fixed":
return self.mix_factor
elif self.merge_strategy == "learned":
return torch.sigmoid(self.mix_factor)
else:
raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}")
def make_time_attn(
in_channels,
attn_type="vanilla",
attn_kwargs=None,
alpha: float = 0,
merge_strategy: str = "learned",
):
return partialclass(
AttnVideoBlock, in_channels, alpha=alpha, merge_strategy=merge_strategy
)
class Conv2DWrapper(torch.nn.Conv2d):
def forward(self, input: torch.Tensor, **kwargs) -> torch.Tensor:
return super().forward(input)
class VideoDecoder(Decoder):
available_time_modes = ["all", "conv-only", "attn-only"]
def __init__(
self,
*args,
video_kernel_size: Union[int, list] = 3,
alpha: float = 0.0,
merge_strategy: str = "learned",
time_mode: str = "conv-only",
**kwargs,
):
self.video_kernel_size = video_kernel_size
self.alpha = alpha
self.merge_strategy = merge_strategy
self.time_mode = time_mode
assert (
self.time_mode in self.available_time_modes
), f"time_mode parameter has to be in {self.available_time_modes}"
if self.time_mode != "attn-only":
kwargs["conv_out_op"] = partialclass(AE3DConv, video_kernel_size=self.video_kernel_size)
if self.time_mode not in ["conv-only", "only-last-conv"]:
kwargs["attn_op"] = partialclass(make_time_attn, alpha=self.alpha, merge_strategy=self.merge_strategy)
if self.time_mode not in ["attn-only", "only-last-conv"]:
kwargs["resnet_op"] = partialclass(VideoResBlock, video_kernel_size=self.video_kernel_size, alpha=self.alpha, merge_strategy=self.merge_strategy)
super().__init__(*args, **kwargs)
def get_last_layer(self, skip_time_mix=False, **kwargs):
if self.time_mode == "attn-only":
raise NotImplementedError("TODO")
else:
return (
self.conv_out.time_mix_conv.weight
if not skip_time_mix
else self.conv_out.weight
)

View File

@ -10,17 +10,22 @@ from . import utils
class ModelType(Enum): class ModelType(Enum):
EPS = 1 EPS = 1
V_PREDICTION = 2 V_PREDICTION = 2
V_PREDICTION_EDM = 3
from comfy.model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete from comfy.model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete, ModelSamplingContinuousEDM
def model_sampling(model_config, model_type): def model_sampling(model_config, model_type):
s = ModelSamplingDiscrete
if model_type == ModelType.EPS: if model_type == ModelType.EPS:
c = EPS c = EPS
elif model_type == ModelType.V_PREDICTION: elif model_type == ModelType.V_PREDICTION:
c = V_PREDICTION c = V_PREDICTION
elif model_type == ModelType.V_PREDICTION_EDM:
s = ModelSamplingDiscrete c = V_PREDICTION
s = ModelSamplingContinuousEDM
class ModelSampling(s, c): class ModelSampling(s, c):
pass pass
@ -121,6 +126,7 @@ class BaseModel(torch.nn.Module):
if k.startswith(unet_prefix): if k.startswith(unet_prefix):
to_load[k[len(unet_prefix):]] = sd.pop(k) to_load[k[len(unet_prefix):]] = sd.pop(k)
to_load = self.model_config.process_unet_state_dict(to_load)
m, u = self.diffusion_model.load_state_dict(to_load, strict=False) m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
if len(m) > 0: if len(m) > 0:
print("unet missing:", m) print("unet missing:", m)
@ -157,6 +163,17 @@ class BaseModel(torch.nn.Module):
def set_inpaint(self): def set_inpaint(self):
self.inpaint_model = True self.inpaint_model = True
def memory_required(self, input_shape):
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
#TODO: this needs to be tweaked
area = input_shape[0] * input_shape[2] * input_shape[3]
return (area * comfy.model_management.dtype_size(self.get_dtype()) / 50) * (1024 * 1024)
else:
#TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory.
area = input_shape[0] * input_shape[2] * input_shape[3]
return (((area * 0.6) / 0.9) + 1024) * (1024 * 1024)
def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0): def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0):
adm_inputs = [] adm_inputs = []
weights = [] weights = []
@ -251,3 +268,48 @@ class SDXL(BaseModel):
out.append(self.embedder(torch.Tensor([target_width]))) out.append(self.embedder(torch.Tensor([target_width])))
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1) flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1)
return torch.cat((clip_pooled.to(flat.device), flat), dim=1) return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
class SVD_img2vid(BaseModel):
def __init__(self, model_config, model_type=ModelType.V_PREDICTION_EDM, device=None):
super().__init__(model_config, model_type, device=device)
self.embedder = Timestep(256)
def encode_adm(self, **kwargs):
fps_id = kwargs.get("fps", 6) - 1
motion_bucket_id = kwargs.get("motion_bucket_id", 127)
augmentation = kwargs.get("augmentation_level", 0)
out = []
out.append(self.embedder(torch.Tensor([fps_id])))
out.append(self.embedder(torch.Tensor([motion_bucket_id])))
out.append(self.embedder(torch.Tensor([augmentation])))
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0)
return flat
def extra_conds(self, **kwargs):
out = {}
adm = self.encode_adm(**kwargs)
if adm is not None:
out['y'] = comfy.conds.CONDRegular(adm)
latent_image = kwargs.get("concat_latent_image", None)
noise = kwargs.get("noise", None)
device = kwargs["device"]
if latent_image is None:
latent_image = torch.zeros_like(noise)
if latent_image.shape[1:] != noise.shape[1:]:
latent_image = utils.common_upscale(latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center")
latent_image = utils.repeat_to_batch_size(latent_image, noise.shape[0])
out['c_concat'] = comfy.conds.CONDNoiseShape(latent_image)
if "time_conditioning" in kwargs:
out["time_context"] = comfy.conds.CONDCrossAttn(kwargs["time_conditioning"])
out['image_only_indicator'] = comfy.conds.CONDConstant(torch.zeros((1,), device=device))
out['num_video_frames'] = comfy.conds.CONDConstant(noise.shape[0])
return out

View File

@ -24,7 +24,8 @@ def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
last_transformer_depth = count_blocks(state_dict_keys, transformer_prefix + '{}') last_transformer_depth = count_blocks(state_dict_keys, transformer_prefix + '{}')
context_dim = state_dict['{}0.attn2.to_k.weight'.format(transformer_prefix)].shape[1] context_dim = state_dict['{}0.attn2.to_k.weight'.format(transformer_prefix)].shape[1]
use_linear_in_transformer = len(state_dict['{}1.proj_in.weight'.format(prefix)].shape) == 2 use_linear_in_transformer = len(state_dict['{}1.proj_in.weight'.format(prefix)].shape) == 2
return last_transformer_depth, context_dim, use_linear_in_transformer time_stack = '{}1.time_stack.0.attn1.to_q.weight'.format(prefix) in state_dict or '{}1.time_mix_blocks.0.attn1.to_q.weight'.format(prefix) in state_dict
return last_transformer_depth, context_dim, use_linear_in_transformer, time_stack
return None return None
def detect_unet_config(state_dict, key_prefix, dtype): def detect_unet_config(state_dict, key_prefix, dtype):
@ -57,6 +58,7 @@ def detect_unet_config(state_dict, key_prefix, dtype):
context_dim = None context_dim = None
use_linear_in_transformer = False use_linear_in_transformer = False
video_model = False
current_res = 1 current_res = 1
count = 0 count = 0
@ -99,6 +101,7 @@ def detect_unet_config(state_dict, key_prefix, dtype):
if context_dim is None: if context_dim is None:
context_dim = out[1] context_dim = out[1]
use_linear_in_transformer = out[2] use_linear_in_transformer = out[2]
video_model = out[3]
else: else:
transformer_depth.append(0) transformer_depth.append(0)
@ -127,6 +130,19 @@ def detect_unet_config(state_dict, key_prefix, dtype):
unet_config["transformer_depth_middle"] = transformer_depth_middle unet_config["transformer_depth_middle"] = transformer_depth_middle
unet_config['use_linear_in_transformer'] = use_linear_in_transformer unet_config['use_linear_in_transformer'] = use_linear_in_transformer
unet_config["context_dim"] = context_dim unet_config["context_dim"] = context_dim
if video_model:
unet_config["extra_ff_mix_layer"] = True
unet_config["use_spatial_context"] = True
unet_config["merge_strategy"] = "learned_with_images"
unet_config["merge_factor"] = 0.0
unet_config["video_kernel_size"] = [3, 1, 1]
unet_config["use_temporal_resblock"] = True
unet_config["use_temporal_attention"] = True
else:
unet_config["use_temporal_resblock"] = False
unet_config["use_temporal_attention"] = False
return unet_config return unet_config
def model_config_from_unet_config(unet_config): def model_config_from_unet_config(unet_config):
@ -186,17 +202,24 @@ def convert_config(unet_config):
def unet_config_from_diffusers_unet(state_dict, dtype): def unet_config_from_diffusers_unet(state_dict, dtype):
match = {} match = {}
attention_resolutions = [] transformer_depth = []
attn_res = 1 attn_res = 1
for i in range(5): down_blocks = count_blocks(state_dict, "down_blocks.{}")
k = "down_blocks.{}.attentions.1.transformer_blocks.0.attn2.to_k.weight".format(i) for i in range(down_blocks):
if k in state_dict: attn_blocks = count_blocks(state_dict, "down_blocks.{}.attentions.".format(i) + '{}')
match["context_dim"] = state_dict[k].shape[1] for ab in range(attn_blocks):
attention_resolutions.append(attn_res) transformer_count = count_blocks(state_dict, "down_blocks.{}.attentions.{}.transformer_blocks.".format(i, ab) + '{}')
attn_res *= 2 transformer_depth.append(transformer_count)
if transformer_count > 0:
match["context_dim"] = state_dict["down_blocks.{}.attentions.{}.transformer_blocks.0.attn2.to_k.weight".format(i, ab)].shape[1]
match["attention_resolutions"] = attention_resolutions attn_res *= 2
if attn_blocks == 0:
transformer_depth.append(0)
transformer_depth.append(0)
match["transformer_depth"] = transformer_depth
match["model_channels"] = state_dict["conv_in.weight"].shape[0] match["model_channels"] = state_dict["conv_in.weight"].shape[0]
match["in_channels"] = state_dict["conv_in.weight"].shape[1] match["in_channels"] = state_dict["conv_in.weight"].shape[1]
@ -208,50 +231,65 @@ def unet_config_from_diffusers_unet(state_dict, dtype):
SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4], 'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 10, 10], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 10,
'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64} 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10],
'use_temporal_attention': False, 'use_temporal_resblock': False}
SDXL_refiner = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, SDXL_refiner = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2560, 'dtype': dtype, 'in_channels': 4, 'model_channels': 384, 'num_classes': 'sequential', 'adm_in_channels': 2560, 'dtype': dtype, 'in_channels': 4, 'model_channels': 384,
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 4, 4, 0], 'channel_mult': [1, 2, 4, 4], 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [0, 0, 4, 4, 4, 4, 0, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 4,
'transformer_depth_middle': 4, 'use_linear_in_transformer': True, 'context_dim': 1280, "num_head_channels": 64} 'use_linear_in_transformer': True, 'context_dim': 1280, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 4, 4, 4, 4, 4, 4, 0, 0, 0],
'use_temporal_attention': False, 'use_temporal_resblock': False}
SD21 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, SD21 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2, 'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2],
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': True,
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, "num_head_channels": 64} 'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
'use_temporal_attention': False, 'use_temporal_resblock': False}
SD21_uncliph = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, SD21_uncliph = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2048, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_classes': 'sequential', 'adm_in_channels': 2048, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1,
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, "num_head_channels": 64} 'use_linear_in_transformer': True, 'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
'use_temporal_attention': False, 'use_temporal_resblock': False}
SD21_unclipl = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, SD21_unclipl = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 1536, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_classes': 'sequential', 'adm_in_channels': 1536, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1,
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024} 'use_linear_in_transformer': True, 'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
'use_temporal_attention': False, 'use_temporal_resblock': False}
SD15 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, SD15 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'adm_in_channels': None,
'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0],
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768, 'num_heads': 8,
'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768, "num_heads": 8} 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
'use_temporal_attention': False, 'use_temporal_resblock': False}
SDXL_mid_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, SDXL_mid_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [4], 'transformer_depth': [0, 0, 1], 'channel_mult': [1, 2, 4], 'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 0, 0, 1, 1], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 1,
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64} 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 0, 0, 0, 1, 1, 1],
'use_temporal_attention': False, 'use_temporal_resblock': False}
SDXL_small_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, SDXL_small_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [], 'transformer_depth': [0, 0, 0], 'channel_mult': [1, 2, 4], 'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 0, 0, 0, 0], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 0,
'transformer_depth_middle': 0, 'use_linear_in_transformer': True, "num_head_channels": 64, 'context_dim': 1} 'use_linear_in_transformer': True, 'num_head_channels': 64, 'context_dim': 1, 'transformer_depth_output': [0, 0, 0, 0, 0, 0, 0, 0, 0],
'use_temporal_attention': False, 'use_temporal_resblock': False}
SDXL_diffusers_inpaint = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, SDXL_diffusers_inpaint = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 9, 'model_channels': 320, 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 9, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4], 'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 10, 10], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 10,
'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64} 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10],
'use_temporal_attention': False, 'use_temporal_resblock': False}
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint] SSD_1B = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 4, 4], 'transformer_depth_output': [0, 0, 0, 1, 1, 2, 10, 4, 4],
'channel_mult': [1, 2, 4], 'transformer_depth_middle': -1, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
'use_temporal_attention': False, 'use_temporal_resblock': False}
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B]
for unet_config in supported_models: for unet_config in supported_models:
matches = True matches = True

View File

@ -133,6 +133,10 @@ else:
import xformers import xformers
import xformers.ops import xformers.ops
XFORMERS_IS_AVAILABLE = True XFORMERS_IS_AVAILABLE = True
try:
XFORMERS_IS_AVAILABLE = xformers._has_cpp_library
except:
pass
try: try:
XFORMERS_VERSION = xformers.version.__version__ XFORMERS_VERSION = xformers.version.__version__
print("xformers version:", XFORMERS_VERSION) print("xformers version:", XFORMERS_VERSION)
@ -478,6 +482,21 @@ def text_encoder_device():
else: else:
return torch.device("cpu") return torch.device("cpu")
def text_encoder_dtype(device=None):
if args.fp8_e4m3fn_text_enc:
return torch.float8_e4m3fn
elif args.fp8_e5m2_text_enc:
return torch.float8_e5m2
elif args.fp16_text_enc:
return torch.float16
elif args.fp32_text_enc:
return torch.float32
if should_use_fp16(device, prioritize_performance=False):
return torch.float16
else:
return torch.float32
def vae_device(): def vae_device():
return get_torch_device() return get_torch_device()
@ -579,27 +598,6 @@ def get_free_memory(dev=None, torch_free_too=False):
else: else:
return mem_free_total return mem_free_total
def batch_area_memory(area):
if xformers_enabled() or pytorch_attention_flash_attention():
#TODO: these formulas are copied from maximum_batch_area below
return (area / 20) * (1024 * 1024)
else:
return (((area * 0.6) / 0.9) + 1024) * (1024 * 1024)
def maximum_batch_area():
global vram_state
if vram_state == VRAMState.NO_VRAM:
return 0
memory_free = get_free_memory() / (1024 * 1024)
if xformers_enabled() or pytorch_attention_flash_attention():
#TODO: this needs to be tweaked
area = 20 * memory_free
else:
#TODO: this formula is because AMD sucks and has memory management issues which might be fixed in the future
area = ((memory_free - 1024) * 0.9) / (0.6)
return int(max(area, 0))
def cpu_mode(): def cpu_mode():
global cpu_state global cpu_state
return cpu_state == CPUState.CPU return cpu_state == CPUState.CPU

View File

@ -37,7 +37,7 @@ class ModelPatcher:
return size return size
def clone(self): def clone(self):
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device) n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, weight_inplace_update=self.weight_inplace_update)
n.patches = {} n.patches = {}
for k in self.patches: for k in self.patches:
n.patches[k] = self.patches[k][:] n.patches[k] = self.patches[k][:]
@ -52,6 +52,9 @@ class ModelPatcher:
return True return True
return False return False
def memory_required(self, input_shape):
return self.model.memory_required(input_shape=input_shape)
def set_model_sampler_cfg_function(self, sampler_cfg_function): def set_model_sampler_cfg_function(self, sampler_cfg_function):
if len(inspect.signature(sampler_cfg_function).parameters) == 3: if len(inspect.signature(sampler_cfg_function).parameters) == 3:
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
@ -93,6 +96,12 @@ class ModelPatcher:
def set_model_attn2_output_patch(self, patch): def set_model_attn2_output_patch(self, patch):
self.set_model_patch(patch, "attn2_output_patch") self.set_model_patch(patch, "attn2_output_patch")
def set_model_input_block_patch(self, patch):
self.set_model_patch(patch, "input_block_patch")
def set_model_input_block_patch_after_skip(self, patch):
self.set_model_patch(patch, "input_block_patch_after_skip")
def set_model_output_block_patch(self, patch): def set_model_output_block_patch(self, patch):
self.set_model_patch(patch, "output_block_patch") self.set_model_patch(patch, "output_block_patch")

View File

@ -1,7 +1,7 @@
import torch import torch
import numpy as np import numpy as np
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
import math
class EPS: class EPS:
def calculate_input(self, sigma, noise): def calculate_input(self, sigma, noise):
@ -24,7 +24,7 @@ class ModelSamplingDiscrete(torch.nn.Module):
super().__init__() super().__init__()
beta_schedule = "linear" beta_schedule = "linear"
if model_config is not None: if model_config is not None:
beta_schedule = model_config.beta_schedule beta_schedule = model_config.sampling_settings.get("beta_schedule", beta_schedule)
self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3) self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
self.sigma_data = 1.0 self.sigma_data = 1.0
@ -65,16 +65,65 @@ class ModelSamplingDiscrete(torch.nn.Module):
def timestep(self, sigma): def timestep(self, sigma):
log_sigma = sigma.log() log_sigma = sigma.log()
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None] dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
return dists.abs().argmin(dim=0).view(sigma.shape) return dists.abs().argmin(dim=0).view(sigma.shape).to(sigma.device)
def sigma(self, timestep): def sigma(self, timestep):
t = torch.clamp(timestep.float(), min=0, max=(len(self.sigmas) - 1)) t = torch.clamp(timestep.float().to(self.log_sigmas.device), min=0, max=(len(self.sigmas) - 1))
low_idx = t.floor().long() low_idx = t.floor().long()
high_idx = t.ceil().long() high_idx = t.ceil().long()
w = t.frac() w = t.frac()
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx] log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
return log_sigma.exp() return log_sigma.exp().to(timestep.device)
def percent_to_sigma(self, percent): def percent_to_sigma(self, percent):
return self.sigma(torch.tensor(percent * 999.0)) if percent <= 0.0:
return 999999999.9
if percent >= 1.0:
return 0.0
percent = 1.0 - percent
return self.sigma(torch.tensor(percent * 999.0)).item()
class ModelSamplingContinuousEDM(torch.nn.Module):
def __init__(self, model_config=None):
super().__init__()
self.sigma_data = 1.0
if model_config is not None:
sampling_settings = model_config.sampling_settings
else:
sampling_settings = {}
sigma_min = sampling_settings.get("sigma_min", 0.002)
sigma_max = sampling_settings.get("sigma_max", 120.0)
self.set_sigma_range(sigma_min, sigma_max)
def set_sigma_range(self, sigma_min, sigma_max):
sigmas = torch.linspace(math.log(sigma_min), math.log(sigma_max), 1000).exp()
self.register_buffer('sigmas', sigmas) #for compatibility with some schedulers
self.register_buffer('log_sigmas', sigmas.log())
@property
def sigma_min(self):
return self.sigmas[0]
@property
def sigma_max(self):
return self.sigmas[-1]
def timestep(self, sigma):
return 0.25 * sigma.log()
def sigma(self, timestep):
return (timestep / 0.25).exp()
def percent_to_sigma(self, percent):
if percent <= 0.0:
return 999999999.9
if percent >= 1.0:
return 0.0
percent = 1.0 - percent
log_sigma_min = math.log(self.sigma_min)
return math.exp((math.log(self.sigma_max) - log_sigma_min) * percent + log_sigma_min)

View File

@ -83,7 +83,7 @@ def prepare_sampling(model, noise_shape, positive, negative, noise_mask):
real_model = None real_model = None
models, inference_memory = get_additional_models(positive, negative, model.model_dtype()) models, inference_memory = get_additional_models(positive, negative, model.model_dtype())
comfy.model_management.load_models_gpu([model] + models, comfy.model_management.batch_area_memory(noise_shape[0] * noise_shape[2] * noise_shape[3]) + inference_memory) comfy.model_management.load_models_gpu([model] + models, model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory)
real_model = model.model real_model = model.model
return real_model, positive, negative, noise_mask, models return real_model, positive, negative, noise_mask, models

View File

@ -11,7 +11,7 @@ import comfy.conds
#The main sampling function shared by all the samplers #The main sampling function shared by all the samplers
#Returns denoised #Returns denoised
def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None): def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
def get_area_and_mult(conds, x_in, timestep_in): def get_area_and_mult(conds, x_in, timestep_in):
area = (x_in.shape[2], x_in.shape[3], 0, 0) area = (x_in.shape[2], x_in.shape[3], 0, 0)
strength = 1.0 strength = 1.0
@ -134,7 +134,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
return out return out
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, model_options): def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
out_cond = torch.zeros_like(x_in) out_cond = torch.zeros_like(x_in)
out_count = torch.ones_like(x_in) * 1e-37 out_count = torch.ones_like(x_in) * 1e-37
@ -170,9 +170,11 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
to_batch_temp.reverse() to_batch_temp.reverse()
to_batch = to_batch_temp[:1] to_batch = to_batch_temp[:1]
free_memory = model_management.get_free_memory(x_in.device)
for i in range(1, len(to_batch_temp) + 1): for i in range(1, len(to_batch_temp) + 1):
batch_amount = to_batch_temp[:len(to_batch_temp)//i] batch_amount = to_batch_temp[:len(to_batch_temp)//i]
if (len(batch_amount) * first_shape[0] * first_shape[2] * first_shape[3] < max_total_area): input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
if model.memory_required(input_shape) < free_memory:
to_batch = batch_amount to_batch = batch_amount
break break
@ -218,12 +220,14 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
transformer_options["patches"] = patches transformer_options["patches"] = patches
transformer_options["cond_or_uncond"] = cond_or_uncond[:] transformer_options["cond_or_uncond"] = cond_or_uncond[:]
transformer_options["sigmas"] = timestep
c['transformer_options'] = transformer_options c['transformer_options'] = transformer_options
if 'model_function_wrapper' in model_options: if 'model_function_wrapper' in model_options:
output = model_options['model_function_wrapper'](model_function, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks) output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
else: else:
output = model_function(input_x, timestep_, **c).chunk(batch_chunks) output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
del input_x del input_x
for o in range(batch_chunks): for o in range(batch_chunks):
@ -242,11 +246,10 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
return out_cond, out_uncond return out_cond, out_uncond
max_total_area = model_management.maximum_batch_area()
if math.isclose(cond_scale, 1.0): if math.isclose(cond_scale, 1.0):
uncond = None uncond = None
cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, model_options) cond, uncond = calc_cond_uncond_batch(model, cond, uncond, x, timestep, model_options)
if "sampler_cfg_function" in model_options: if "sampler_cfg_function" in model_options:
args = {"cond": x - cond, "uncond": x - uncond, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep} args = {"cond": x - cond, "uncond": x - uncond, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep}
return x - model_options["sampler_cfg_function"](args) return x - model_options["sampler_cfg_function"](args)
@ -258,7 +261,7 @@ class CFGNoisePredictor(torch.nn.Module):
super().__init__() super().__init__()
self.inner_model = model self.inner_model = model
def apply_model(self, x, timestep, cond, uncond, cond_scale, model_options={}, seed=None): def apply_model(self, x, timestep, cond, uncond, cond_scale, model_options={}, seed=None):
out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, model_options=model_options, seed=seed) out = sampling_function(self.inner_model, x, timestep, uncond, cond, cond_scale, model_options=model_options, seed=seed)
return out return out
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
return self.apply_model(*args, **kwargs) return self.apply_model(*args, **kwargs)
@ -511,23 +514,27 @@ class Sampler:
class UNIPC(Sampler): class UNIPC(Sampler):
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
return uni_pc.sample_unipc(model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=self.max_denoise(model_wrap, sigmas), extra_args=extra_args, noise_mask=denoise_mask, callback=callback, disable=disable_pbar) return uni_pc.sample_unipc(model_wrap, noise, latent_image, sigmas, max_denoise=self.max_denoise(model_wrap, sigmas), extra_args=extra_args, noise_mask=denoise_mask, callback=callback, disable=disable_pbar)
class UNIPCBH2(Sampler): class UNIPCBH2(Sampler):
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
return uni_pc.sample_unipc(model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=self.max_denoise(model_wrap, sigmas), extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2', disable=disable_pbar) return uni_pc.sample_unipc(model_wrap, noise, latent_image, sigmas, max_denoise=self.max_denoise(model_wrap, sigmas), extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2', disable=disable_pbar)
KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral", KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm"] "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm"]
def ksampler(sampler_name, extra_options={}, inpaint_options={}):
class KSAMPLER(Sampler): class KSAMPLER(Sampler):
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
self.sampler_function = sampler_function
self.extra_options = extra_options
self.inpaint_options = inpaint_options
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
extra_args["denoise_mask"] = denoise_mask extra_args["denoise_mask"] = denoise_mask
model_k = KSamplerX0Inpaint(model_wrap) model_k = KSamplerX0Inpaint(model_wrap)
model_k.latent_image = latent_image model_k.latent_image = latent_image
if inpaint_options.get("random", False): #TODO: Should this be the default? if self.inpaint_options.get("random", False): #TODO: Should this be the default?
generator = torch.manual_seed(extra_args.get("seed", 41) + 1) generator = torch.manual_seed(extra_args.get("seed", 41) + 1)
model_k.noise = torch.randn(noise.shape, generator=generator, device="cpu").to(noise.dtype).to(noise.device) model_k.noise = torch.randn(noise.shape, generator=generator, device="cpu").to(noise.dtype).to(noise.device)
else: else:
@ -543,20 +550,33 @@ def ksampler(sampler_name, extra_options={}, inpaint_options={}):
if callback is not None: if callback is not None:
k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps) k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps)
if latent_image is not None:
noise += latent_image
samples = self.sampler_function(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **self.extra_options)
return samples
def ksampler(sampler_name, extra_options={}, inpaint_options={}):
if sampler_name == "dpm_fast":
def dpm_fast_function(model, noise, sigmas, extra_args, callback, disable):
sigma_min = sigmas[-1] sigma_min = sigmas[-1]
if sigma_min == 0: if sigma_min == 0:
sigma_min = sigmas[-2] sigma_min = sigmas[-2]
total_steps = len(sigmas) - 1
if latent_image is not None: return k_diffusion_sampling.sample_dpm_fast(model, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=callback, disable=disable)
noise += latent_image sampler_function = dpm_fast_function
if sampler_name == "dpm_fast":
samples = k_diffusion_sampling.sample_dpm_fast(model_k, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=k_callback, disable=disable_pbar)
elif sampler_name == "dpm_adaptive": elif sampler_name == "dpm_adaptive":
samples = k_diffusion_sampling.sample_dpm_adaptive(model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback, disable=disable_pbar) def dpm_adaptive_function(model, noise, sigmas, extra_args, callback, disable):
sigma_min = sigmas[-1]
if sigma_min == 0:
sigma_min = sigmas[-2]
return k_diffusion_sampling.sample_dpm_adaptive(model, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=callback, disable=disable)
sampler_function = dpm_adaptive_function
else: else:
samples = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name))(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **extra_options) sampler_function = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name))
return samples
return KSAMPLER return KSAMPLER(sampler_function, extra_options, inpaint_options)
def wrap_model(model): def wrap_model(model):
model_denoise = CFGNoisePredictor(model) model_denoise = CFGNoisePredictor(model)
@ -617,11 +637,11 @@ def calculate_sigmas_scheduler(model, scheduler_name, steps):
print("error invalid scheduler", self.scheduler) print("error invalid scheduler", self.scheduler)
return sigmas return sigmas
def sampler_class(name): def sampler_object(name):
if name == "uni_pc": if name == "uni_pc":
sampler = UNIPC sampler = UNIPC()
elif name == "uni_pc_bh2": elif name == "uni_pc_bh2":
sampler = UNIPCBH2 sampler = UNIPCBH2()
elif name == "ddim": elif name == "ddim":
sampler = ksampler("euler", inpaint_options={"random": True}) sampler = ksampler("euler", inpaint_options={"random": True})
else: else:
@ -686,6 +706,6 @@ class KSampler:
else: else:
return torch.zeros_like(noise) return torch.zeros_like(noise)
sampler = sampler_class(self.sampler) sampler = sampler_object(self.sampler)
return sample(self.model, noise, positive, negative, cfg, self.device, sampler(), sigmas, self.model_options, latent_image=latent_image, denoise_mask=denoise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed) return sample(self.model, noise, positive, negative, cfg, self.device, sampler, sigmas, self.model_options, latent_image=latent_image, denoise_mask=denoise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)

View File

@ -23,6 +23,7 @@ import comfy.model_patcher
import comfy.lora import comfy.lora
import comfy.t2i_adapter.adapter import comfy.t2i_adapter.adapter
import comfy.supported_models_base import comfy.supported_models_base
import comfy.taesd.taesd
def load_model_weights(model, sd): def load_model_weights(model, sd):
m, u = model.load_state_dict(sd, strict=False) m, u = model.load_state_dict(sd, strict=False)
@ -95,10 +96,7 @@ class CLIP:
load_device = model_management.text_encoder_device() load_device = model_management.text_encoder_device()
offload_device = model_management.text_encoder_offload_device() offload_device = model_management.text_encoder_offload_device()
params['device'] = offload_device params['device'] = offload_device
if model_management.should_use_fp16(load_device, prioritize_performance=False): params['dtype'] = model_management.text_encoder_dtype(load_device)
params['dtype'] = torch.float16
else:
params['dtype'] = torch.float32
self.cond_stage_model = clip(**(params)) self.cond_stage_model = clip(**(params))
@ -157,7 +155,21 @@ class VAE:
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
sd = diffusers_convert.convert_vae_state_dict(sd) sd = diffusers_convert.convert_vae_state_dict(sd)
self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) #These are for AutoencoderKL and need tweaking (should be lower)
self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype)
if config is None: if config is None:
if "decoder.mid.block_1.mix_factor" in sd:
encoder_config = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
decoder_config = encoder_config.copy()
decoder_config["video_kernel_size"] = [3, 1, 1]
decoder_config["alpha"] = 0.0
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': encoder_config},
decoder_config={'target': "comfy.ldm.modules.temporal_ae.VideoDecoder", 'params': decoder_config})
elif "taesd_decoder.1.weight" in sd:
self.first_stage_model = comfy.taesd.taesd.TAESD()
else:
#default SD1.x/SD2.x VAE parameters #default SD1.x/SD2.x VAE parameters
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4) self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4)
@ -175,10 +187,12 @@ class VAE:
if device is None: if device is None:
device = model_management.vae_device() device = model_management.vae_device()
self.device = device self.device = device
self.offload_device = model_management.vae_offload_device() offload_device = model_management.vae_offload_device()
self.vae_dtype = model_management.vae_dtype() self.vae_dtype = model_management.vae_dtype()
self.first_stage_model.to(self.vae_dtype) self.first_stage_model.to(self.vae_dtype)
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16): def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
steps = samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap) steps = samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap)
steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap) steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap)
@ -207,10 +221,9 @@ class VAE:
return samples return samples
def decode(self, samples_in): def decode(self, samples_in):
self.first_stage_model = self.first_stage_model.to(self.device)
try: try:
memory_used = (2562 * samples_in.shape[2] * samples_in.shape[3] * 64) * 1.7 memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
model_management.free_memory(memory_used, self.device) model_management.load_models_gpu([self.patcher], memory_required=memory_used)
free_memory = model_management.get_free_memory(self.device) free_memory = model_management.get_free_memory(self.device)
batch_number = int(free_memory / memory_used) batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number) batch_number = max(1, batch_number)
@ -223,22 +236,19 @@ class VAE:
print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
pixel_samples = self.decode_tiled_(samples_in) pixel_samples = self.decode_tiled_(samples_in)
self.first_stage_model = self.first_stage_model.to(self.offload_device)
pixel_samples = pixel_samples.cpu().movedim(1,-1) pixel_samples = pixel_samples.cpu().movedim(1,-1)
return pixel_samples return pixel_samples
def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16): def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16):
self.first_stage_model = self.first_stage_model.to(self.device) model_management.load_model_gpu(self.patcher)
output = self.decode_tiled_(samples, tile_x, tile_y, overlap) output = self.decode_tiled_(samples, tile_x, tile_y, overlap)
self.first_stage_model = self.first_stage_model.to(self.offload_device)
return output.movedim(1,-1) return output.movedim(1,-1)
def encode(self, pixel_samples): def encode(self, pixel_samples):
self.first_stage_model = self.first_stage_model.to(self.device)
pixel_samples = pixel_samples.movedim(-1,1) pixel_samples = pixel_samples.movedim(-1,1)
try: try:
memory_used = (2078 * pixel_samples.shape[2] * pixel_samples.shape[3]) * 1.7 #NOTE: this constant along with the one in the decode above are estimated from the mem usage for the VAE and could change. memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
model_management.free_memory(memory_used, self.device) model_management.load_models_gpu([self.patcher], memory_required=memory_used)
free_memory = model_management.get_free_memory(self.device) free_memory = model_management.get_free_memory(self.device)
batch_number = int(free_memory / memory_used) batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number) batch_number = max(1, batch_number)
@ -251,14 +261,12 @@ class VAE:
print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
samples = self.encode_tiled_(pixel_samples) samples = self.encode_tiled_(pixel_samples)
self.first_stage_model = self.first_stage_model.to(self.offload_device)
return samples return samples
def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
self.first_stage_model = self.first_stage_model.to(self.device) model_management.load_model_gpu(self.patcher)
pixel_samples = pixel_samples.movedim(-1,1) pixel_samples = pixel_samples.movedim(-1,1)
samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap) samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap)
self.first_stage_model = self.first_stage_model.to(self.offload_device)
return samples return samples
def get_sd(self): def get_sd(self):
@ -444,6 +452,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
if output_vae: if output_vae:
vae_sd = comfy.utils.state_dict_prefix_replace(sd, {"first_stage_model.": ""}, filter_keys=True) vae_sd = comfy.utils.state_dict_prefix_replace(sd, {"first_stage_model.": ""}, filter_keys=True)
vae_sd = model_config.process_vae_state_dict(vae_sd)
vae = VAE(sd=vae_sd) vae = VAE(sd=vae_sd)
if output_clip: if output_clip:
@ -468,20 +477,18 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
return (model_patcher, clip, vae, clipvision) return (model_patcher, clip, vae, clipvision)
def load_unet(unet_path): #load unet in diffusers format def load_unet_state_dict(sd): #load unet in diffusers format
sd = comfy.utils.load_torch_file(unet_path)
parameters = comfy.utils.calculate_parameters(sd) parameters = comfy.utils.calculate_parameters(sd)
unet_dtype = model_management.unet_dtype(model_params=parameters) unet_dtype = model_management.unet_dtype(model_params=parameters)
if "input_blocks.0.0.weight" in sd: #ldm if "input_blocks.0.0.weight" in sd: #ldm
model_config = model_detection.model_config_from_unet(sd, "", unet_dtype) model_config = model_detection.model_config_from_unet(sd, "", unet_dtype)
if model_config is None: if model_config is None:
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path)) return None
new_sd = sd new_sd = sd
else: #diffusers else: #diffusers
model_config = model_detection.model_config_from_diffusers_unet(sd, unet_dtype) model_config = model_detection.model_config_from_diffusers_unet(sd, unet_dtype)
if model_config is None: if model_config is None:
print("ERROR UNSUPPORTED UNET", unet_path)
return None return None
diffusers_keys = comfy.utils.unet_to_diffusers(model_config.unet_config) diffusers_keys = comfy.utils.unet_to_diffusers(model_config.unet_config)
@ -501,6 +508,14 @@ def load_unet(unet_path): #load unet in diffusers format
print("left over keys in unet:", left_over) print("left over keys in unet:", left_over)
return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device) return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device)
def load_unet(unet_path):
sd = comfy.utils.load_torch_file(unet_path)
model = load_unet_state_dict(sd)
if model is None:
print("ERROR UNSUPPORTED UNET", unet_path)
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
return model
def save_checkpoint(output_path, model, clip, vae, metadata=None): def save_checkpoint(output_path, model, clip, vae, metadata=None):
model_management.load_models_gpu([model, clip.load_model()]) model_management.load_models_gpu([model, clip.load_model()])
sd = model.model.state_dict_for_saving(clip.get_sd(), vae.get_sd()) sd = model.model.state_dict_for_saving(clip.get_sd(), vae.get_sd())

View File

@ -173,9 +173,9 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
if getattr(self.transformer, self.inner_name).final_layer_norm.weight.dtype != torch.float32: if getattr(self.transformer, self.inner_name).final_layer_norm.weight.dtype != torch.float32:
precision_scope = torch.autocast precision_scope = torch.autocast
else: else:
precision_scope = lambda a, b: contextlib.nullcontext(a) precision_scope = lambda a, dtype: contextlib.nullcontext(a)
with precision_scope(model_management.get_autocast_device(device), torch.float32): with precision_scope(model_management.get_autocast_device(device), dtype=torch.float32):
attention_mask = None attention_mask = None
if self.enable_attention_masks: if self.enable_attention_masks:
attention_mask = torch.zeros_like(tokens) attention_mask = torch.zeros_like(tokens)

View File

@ -17,6 +17,7 @@ class SD15(supported_models_base.BASE):
"model_channels": 320, "model_channels": 320,
"use_linear_in_transformer": False, "use_linear_in_transformer": False,
"adm_in_channels": None, "adm_in_channels": None,
"use_temporal_attention": False,
} }
unet_extra_config = { unet_extra_config = {
@ -56,6 +57,7 @@ class SD20(supported_models_base.BASE):
"model_channels": 320, "model_channels": 320,
"use_linear_in_transformer": True, "use_linear_in_transformer": True,
"adm_in_channels": None, "adm_in_channels": None,
"use_temporal_attention": False,
} }
latent_format = latent_formats.SD15 latent_format = latent_formats.SD15
@ -69,6 +71,10 @@ class SD20(supported_models_base.BASE):
return model_base.ModelType.EPS return model_base.ModelType.EPS
def process_clip_state_dict(self, state_dict): def process_clip_state_dict(self, state_dict):
replace_prefix = {}
replace_prefix["conditioner.embedders.0.model."] = "cond_stage_model.model." #SD2 in sgm format
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.clip_h.transformer.text_model.", 24) state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.clip_h.transformer.text_model.", 24)
return state_dict return state_dict
@ -88,6 +94,7 @@ class SD21UnclipL(SD20):
"model_channels": 320, "model_channels": 320,
"use_linear_in_transformer": True, "use_linear_in_transformer": True,
"adm_in_channels": 1536, "adm_in_channels": 1536,
"use_temporal_attention": False,
} }
clip_vision_prefix = "embedder.model.visual." clip_vision_prefix = "embedder.model.visual."
@ -100,6 +107,7 @@ class SD21UnclipH(SD20):
"model_channels": 320, "model_channels": 320,
"use_linear_in_transformer": True, "use_linear_in_transformer": True,
"adm_in_channels": 2048, "adm_in_channels": 2048,
"use_temporal_attention": False,
} }
clip_vision_prefix = "embedder.model.visual." clip_vision_prefix = "embedder.model.visual."
@ -112,6 +120,7 @@ class SDXLRefiner(supported_models_base.BASE):
"context_dim": 1280, "context_dim": 1280,
"adm_in_channels": 2560, "adm_in_channels": 2560,
"transformer_depth": [0, 0, 4, 4, 4, 4, 0, 0], "transformer_depth": [0, 0, 4, 4, 4, 4, 0, 0],
"use_temporal_attention": False,
} }
latent_format = latent_formats.SDXL latent_format = latent_formats.SDXL
@ -148,7 +157,8 @@ class SDXL(supported_models_base.BASE):
"use_linear_in_transformer": True, "use_linear_in_transformer": True,
"transformer_depth": [0, 0, 2, 2, 10, 10], "transformer_depth": [0, 0, 2, 2, 10, 10],
"context_dim": 2048, "context_dim": 2048,
"adm_in_channels": 2816 "adm_in_channels": 2816,
"use_temporal_attention": False,
} }
latent_format = latent_formats.SDXL latent_format = latent_formats.SDXL
@ -203,8 +213,34 @@ class SSD1B(SDXL):
"use_linear_in_transformer": True, "use_linear_in_transformer": True,
"transformer_depth": [0, 0, 2, 2, 4, 4], "transformer_depth": [0, 0, 2, 2, 4, 4],
"context_dim": 2048, "context_dim": 2048,
"adm_in_channels": 2816 "adm_in_channels": 2816,
"use_temporal_attention": False,
} }
class SVD_img2vid(supported_models_base.BASE):
unet_config = {
"model_channels": 320,
"in_channels": 8,
"use_linear_in_transformer": True,
"transformer_depth": [1, 1, 1, 1, 1, 1, 0, 0],
"context_dim": 1024,
"adm_in_channels": 768,
"use_temporal_attention": True,
"use_temporal_resblock": True
}
clip_vision_prefix = "conditioner.embedders.0.open_clip.model.visual."
latent_format = latent_formats.SD15
sampling_settings = {"sigma_max": 700.0, "sigma_min": 0.002}
def get_model(self, state_dict, prefix="", device=None):
out = model_base.SVD_img2vid(self, device=device)
return out
def clip_target(self):
return None
models = [SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B] models = [SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B]
models += [SVD_img2vid]

View File

@ -19,7 +19,7 @@ class BASE:
clip_prefix = [] clip_prefix = []
clip_vision_prefix = None clip_vision_prefix = None
noise_aug_config = None noise_aug_config = None
beta_schedule = "linear" sampling_settings = {}
latent_format = latent_formats.LatentFormat latent_format = latent_formats.LatentFormat
@classmethod @classmethod
@ -53,6 +53,12 @@ class BASE:
def process_clip_state_dict(self, state_dict): def process_clip_state_dict(self, state_dict):
return state_dict return state_dict
def process_unet_state_dict(self, state_dict):
return state_dict
def process_vae_state_dict(self, state_dict):
return state_dict
def process_clip_state_dict_for_saving(self, state_dict): def process_clip_state_dict_for_saving(self, state_dict):
replace_prefix = {"": "cond_stage_model."} replace_prefix = {"": "cond_stage_model."}
return utils.state_dict_prefix_replace(state_dict, replace_prefix) return utils.state_dict_prefix_replace(state_dict, replace_prefix)

View File

@ -46,15 +46,16 @@ class TAESD(nn.Module):
latent_magnitude = 3 latent_magnitude = 3
latent_shift = 0.5 latent_shift = 0.5
def __init__(self, encoder_path="taesd_encoder.pth", decoder_path="taesd_decoder.pth"): def __init__(self, encoder_path=None, decoder_path=None):
"""Initialize pretrained TAESD on the given device from the given checkpoints.""" """Initialize pretrained TAESD on the given device from the given checkpoints."""
super().__init__() super().__init__()
self.encoder = Encoder() self.taesd_encoder = Encoder()
self.decoder = Decoder() self.taesd_decoder = Decoder()
self.vae_scale = torch.nn.Parameter(torch.tensor(1.0))
if encoder_path is not None: if encoder_path is not None:
self.encoder.load_state_dict(comfy.utils.load_torch_file(encoder_path, safe_load=True)) self.taesd_encoder.load_state_dict(comfy.utils.load_torch_file(encoder_path, safe_load=True))
if decoder_path is not None: if decoder_path is not None:
self.decoder.load_state_dict(comfy.utils.load_torch_file(decoder_path, safe_load=True)) self.taesd_decoder.load_state_dict(comfy.utils.load_torch_file(decoder_path, safe_load=True))
@staticmethod @staticmethod
def scale_latents(x): def scale_latents(x):
@ -65,3 +66,11 @@ class TAESD(nn.Module):
def unscale_latents(x): def unscale_latents(x):
"""[0, 1] -> raw latents""" """[0, 1] -> raw latents"""
return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude) return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)
def decode(self, x):
x_sample = self.taesd_decoder(x * self.vae_scale)
x_sample = x_sample.sub(0.5).mul(2)
return x_sample
def encode(self, x):
return self.taesd_encoder(x * 0.5 + 0.5) / self.vae_scale

View File

@ -258,7 +258,7 @@ def set_attr(obj, attr, value):
for name in attrs[:-1]: for name in attrs[:-1]:
obj = getattr(obj, name) obj = getattr(obj, name)
prev = getattr(obj, attrs[-1]) prev = getattr(obj, attrs[-1])
setattr(obj, attrs[-1], torch.nn.Parameter(value)) setattr(obj, attrs[-1], torch.nn.Parameter(value, requires_grad=False))
del prev del prev
def copy_to_param(obj, attr, value): def copy_to_param(obj, attr, value):
@ -307,23 +307,25 @@ def bislerp(samples, width, height):
res[dot < 1e-5 - 1] = (b1 * (1.0-r) + b2 * r)[dot < 1e-5 - 1] res[dot < 1e-5 - 1] = (b1 * (1.0-r) + b2 * r)[dot < 1e-5 - 1]
return res return res
def generate_bilinear_data(length_old, length_new): def generate_bilinear_data(length_old, length_new, device):
coords_1 = torch.arange(length_old).reshape((1,1,1,-1)).to(torch.float32) coords_1 = torch.arange(length_old, dtype=torch.float32, device=device).reshape((1,1,1,-1))
coords_1 = torch.nn.functional.interpolate(coords_1, size=(1, length_new), mode="bilinear") coords_1 = torch.nn.functional.interpolate(coords_1, size=(1, length_new), mode="bilinear")
ratios = coords_1 - coords_1.floor() ratios = coords_1 - coords_1.floor()
coords_1 = coords_1.to(torch.int64) coords_1 = coords_1.to(torch.int64)
coords_2 = torch.arange(length_old).reshape((1,1,1,-1)).to(torch.float32) + 1 coords_2 = torch.arange(length_old, dtype=torch.float32, device=device).reshape((1,1,1,-1)) + 1
coords_2[:,:,:,-1] -= 1 coords_2[:,:,:,-1] -= 1
coords_2 = torch.nn.functional.interpolate(coords_2, size=(1, length_new), mode="bilinear") coords_2 = torch.nn.functional.interpolate(coords_2, size=(1, length_new), mode="bilinear")
coords_2 = coords_2.to(torch.int64) coords_2 = coords_2.to(torch.int64)
return ratios, coords_1, coords_2 return ratios, coords_1, coords_2
orig_dtype = samples.dtype
samples = samples.float()
n,c,h,w = samples.shape n,c,h,w = samples.shape
h_new, w_new = (height, width) h_new, w_new = (height, width)
#linear w #linear w
ratios, coords_1, coords_2 = generate_bilinear_data(w, w_new) ratios, coords_1, coords_2 = generate_bilinear_data(w, w_new, samples.device)
coords_1 = coords_1.expand((n, c, h, -1)) coords_1 = coords_1.expand((n, c, h, -1))
coords_2 = coords_2.expand((n, c, h, -1)) coords_2 = coords_2.expand((n, c, h, -1))
ratios = ratios.expand((n, 1, h, -1)) ratios = ratios.expand((n, 1, h, -1))
@ -336,7 +338,7 @@ def bislerp(samples, width, height):
result = result.reshape(n, h, w_new, c).movedim(-1, 1) result = result.reshape(n, h, w_new, c).movedim(-1, 1)
#linear h #linear h
ratios, coords_1, coords_2 = generate_bilinear_data(h, h_new) ratios, coords_1, coords_2 = generate_bilinear_data(h, h_new, samples.device)
coords_1 = coords_1.reshape((1,1,-1,1)).expand((n, c, -1, w_new)) coords_1 = coords_1.reshape((1,1,-1,1)).expand((n, c, -1, w_new))
coords_2 = coords_2.reshape((1,1,-1,1)).expand((n, c, -1, w_new)) coords_2 = coords_2.reshape((1,1,-1,1)).expand((n, c, -1, w_new))
ratios = ratios.reshape((1,1,-1,1)).expand((n, 1, -1, w_new)) ratios = ratios.reshape((1,1,-1,1)).expand((n, 1, -1, w_new))
@ -347,7 +349,7 @@ def bislerp(samples, width, height):
result = slerp(pass_1, pass_2, ratios) result = slerp(pass_1, pass_2, ratios)
result = result.reshape(n, h_new, w_new, c).movedim(-1, 1) result = result.reshape(n, h_new, w_new, c).movedim(-1, 1)
return result return result.to(orig_dtype)
def lanczos(samples, width, height): def lanczos(samples, width, height):
images = [Image.fromarray(np.clip(255. * image.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8)) for image in samples] images = [Image.fromarray(np.clip(255. * image.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8)) for image in samples]

View File

@ -16,7 +16,7 @@ class BasicScheduler:
} }
} }
RETURN_TYPES = ("SIGMAS",) RETURN_TYPES = ("SIGMAS",)
CATEGORY = "sampling/custom_sampling" CATEGORY = "sampling/custom_sampling/schedulers"
FUNCTION = "get_sigmas" FUNCTION = "get_sigmas"
@ -36,7 +36,7 @@ class KarrasScheduler:
} }
} }
RETURN_TYPES = ("SIGMAS",) RETURN_TYPES = ("SIGMAS",)
CATEGORY = "sampling/custom_sampling" CATEGORY = "sampling/custom_sampling/schedulers"
FUNCTION = "get_sigmas" FUNCTION = "get_sigmas"
@ -54,7 +54,7 @@ class ExponentialScheduler:
} }
} }
RETURN_TYPES = ("SIGMAS",) RETURN_TYPES = ("SIGMAS",)
CATEGORY = "sampling/custom_sampling" CATEGORY = "sampling/custom_sampling/schedulers"
FUNCTION = "get_sigmas" FUNCTION = "get_sigmas"
@ -73,7 +73,7 @@ class PolyexponentialScheduler:
} }
} }
RETURN_TYPES = ("SIGMAS",) RETURN_TYPES = ("SIGMAS",)
CATEGORY = "sampling/custom_sampling" CATEGORY = "sampling/custom_sampling/schedulers"
FUNCTION = "get_sigmas" FUNCTION = "get_sigmas"
@ -81,6 +81,25 @@ class PolyexponentialScheduler:
sigmas = k_diffusion_sampling.get_sigmas_polyexponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho) sigmas = k_diffusion_sampling.get_sigmas_polyexponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho)
return (sigmas, ) return (sigmas, )
class SDTurboScheduler:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"model": ("MODEL",),
"steps": ("INT", {"default": 1, "min": 1, "max": 10}),
}
}
RETURN_TYPES = ("SIGMAS",)
CATEGORY = "sampling/custom_sampling/schedulers"
FUNCTION = "get_sigmas"
def get_sigmas(self, model, steps):
timesteps = torch.flip(torch.arange(1, 11) * 100 - 1, (0,))[:steps]
sigmas = model.model.model_sampling.sigma(timesteps)
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
return (sigmas, )
class VPScheduler: class VPScheduler:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -92,7 +111,7 @@ class VPScheduler:
} }
} }
RETURN_TYPES = ("SIGMAS",) RETURN_TYPES = ("SIGMAS",)
CATEGORY = "sampling/custom_sampling" CATEGORY = "sampling/custom_sampling/schedulers"
FUNCTION = "get_sigmas" FUNCTION = "get_sigmas"
@ -109,7 +128,7 @@ class SplitSigmas:
} }
} }
RETURN_TYPES = ("SIGMAS","SIGMAS") RETURN_TYPES = ("SIGMAS","SIGMAS")
CATEGORY = "sampling/custom_sampling" CATEGORY = "sampling/custom_sampling/sigmas"
FUNCTION = "get_sigmas" FUNCTION = "get_sigmas"
@ -118,6 +137,24 @@ class SplitSigmas:
sigmas2 = sigmas[step:] sigmas2 = sigmas[step:]
return (sigmas1, sigmas2) return (sigmas1, sigmas2)
class FlipSigmas:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"sigmas": ("SIGMAS", ),
}
}
RETURN_TYPES = ("SIGMAS",)
CATEGORY = "sampling/custom_sampling/sigmas"
FUNCTION = "get_sigmas"
def get_sigmas(self, sigmas):
sigmas = sigmas.flip(0)
if sigmas[0] == 0:
sigmas[0] = 0.0001
return (sigmas,)
class KSamplerSelect: class KSamplerSelect:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -126,12 +163,12 @@ class KSamplerSelect:
} }
} }
RETURN_TYPES = ("SAMPLER",) RETURN_TYPES = ("SAMPLER",)
CATEGORY = "sampling/custom_sampling" CATEGORY = "sampling/custom_sampling/samplers"
FUNCTION = "get_sampler" FUNCTION = "get_sampler"
def get_sampler(self, sampler_name): def get_sampler(self, sampler_name):
sampler = comfy.samplers.sampler_class(sampler_name)() sampler = comfy.samplers.sampler_object(sampler_name)
return (sampler, ) return (sampler, )
class SamplerDPMPP_2M_SDE: class SamplerDPMPP_2M_SDE:
@ -145,7 +182,7 @@ class SamplerDPMPP_2M_SDE:
} }
} }
RETURN_TYPES = ("SAMPLER",) RETURN_TYPES = ("SAMPLER",)
CATEGORY = "sampling/custom_sampling" CATEGORY = "sampling/custom_sampling/samplers"
FUNCTION = "get_sampler" FUNCTION = "get_sampler"
@ -154,7 +191,7 @@ class SamplerDPMPP_2M_SDE:
sampler_name = "dpmpp_2m_sde" sampler_name = "dpmpp_2m_sde"
else: else:
sampler_name = "dpmpp_2m_sde_gpu" sampler_name = "dpmpp_2m_sde_gpu"
sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "solver_type": solver_type})() sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "solver_type": solver_type})
return (sampler, ) return (sampler, )
@ -169,7 +206,7 @@ class SamplerDPMPP_SDE:
} }
} }
RETURN_TYPES = ("SAMPLER",) RETURN_TYPES = ("SAMPLER",)
CATEGORY = "sampling/custom_sampling" CATEGORY = "sampling/custom_sampling/samplers"
FUNCTION = "get_sampler" FUNCTION = "get_sampler"
@ -178,7 +215,7 @@ class SamplerDPMPP_SDE:
sampler_name = "dpmpp_sde" sampler_name = "dpmpp_sde"
else: else:
sampler_name = "dpmpp_sde_gpu" sampler_name = "dpmpp_sde_gpu"
sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r})() sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r})
return (sampler, ) return (sampler, )
class SamplerCustom: class SamplerCustom:
@ -234,13 +271,15 @@ class SamplerCustom:
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"SamplerCustom": SamplerCustom, "SamplerCustom": SamplerCustom,
"BasicScheduler": BasicScheduler,
"KarrasScheduler": KarrasScheduler, "KarrasScheduler": KarrasScheduler,
"ExponentialScheduler": ExponentialScheduler, "ExponentialScheduler": ExponentialScheduler,
"PolyexponentialScheduler": PolyexponentialScheduler, "PolyexponentialScheduler": PolyexponentialScheduler,
"VPScheduler": VPScheduler, "VPScheduler": VPScheduler,
"SDTurboScheduler": SDTurboScheduler,
"KSamplerSelect": KSamplerSelect, "KSamplerSelect": KSamplerSelect,
"SamplerDPMPP_2M_SDE": SamplerDPMPP_2M_SDE, "SamplerDPMPP_2M_SDE": SamplerDPMPP_2M_SDE,
"SamplerDPMPP_SDE": SamplerDPMPP_SDE, "SamplerDPMPP_SDE": SamplerDPMPP_SDE,
"BasicScheduler": BasicScheduler,
"SplitSigmas": SplitSigmas, "SplitSigmas": SplitSigmas,
"FlipSigmas": FlipSigmas,
} }

View File

@ -0,0 +1,175 @@
import nodes
import folder_paths
from comfy.cli_args import args
from PIL import Image
from PIL.PngImagePlugin import PngInfo
import numpy as np
import json
import os
MAX_RESOLUTION = nodes.MAX_RESOLUTION
class ImageCrop:
@classmethod
def INPUT_TYPES(s):
return {"required": { "image": ("IMAGE",),
"width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
"height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "crop"
CATEGORY = "image/transform"
def crop(self, image, width, height, x, y):
x = min(x, image.shape[2] - 1)
y = min(y, image.shape[1] - 1)
to_x = width + x
to_y = height + y
img = image[:,y:to_y, x:to_x, :]
return (img,)
class RepeatImageBatch:
@classmethod
def INPUT_TYPES(s):
return {"required": { "image": ("IMAGE",),
"amount": ("INT", {"default": 1, "min": 1, "max": 64}),
}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "repeat"
CATEGORY = "image/batch"
def repeat(self, image, amount):
s = image.repeat((amount, 1,1,1))
return (s,)
class SaveAnimatedWEBP:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
self.type = "output"
self.prefix_append = ""
methods = {"default": 4, "fastest": 0, "slowest": 6}
@classmethod
def INPUT_TYPES(s):
return {"required":
{"images": ("IMAGE", ),
"filename_prefix": ("STRING", {"default": "ComfyUI"}),
"fps": ("FLOAT", {"default": 6.0, "min": 0.01, "max": 1000.0, "step": 0.01}),
"lossless": ("BOOLEAN", {"default": True}),
"quality": ("INT", {"default": 80, "min": 0, "max": 100}),
"method": (list(s.methods.keys()),),
# "num_frames": ("INT", {"default": 0, "min": 0, "max": 8192}),
},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
RETURN_TYPES = ()
FUNCTION = "save_images"
OUTPUT_NODE = True
CATEGORY = "_for_testing"
def save_images(self, images, fps, filename_prefix, lossless, quality, method, num_frames=0, prompt=None, extra_pnginfo=None):
method = self.methods.get(method)
filename_prefix += self.prefix_append
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
results = list()
pil_images = []
for image in images:
i = 255. * image.cpu().numpy()
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
pil_images.append(img)
metadata = pil_images[0].getexif()
if not args.disable_metadata:
if prompt is not None:
metadata[0x0110] = "prompt:{}".format(json.dumps(prompt))
if extra_pnginfo is not None:
inital_exif = 0x010f
for x in extra_pnginfo:
metadata[inital_exif] = "{}:{}".format(x, json.dumps(extra_pnginfo[x]))
inital_exif -= 1
if num_frames == 0:
num_frames = len(pil_images)
c = len(pil_images)
for i in range(0, c, num_frames):
file = f"{filename}_{counter:05}_.webp"
pil_images[i].save(os.path.join(full_output_folder, file), save_all=True, duration=int(1000.0/fps), append_images=pil_images[i + 1:i + num_frames], exif=metadata, lossless=lossless, quality=quality, method=method)
results.append({
"filename": file,
"subfolder": subfolder,
"type": self.type
})
counter += 1
animated = num_frames != 1
return { "ui": { "images": results, "animated": (animated,) } }
class SaveAnimatedPNG:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
self.type = "output"
self.prefix_append = ""
@classmethod
def INPUT_TYPES(s):
return {"required":
{"images": ("IMAGE", ),
"filename_prefix": ("STRING", {"default": "ComfyUI"}),
"fps": ("FLOAT", {"default": 6.0, "min": 0.01, "max": 1000.0, "step": 0.01}),
"compress_level": ("INT", {"default": 4, "min": 0, "max": 9})
},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
RETURN_TYPES = ()
FUNCTION = "save_images"
OUTPUT_NODE = True
CATEGORY = "_for_testing"
def save_images(self, images, fps, compress_level, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
filename_prefix += self.prefix_append
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
results = list()
pil_images = []
for image in images:
i = 255. * image.cpu().numpy()
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
pil_images.append(img)
metadata = None
if not args.disable_metadata:
metadata = PngInfo()
if prompt is not None:
metadata.add(b"comf", "prompt".encode("latin-1", "strict") + b"\0" + json.dumps(prompt).encode("latin-1", "strict"), after_idat=True)
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata.add(b"comf", x.encode("latin-1", "strict") + b"\0" + json.dumps(extra_pnginfo[x]).encode("latin-1", "strict"), after_idat=True)
file = f"{filename}_{counter:05}_.png"
pil_images[0].save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=compress_level, save_all=True, duration=int(1000.0/fps), append_images=pil_images[1:])
results.append({
"filename": file,
"subfolder": subfolder,
"type": self.type
})
return { "ui": { "images": results, "animated": (True,)} }
NODE_CLASS_MAPPINGS = {
"ImageCrop": ImageCrop,
"RepeatImageBatch": RepeatImageBatch,
"SaveAnimatedWEBP": SaveAnimatedWEBP,
"SaveAnimatedPNG": SaveAnimatedPNG,
}

View File

@ -1,4 +1,5 @@
import comfy.utils import comfy.utils
import torch
def reshape_latent_to(target_shape, latent): def reshape_latent_to(target_shape, latent):
if latent.shape[1:] != target_shape[1:]: if latent.shape[1:] != target_shape[1:]:
@ -67,8 +68,43 @@ class LatentMultiply:
samples_out["samples"] = s1 * multiplier samples_out["samples"] = s1 * multiplier
return (samples_out,) return (samples_out,)
class LatentInterpolate:
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples1": ("LATENT",),
"samples2": ("LATENT",),
"ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "op"
CATEGORY = "latent/advanced"
def op(self, samples1, samples2, ratio):
samples_out = samples1.copy()
s1 = samples1["samples"]
s2 = samples2["samples"]
s2 = reshape_latent_to(s1.shape, s2)
m1 = torch.linalg.vector_norm(s1, dim=(1))
m2 = torch.linalg.vector_norm(s2, dim=(1))
s1 = torch.nan_to_num(s1 / m1)
s2 = torch.nan_to_num(s2 / m2)
t = (s1 * ratio + s2 * (1.0 - ratio))
mt = torch.linalg.vector_norm(t, dim=(1))
st = torch.nan_to_num(t / mt)
samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio))
return (samples_out,)
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"LatentAdd": LatentAdd, "LatentAdd": LatentAdd,
"LatentSubtract": LatentSubtract, "LatentSubtract": LatentSubtract,
"LatentMultiply": LatentMultiply, "LatentMultiply": LatentMultiply,
"LatentInterpolate": LatentInterpolate,
} }

View File

@ -17,7 +17,9 @@ class LCM(comfy.model_sampling.EPS):
return c_out * x0 + c_skip * model_input return c_out * x0 + c_skip * model_input
class ModelSamplingDiscreteLCM(torch.nn.Module): class ModelSamplingDiscreteDistilled(torch.nn.Module):
original_timesteps = 50
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.sigma_data = 1.0 self.sigma_data = 1.0
@ -29,13 +31,12 @@ class ModelSamplingDiscreteLCM(torch.nn.Module):
alphas = 1.0 - betas alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0) alphas_cumprod = torch.cumprod(alphas, dim=0)
original_timesteps = 50 self.skip_steps = timesteps // self.original_timesteps
self.skip_steps = timesteps // original_timesteps
alphas_cumprod_valid = torch.zeros((original_timesteps), dtype=torch.float32) alphas_cumprod_valid = torch.zeros((self.original_timesteps), dtype=torch.float32)
for x in range(original_timesteps): for x in range(self.original_timesteps):
alphas_cumprod_valid[original_timesteps - 1 - x] = alphas_cumprod[timesteps - 1 - x * self.skip_steps] alphas_cumprod_valid[self.original_timesteps - 1 - x] = alphas_cumprod[timesteps - 1 - x * self.skip_steps]
sigmas = ((1 - alphas_cumprod_valid) / alphas_cumprod_valid) ** 0.5 sigmas = ((1 - alphas_cumprod_valid) / alphas_cumprod_valid) ** 0.5
self.set_sigmas(sigmas) self.set_sigmas(sigmas)
@ -55,18 +56,23 @@ class ModelSamplingDiscreteLCM(torch.nn.Module):
def timestep(self, sigma): def timestep(self, sigma):
log_sigma = sigma.log() log_sigma = sigma.log()
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None] dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
return dists.abs().argmin(dim=0).view(sigma.shape) * self.skip_steps + (self.skip_steps - 1) return (dists.abs().argmin(dim=0).view(sigma.shape) * self.skip_steps + (self.skip_steps - 1)).to(sigma.device)
def sigma(self, timestep): def sigma(self, timestep):
t = torch.clamp(((timestep - (self.skip_steps - 1)) / self.skip_steps).float(), min=0, max=(len(self.sigmas) - 1)) t = torch.clamp(((timestep.float().to(self.log_sigmas.device) - (self.skip_steps - 1)) / self.skip_steps).float(), min=0, max=(len(self.sigmas) - 1))
low_idx = t.floor().long() low_idx = t.floor().long()
high_idx = t.ceil().long() high_idx = t.ceil().long()
w = t.frac() w = t.frac()
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx] log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
return log_sigma.exp() return log_sigma.exp().to(timestep.device)
def percent_to_sigma(self, percent): def percent_to_sigma(self, percent):
return self.sigma(torch.tensor(percent * 999.0)) if percent <= 0.0:
return 999999999.9
if percent >= 1.0:
return 0.0
percent = 1.0 - percent
return self.sigma(torch.tensor(percent * 999.0)).item()
def rescale_zero_terminal_snr_sigmas(sigmas): def rescale_zero_terminal_snr_sigmas(sigmas):
@ -111,7 +117,7 @@ class ModelSamplingDiscrete:
sampling_type = comfy.model_sampling.V_PREDICTION sampling_type = comfy.model_sampling.V_PREDICTION
elif sampling == "lcm": elif sampling == "lcm":
sampling_type = LCM sampling_type = LCM
sampling_base = ModelSamplingDiscreteLCM sampling_base = ModelSamplingDiscreteDistilled
class ModelSamplingAdvanced(sampling_base, sampling_type): class ModelSamplingAdvanced(sampling_base, sampling_type):
pass pass
@ -123,6 +129,36 @@ class ModelSamplingDiscrete:
m.add_object_patch("model_sampling", model_sampling) m.add_object_patch("model_sampling", model_sampling)
return (m, ) return (m, )
class ModelSamplingContinuousEDM:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"sampling": (["v_prediction", "eps"],),
"sigma_max": ("FLOAT", {"default": 120.0, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
"sigma_min": ("FLOAT", {"default": 0.002, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "advanced/model"
def patch(self, model, sampling, sigma_max, sigma_min):
m = model.clone()
if sampling == "eps":
sampling_type = comfy.model_sampling.EPS
elif sampling == "v_prediction":
sampling_type = comfy.model_sampling.V_PREDICTION
class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingContinuousEDM, sampling_type):
pass
model_sampling = ModelSamplingAdvanced()
model_sampling.set_sigma_range(sigma_min, sigma_max)
m.add_object_patch("model_sampling", model_sampling)
return (m, )
class RescaleCFG: class RescaleCFG:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -164,5 +200,6 @@ class RescaleCFG:
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"ModelSamplingDiscrete": ModelSamplingDiscrete, "ModelSamplingDiscrete": ModelSamplingDiscrete,
"ModelSamplingContinuousEDM": ModelSamplingContinuousEDM,
"RescaleCFG": RescaleCFG, "RescaleCFG": RescaleCFG,
} }

View File

@ -0,0 +1,53 @@
import torch
import comfy.utils
class PatchModelAddDownscale:
upscale_methods = ["bicubic", "nearest-exact", "bilinear", "area", "bislerp"]
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"block_number": ("INT", {"default": 3, "min": 1, "max": 32, "step": 1}),
"downscale_factor": ("FLOAT", {"default": 2.0, "min": 0.1, "max": 9.0, "step": 0.001}),
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
"end_percent": ("FLOAT", {"default": 0.35, "min": 0.0, "max": 1.0, "step": 0.001}),
"downscale_after_skip": ("BOOLEAN", {"default": True}),
"downscale_method": (s.upscale_methods,),
"upscale_method": (s.upscale_methods,),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "_for_testing"
def patch(self, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip, downscale_method, upscale_method):
sigma_start = model.model.model_sampling.percent_to_sigma(start_percent)
sigma_end = model.model.model_sampling.percent_to_sigma(end_percent)
def input_block_patch(h, transformer_options):
if transformer_options["block"][1] == block_number:
sigma = transformer_options["sigmas"][0].item()
if sigma <= sigma_start and sigma >= sigma_end:
h = comfy.utils.common_upscale(h, round(h.shape[-1] * (1.0 / downscale_factor)), round(h.shape[-2] * (1.0 / downscale_factor)), downscale_method, "disabled")
return h
def output_block_patch(h, hsp, transformer_options):
if h.shape[2] != hsp.shape[2]:
h = comfy.utils.common_upscale(h, hsp.shape[-1], hsp.shape[-2], upscale_method, "disabled")
return h, hsp
m = model.clone()
if downscale_after_skip:
m.set_model_input_block_patch_after_skip(input_block_patch)
else:
m.set_model_input_block_patch(input_block_patch)
m.set_model_output_block_patch(output_block_patch)
return (m, )
NODE_CLASS_MAPPINGS = {
"PatchModelAddDownscale": PatchModelAddDownscale,
}
NODE_DISPLAY_NAME_MAPPINGS = {
# Sampling
"PatchModelAddDownscale": "PatchModelAddDownscale (Kohya Deep Shrink)",
}

View File

@ -0,0 +1,89 @@
import nodes
import torch
import comfy.utils
import comfy.sd
import folder_paths
class ImageOnlyCheckpointLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
}}
RETURN_TYPES = ("MODEL", "CLIP_VISION", "VAE")
FUNCTION = "load_checkpoint"
CATEGORY = "loaders/video_models"
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=False, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
return (out[0], out[3], out[2])
class SVD_img2vid_Conditioning:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_vision": ("CLIP_VISION",),
"init_image": ("IMAGE",),
"vae": ("VAE",),
"width": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 576, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"video_frames": ("INT", {"default": 14, "min": 1, "max": 4096}),
"motion_bucket_id": ("INT", {"default": 127, "min": 1, "max": 1023}),
"fps": ("INT", {"default": 6, "min": 1, "max": 1024}),
"augmentation_level": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.01})
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
FUNCTION = "encode"
CATEGORY = "conditioning/video_models"
def encode(self, clip_vision, init_image, vae, width, height, video_frames, motion_bucket_id, fps, augmentation_level):
output = clip_vision.encode_image(init_image)
pooled = output.image_embeds.unsqueeze(0)
pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1)
encode_pixels = pixels[:,:,:,:3]
if augmentation_level > 0:
encode_pixels += torch.randn_like(pixels) * augmentation_level
t = vae.encode(encode_pixels)
positive = [[pooled, {"motion_bucket_id": motion_bucket_id, "fps": fps, "augmentation_level": augmentation_level, "concat_latent_image": t}]]
negative = [[torch.zeros_like(pooled), {"motion_bucket_id": motion_bucket_id, "fps": fps, "augmentation_level": augmentation_level, "concat_latent_image": torch.zeros_like(t)}]]
latent = torch.zeros([video_frames, 4, height // 8, width // 8])
return (positive, negative, {"samples":latent})
class VideoLinearCFGGuidance:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"min_cfg": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "sampling/video_models"
def patch(self, model, min_cfg):
def linear_cfg(args):
cond = args["cond"]
uncond = args["uncond"]
cond_scale = args["cond_scale"]
scale = torch.linspace(min_cfg, cond_scale, cond.shape[0], device=cond.device).reshape((cond.shape[0], 1, 1, 1))
return uncond + scale * (cond - uncond)
m = model.clone()
m.set_model_sampler_cfg_function(linear_cfg)
return (m, )
NODE_CLASS_MAPPINGS = {
"ImageOnlyCheckpointLoader": ImageOnlyCheckpointLoader,
"SVD_img2vid_Conditioning": SVD_img2vid_Conditioning,
"VideoLinearCFGGuidance": VideoLinearCFGGuidance,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"ImageOnlyCheckpointLoader": "Image Only Checkpoint Loader (img2vid model)",
}

View File

@ -681,6 +681,7 @@ def validate_prompt(prompt):
return (True, None, list(good_outputs), node_errors) return (True, None, list(good_outputs), node_errors)
MAXIMUM_HISTORY_SIZE = 10000
class PromptQueue: class PromptQueue:
def __init__(self, server): def __init__(self, server):
@ -699,10 +700,12 @@ class PromptQueue:
self.server.queue_updated() self.server.queue_updated()
self.not_empty.notify() self.not_empty.notify()
def get(self): def get(self, timeout=None):
with self.not_empty: with self.not_empty:
while len(self.queue) == 0: while len(self.queue) == 0:
self.not_empty.wait() self.not_empty.wait(timeout=timeout)
if timeout is not None and len(self.queue) == 0:
return None
item = heapq.heappop(self.queue) item = heapq.heappop(self.queue)
i = self.task_counter i = self.task_counter
self.currently_running[i] = copy.deepcopy(item) self.currently_running[i] = copy.deepcopy(item)
@ -713,6 +716,8 @@ class PromptQueue:
def task_done(self, item_id, outputs): def task_done(self, item_id, outputs):
with self.mutex: with self.mutex:
prompt = self.currently_running.pop(item_id) prompt = self.currently_running.pop(item_id)
if len(self.history) > MAXIMUM_HISTORY_SIZE:
self.history.pop(next(iter(self.history)))
self.history[prompt[1]] = { "prompt": prompt, "outputs": {} } self.history[prompt[1]] = { "prompt": prompt, "outputs": {} }
for o in outputs: for o in outputs:
self.history[prompt[1]]["outputs"][o] = outputs[o] self.history[prompt[1]]["outputs"][o] = outputs[o]
@ -747,10 +752,20 @@ class PromptQueue:
return True return True
return False return False
def get_history(self, prompt_id=None): def get_history(self, prompt_id=None, max_items=None, offset=-1):
with self.mutex: with self.mutex:
if prompt_id is None: if prompt_id is None:
return copy.deepcopy(self.history) out = {}
i = 0
if offset < 0 and max_items is not None:
offset = len(self.history) - max_items
for k in self.history:
if i >= offset:
out[k] = self.history[k]
if max_items is not None and len(out) >= max_items:
break
i += 1
return out
elif prompt_id in self.history: elif prompt_id in self.history:
return {prompt_id: copy.deepcopy(self.history[prompt_id])} return {prompt_id: copy.deepcopy(self.history[prompt_id])}
else: else:

View File

@ -38,7 +38,10 @@ input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "inp
filename_list_cache = {} filename_list_cache = {}
if not os.path.exists(input_directory): if not os.path.exists(input_directory):
try:
os.makedirs(input_directory) os.makedirs(input_directory)
except:
print("Failed to create input directory")
def set_output_directory(output_dir): def set_output_directory(output_dir):
global output_directory global output_directory
@ -228,8 +231,12 @@ def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height
full_output_folder = os.path.join(output_dir, subfolder) full_output_folder = os.path.join(output_dir, subfolder)
if os.path.commonpath((output_dir, os.path.abspath(full_output_folder))) != output_dir: if os.path.commonpath((output_dir, os.path.abspath(full_output_folder))) != output_dir:
print("Saving image outside the output folder is not allowed.") err = "**** ERROR: Saving image outside the output folder is not allowed." + \
return {} "\n full_output_folder: " + os.path.abspath(full_output_folder) + \
"\n output_dir: " + output_dir + \
"\n commonpath: " + os.path.commonpath((output_dir, os.path.abspath(full_output_folder)))
print(err)
raise Exception(err)
try: try:
counter = max(filter(lambda a: a[1][:-1] == filename and a[1][-1] == "_", map(map_filename, os.listdir(full_output_folder))))[0] + 1 counter = max(filter(lambda a: a[1][:-1] == filename and a[1][-1] == "_", map(map_filename, os.listdir(full_output_folder))))[0] + 1

View File

@ -22,10 +22,7 @@ class TAESDPreviewerImpl(LatentPreviewer):
self.taesd = taesd self.taesd = taesd
def decode_latent_to_preview(self, x0): def decode_latent_to_preview(self, x0):
x_sample = self.taesd.decoder(x0[:1])[0].detach() x_sample = self.taesd.decode(x0[:1])[0].detach()
# x_sample = self.taesd.unscale_latents(x_sample).div(4).add(0.5) # returns value in [-2, 2]
x_sample = x_sample.sub(0.5).mul(2)
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0) x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8) x_sample = x_sample.astype(np.uint8)

23
main.py
View File

@ -88,18 +88,37 @@ def cuda_malloc_warning():
def prompt_worker(q, server): def prompt_worker(q, server):
e = execution.PromptExecutor(server) e = execution.PromptExecutor(server)
last_gc_collect = 0
need_gc = False
gc_collect_interval = 10.0
while True: while True:
item, item_id = q.get() timeout = None
if need_gc:
timeout = max(gc_collect_interval - (current_time - last_gc_collect), 0.0)
queue_item = q.get(timeout=timeout)
if queue_item is not None:
item, item_id = queue_item
execution_start_time = time.perf_counter() execution_start_time = time.perf_counter()
prompt_id = item[1] prompt_id = item[1]
e.execute(item[2], prompt_id, item[3], item[4]) e.execute(item[2], prompt_id, item[3], item[4])
need_gc = True
q.task_done(item_id, e.outputs_ui) q.task_done(item_id, e.outputs_ui)
if server.client_id is not None: if server.client_id is not None:
server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, server.client_id) server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, server.client_id)
print("Prompt executed in {:.2f} seconds".format(time.perf_counter() - execution_start_time)) current_time = time.perf_counter()
execution_time = current_time - execution_start_time
print("Prompt executed in {:.2f} seconds".format(execution_time))
if need_gc:
current_time = time.perf_counter()
if (current_time - last_gc_collect) > gc_collect_interval:
gc.collect() gc.collect()
comfy.model_management.soft_empty_cache() comfy.model_management.soft_empty_cache()
last_gc_collect = current_time
need_gc = False
async def run(server, address='', port=8188, verbose=True, call_on_start=None): async def run(server, address='', port=8188, verbose=True, call_on_start=None):
await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop()) await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop())

View File

@ -248,8 +248,8 @@ class ConditioningSetTimestepRange:
c = [] c = []
for t in conditioning: for t in conditioning:
d = t[1].copy() d = t[1].copy()
d['start_percent'] = 1.0 - start d['start_percent'] = start
d['end_percent'] = 1.0 - end d['end_percent'] = end
n = [t[0], d] n = [t[0], d]
c.append(n) c.append(n)
return (c, ) return (c, )
@ -572,10 +572,69 @@ class LoraLoader:
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip) model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip)
return (model_lora, clip_lora) return (model_lora, clip_lora)
class VAELoader: class LoraLoaderModelOnly(LoraLoader):
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "vae_name": (folder_paths.get_filename_list("vae"), )}} return {"required": { "model": ("MODEL",),
"lora_name": (folder_paths.get_filename_list("loras"), ),
"strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "load_lora_model_only"
def load_lora_model_only(self, model, lora_name, strength_model):
return (self.load_lora(model, None, lora_name, strength_model, 0)[0],)
class VAELoader:
@staticmethod
def vae_list():
vaes = folder_paths.get_filename_list("vae")
approx_vaes = folder_paths.get_filename_list("vae_approx")
sdxl_taesd_enc = False
sdxl_taesd_dec = False
sd1_taesd_enc = False
sd1_taesd_dec = False
for v in approx_vaes:
if v.startswith("taesd_decoder."):
sd1_taesd_dec = True
elif v.startswith("taesd_encoder."):
sd1_taesd_enc = True
elif v.startswith("taesdxl_decoder."):
sdxl_taesd_dec = True
elif v.startswith("taesdxl_encoder."):
sdxl_taesd_enc = True
if sd1_taesd_dec and sd1_taesd_enc:
vaes.append("taesd")
if sdxl_taesd_dec and sdxl_taesd_enc:
vaes.append("taesdxl")
return vaes
@staticmethod
def load_taesd(name):
sd = {}
approx_vaes = folder_paths.get_filename_list("vae_approx")
encoder = next(filter(lambda a: a.startswith("{}_encoder.".format(name)), approx_vaes))
decoder = next(filter(lambda a: a.startswith("{}_decoder.".format(name)), approx_vaes))
enc = comfy.utils.load_torch_file(folder_paths.get_full_path("vae_approx", encoder))
for k in enc:
sd["taesd_encoder.{}".format(k)] = enc[k]
dec = comfy.utils.load_torch_file(folder_paths.get_full_path("vae_approx", decoder))
for k in dec:
sd["taesd_decoder.{}".format(k)] = dec[k]
if name == "taesd":
sd["vae_scale"] = torch.tensor(0.18215)
elif name == "taesdxl":
sd["vae_scale"] = torch.tensor(0.13025)
return sd
@classmethod
def INPUT_TYPES(s):
return {"required": { "vae_name": (s.vae_list(), )}}
RETURN_TYPES = ("VAE",) RETURN_TYPES = ("VAE",)
FUNCTION = "load_vae" FUNCTION = "load_vae"
@ -583,6 +642,9 @@ class VAELoader:
#TODO: scale factor? #TODO: scale factor?
def load_vae(self, vae_name): def load_vae(self, vae_name):
if vae_name in ["taesd", "taesdxl"]:
sd = self.load_taesd(vae_name)
else:
vae_path = folder_paths.get_full_path("vae", vae_name) vae_path = folder_paths.get_full_path("vae", vae_name)
sd = comfy.utils.load_torch_file(vae_path) sd = comfy.utils.load_torch_file(vae_path)
vae = comfy.sd.VAE(sd=sd) vae = comfy.sd.VAE(sd=sd)
@ -685,7 +747,7 @@ class ControlNetApplyAdvanced:
if prev_cnet in cnets: if prev_cnet in cnets:
c_net = cnets[prev_cnet] c_net = cnets[prev_cnet]
else: else:
c_net = control_net.copy().set_cond_hint(control_hint, strength, (1.0 - start_percent, 1.0 - end_percent)) c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent))
c_net.set_previous_controlnet(prev_cnet) c_net.set_previous_controlnet(prev_cnet)
cnets[prev_cnet] = c_net cnets[prev_cnet] = c_net
@ -1275,6 +1337,7 @@ class SaveImage:
self.output_dir = folder_paths.get_output_directory() self.output_dir = folder_paths.get_output_directory()
self.type = "output" self.type = "output"
self.prefix_append = "" self.prefix_append = ""
self.compress_level = 4
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -1308,7 +1371,7 @@ class SaveImage:
metadata.add_text(x, json.dumps(extra_pnginfo[x])) metadata.add_text(x, json.dumps(extra_pnginfo[x]))
file = f"{filename}_{counter:05}_.png" file = f"{filename}_{counter:05}_.png"
img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=4) img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=self.compress_level)
results.append({ results.append({
"filename": file, "filename": file,
"subfolder": subfolder, "subfolder": subfolder,
@ -1323,6 +1386,7 @@ class PreviewImage(SaveImage):
self.output_dir = folder_paths.get_temp_directory() self.output_dir = folder_paths.get_temp_directory()
self.type = "temp" self.type = "temp"
self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5)) self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5))
self.compress_level = 1
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -1654,6 +1718,7 @@ NODE_CLASS_MAPPINGS = {
"ConditioningZeroOut": ConditioningZeroOut, "ConditioningZeroOut": ConditioningZeroOut,
"ConditioningSetTimestepRange": ConditioningSetTimestepRange, "ConditioningSetTimestepRange": ConditioningSetTimestepRange,
"LoraLoaderModelOnly": LoraLoaderModelOnly,
} }
NODE_DISPLAY_NAME_MAPPINGS = { NODE_DISPLAY_NAME_MAPPINGS = {
@ -1759,7 +1824,7 @@ def load_custom_nodes():
node_paths = folder_paths.get_folder_paths("custom_nodes") node_paths = folder_paths.get_folder_paths("custom_nodes")
node_import_times = [] node_import_times = []
for custom_node_path in node_paths: for custom_node_path in node_paths:
possible_modules = os.listdir(custom_node_path) possible_modules = os.listdir(os.path.realpath(custom_node_path))
if "__pycache__" in possible_modules: if "__pycache__" in possible_modules:
possible_modules.remove("__pycache__") possible_modules.remove("__pycache__")
@ -1799,6 +1864,9 @@ def init_custom_nodes():
"nodes_custom_sampler.py", "nodes_custom_sampler.py",
"nodes_hypertile.py", "nodes_hypertile.py",
"nodes_model_advanced.py", "nodes_model_advanced.py",
"nodes_model_downscale.py",
"nodes_images.py",
"nodes_video_model.py",
] ]
for node_file in extras_files: for node_file in extras_files:

View File

@ -431,7 +431,10 @@ class PromptServer():
@routes.get("/history") @routes.get("/history")
async def get_history(request): async def get_history(request):
return web.json_response(self.prompt_queue.get_history()) max_items = request.rel_url.query.get("max_items", None)
if max_items is not None:
max_items = int(max_items)
return web.json_response(self.prompt_queue.get_history(max_items=max_items))
@routes.get("/history/{prompt_id}") @routes.get("/history/{prompt_id}")
async def get_history(request): async def get_history(request):
@ -573,7 +576,7 @@ class PromptServer():
bytesIO = BytesIO() bytesIO = BytesIO()
header = struct.pack(">I", type_num) header = struct.pack(">I", type_num)
bytesIO.write(header) bytesIO.write(header)
image.save(bytesIO, format=image_type, quality=95, compress_level=4) image.save(bytesIO, format=image_type, quality=95, compress_level=1)
preview_bytes = bytesIO.getvalue() preview_bytes = bytesIO.getvalue()
await self.send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid) await self.send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid)

View File

@ -20,6 +20,7 @@ async function setup() {
// Modify the response data to add some checkpoints // Modify the response data to add some checkpoints
const objectInfo = JSON.parse(data); const objectInfo = JSON.parse(data);
objectInfo.CheckpointLoaderSimple.input.required.ckpt_name[0] = ["model1.safetensors", "model2.ckpt"]; objectInfo.CheckpointLoaderSimple.input.required.ckpt_name[0] = ["model1.safetensors", "model2.ckpt"];
objectInfo.VAELoader.input.required.vae_name[0] = ["vae1.safetensors", "vae2.ckpt"];
data = JSON.stringify(objectInfo, undefined, "\t"); data = JSON.stringify(objectInfo, undefined, "\t");

View File

@ -0,0 +1,196 @@
// @ts-check
/// <reference path="../node_modules/@types/jest/index.d.ts" />
const { start } = require("../utils");
const lg = require("../utils/litegraph");
describe("extensions", () => {
beforeEach(() => {
lg.setup(global);
});
afterEach(() => {
lg.teardown(global);
});
it("calls each extension hook", async () => {
const mockExtension = {
name: "TestExtension",
init: jest.fn(),
setup: jest.fn(),
addCustomNodeDefs: jest.fn(),
getCustomWidgets: jest.fn(),
beforeRegisterNodeDef: jest.fn(),
registerCustomNodes: jest.fn(),
loadedGraphNode: jest.fn(),
nodeCreated: jest.fn(),
beforeConfigureGraph: jest.fn(),
afterConfigureGraph: jest.fn(),
};
const { app, ez, graph } = await start({
async preSetup(app) {
app.registerExtension(mockExtension);
},
});
// Basic initialisation hooks should be called once, with app
expect(mockExtension.init).toHaveBeenCalledTimes(1);
expect(mockExtension.init).toHaveBeenCalledWith(app);
// Adding custom node defs should be passed the full list of nodes
expect(mockExtension.addCustomNodeDefs).toHaveBeenCalledTimes(1);
expect(mockExtension.addCustomNodeDefs.mock.calls[0][1]).toStrictEqual(app);
const defs = mockExtension.addCustomNodeDefs.mock.calls[0][0];
expect(defs).toHaveProperty("KSampler");
expect(defs).toHaveProperty("LoadImage");
// Get custom widgets is called once and should return new widget types
expect(mockExtension.getCustomWidgets).toHaveBeenCalledTimes(1);
expect(mockExtension.getCustomWidgets).toHaveBeenCalledWith(app);
// Before register node def will be called once per node type
const nodeNames = Object.keys(defs);
const nodeCount = nodeNames.length;
expect(mockExtension.beforeRegisterNodeDef).toHaveBeenCalledTimes(nodeCount);
for (let i = 0; i < nodeCount; i++) {
// It should be send the JS class and the original JSON definition
const nodeClass = mockExtension.beforeRegisterNodeDef.mock.calls[i][0];
const nodeDef = mockExtension.beforeRegisterNodeDef.mock.calls[i][1];
expect(nodeClass.name).toBe("ComfyNode");
expect(nodeClass.comfyClass).toBe(nodeNames[i]);
expect(nodeDef.name).toBe(nodeNames[i]);
expect(nodeDef).toHaveProperty("input");
expect(nodeDef).toHaveProperty("output");
}
// Register custom nodes is called once after registerNode defs to allow adding other frontend nodes
expect(mockExtension.registerCustomNodes).toHaveBeenCalledTimes(1);
// Before configure graph will be called here as the default graph is being loaded
expect(mockExtension.beforeConfigureGraph).toHaveBeenCalledTimes(1);
// it gets sent the graph data that is going to be loaded
const graphData = mockExtension.beforeConfigureGraph.mock.calls[0][0];
// A node created is fired for each node constructor that is called
expect(mockExtension.nodeCreated).toHaveBeenCalledTimes(graphData.nodes.length);
for (let i = 0; i < graphData.nodes.length; i++) {
expect(mockExtension.nodeCreated.mock.calls[i][0].type).toBe(graphData.nodes[i].type);
}
// Each node then calls loadedGraphNode to allow them to be updated
expect(mockExtension.loadedGraphNode).toHaveBeenCalledTimes(graphData.nodes.length);
for (let i = 0; i < graphData.nodes.length; i++) {
expect(mockExtension.loadedGraphNode.mock.calls[i][0].type).toBe(graphData.nodes[i].type);
}
// After configure is then called once all the setup is done
expect(mockExtension.afterConfigureGraph).toHaveBeenCalledTimes(1);
expect(mockExtension.setup).toHaveBeenCalledTimes(1);
expect(mockExtension.setup).toHaveBeenCalledWith(app);
// Ensure hooks are called in the correct order
const callOrder = [
"init",
"addCustomNodeDefs",
"getCustomWidgets",
"beforeRegisterNodeDef",
"registerCustomNodes",
"beforeConfigureGraph",
"nodeCreated",
"loadedGraphNode",
"afterConfigureGraph",
"setup",
];
for (let i = 1; i < callOrder.length; i++) {
const fn1 = mockExtension[callOrder[i - 1]];
const fn2 = mockExtension[callOrder[i]];
expect(fn1.mock.invocationCallOrder[0]).toBeLessThan(fn2.mock.invocationCallOrder[0]);
}
graph.clear();
// Ensure adding a new node calls the correct callback
ez.LoadImage();
expect(mockExtension.loadedGraphNode).toHaveBeenCalledTimes(graphData.nodes.length);
expect(mockExtension.nodeCreated).toHaveBeenCalledTimes(graphData.nodes.length + 1);
expect(mockExtension.nodeCreated.mock.lastCall[0].type).toBe("LoadImage");
// Reload the graph to ensure correct hooks are fired
await graph.reload();
// These hooks should not be fired again
expect(mockExtension.init).toHaveBeenCalledTimes(1);
expect(mockExtension.addCustomNodeDefs).toHaveBeenCalledTimes(1);
expect(mockExtension.getCustomWidgets).toHaveBeenCalledTimes(1);
expect(mockExtension.registerCustomNodes).toHaveBeenCalledTimes(1);
expect(mockExtension.beforeRegisterNodeDef).toHaveBeenCalledTimes(nodeCount);
expect(mockExtension.setup).toHaveBeenCalledTimes(1);
// These should be called again
expect(mockExtension.beforeConfigureGraph).toHaveBeenCalledTimes(2);
expect(mockExtension.nodeCreated).toHaveBeenCalledTimes(graphData.nodes.length + 2);
expect(mockExtension.loadedGraphNode).toHaveBeenCalledTimes(graphData.nodes.length + 1);
expect(mockExtension.afterConfigureGraph).toHaveBeenCalledTimes(2);
});
it("allows custom nodeDefs and widgets to be registered", async () => {
const widgetMock = jest.fn((node, inputName, inputData, app) => {
expect(node.constructor.comfyClass).toBe("TestNode");
expect(inputName).toBe("test_input");
expect(inputData[0]).toBe("CUSTOMWIDGET");
expect(inputData[1]?.hello).toBe("world");
expect(app).toStrictEqual(app);
return {
widget: node.addWidget("button", inputName, "hello", () => {}),
};
});
// Register our extension that adds a custom node + widget type
const mockExtension = {
name: "TestExtension",
addCustomNodeDefs: (nodeDefs) => {
nodeDefs["TestNode"] = {
output: [],
output_name: [],
output_is_list: [],
name: "TestNode",
display_name: "TestNode",
category: "Test",
input: {
required: {
test_input: ["CUSTOMWIDGET", { hello: "world" }],
},
},
};
},
getCustomWidgets: jest.fn(() => {
return {
CUSTOMWIDGET: widgetMock,
};
}),
};
const { graph, ez } = await start({
async preSetup(app) {
app.registerExtension(mockExtension);
},
});
expect(mockExtension.getCustomWidgets).toBeCalledTimes(1);
graph.clear();
expect(widgetMock).toBeCalledTimes(0);
const node = ez.TestNode();
expect(widgetMock).toBeCalledTimes(1);
// Ensure our custom widget is created
expect(node.inputs.length).toBe(0);
expect(node.widgets.length).toBe(1);
const w = node.widgets[0].widget;
expect(w.name).toBe("test_input");
expect(w.type).toBe("button");
});
});

View File

@ -0,0 +1,818 @@
// @ts-check
/// <reference path="../node_modules/@types/jest/index.d.ts" />
const { start, createDefaultWorkflow } = require("../utils");
const lg = require("../utils/litegraph");
describe("group node", () => {
beforeEach(() => {
lg.setup(global);
});
afterEach(() => {
lg.teardown(global);
});
/**
*
* @param {*} app
* @param {*} graph
* @param {*} name
* @param {*} nodes
* @returns { Promise<InstanceType<import("../utils/ezgraph")["EzNode"]>> }
*/
async function convertToGroup(app, graph, name, nodes) {
// Select the nodes we are converting
for (const n of nodes) {
n.select(true);
}
expect(Object.keys(app.canvas.selected_nodes).sort((a, b) => +a - +b)).toEqual(
nodes.map((n) => n.id + "").sort((a, b) => +a - +b)
);
global.prompt = jest.fn().mockImplementation(() => name);
const groupNode = await nodes[0].menu["Convert to Group Node"].call(false);
// Check group name was requested
expect(window.prompt).toHaveBeenCalled();
// Ensure old nodes are removed
for (const n of nodes) {
expect(n.isRemoved).toBeTruthy();
}
expect(groupNode.type).toEqual("workflow/" + name);
return graph.find(groupNode);
}
/**
* @param { Record<string, string | number> | number[] } idMap
* @param { Record<string, Record<string, unknown>> } valueMap
*/
function getOutput(idMap = {}, valueMap = {}) {
if (idMap instanceof Array) {
idMap = idMap.reduce((p, n) => {
p[n] = n + "";
return p;
}, {});
}
const expected = {
1: { inputs: { ckpt_name: "model1.safetensors", ...valueMap?.[1] }, class_type: "CheckpointLoaderSimple" },
2: { inputs: { text: "positive", clip: ["1", 1], ...valueMap?.[2] }, class_type: "CLIPTextEncode" },
3: { inputs: { text: "negative", clip: ["1", 1], ...valueMap?.[3] }, class_type: "CLIPTextEncode" },
4: { inputs: { width: 512, height: 512, batch_size: 1, ...valueMap?.[4] }, class_type: "EmptyLatentImage" },
5: {
inputs: {
seed: 0,
steps: 20,
cfg: 8,
sampler_name: "euler",
scheduler: "normal",
denoise: 1,
model: ["1", 0],
positive: ["2", 0],
negative: ["3", 0],
latent_image: ["4", 0],
...valueMap?.[5],
},
class_type: "KSampler",
},
6: { inputs: { samples: ["5", 0], vae: ["1", 2], ...valueMap?.[6] }, class_type: "VAEDecode" },
7: { inputs: { filename_prefix: "ComfyUI", images: ["6", 0], ...valueMap?.[7] }, class_type: "SaveImage" },
};
// Map old IDs to new at the top level
const mapped = {};
for (const oldId in idMap) {
mapped[idMap[oldId]] = expected[oldId];
delete expected[oldId];
}
Object.assign(mapped, expected);
// Map old IDs to new inside links
for (const k in mapped) {
for (const input in mapped[k].inputs) {
const v = mapped[k].inputs[input];
if (v instanceof Array) {
if (v[0] in idMap) {
v[0] = idMap[v[0]] + "";
}
}
}
}
return mapped;
}
test("can be created from selected nodes", async () => {
const { ez, graph, app } = await start();
const nodes = createDefaultWorkflow(ez, graph);
const group = await convertToGroup(app, graph, "test", [nodes.pos, nodes.neg, nodes.empty]);
// Ensure links are now to the group node
expect(group.inputs).toHaveLength(2);
expect(group.outputs).toHaveLength(3);
expect(group.inputs.map((i) => i.input.name)).toEqual(["clip", "CLIPTextEncode clip"]);
expect(group.outputs.map((i) => i.output.name)).toEqual(["LATENT", "CONDITIONING", "CLIPTextEncode CONDITIONING"]);
// ckpt clip to both clip inputs on the group
expect(nodes.ckpt.outputs.CLIP.connections.map((t) => [t.targetNode.id, t.targetInput.index])).toEqual([
[group.id, 0],
[group.id, 1],
]);
// group conditioning to sampler
expect(group.outputs["CONDITIONING"].connections.map((t) => [t.targetNode.id, t.targetInput.index])).toEqual([
[nodes.sampler.id, 1],
]);
// group conditioning 2 to sampler
expect(
group.outputs["CLIPTextEncode CONDITIONING"].connections.map((t) => [t.targetNode.id, t.targetInput.index])
).toEqual([[nodes.sampler.id, 2]]);
// group latent to sampler
expect(group.outputs["LATENT"].connections.map((t) => [t.targetNode.id, t.targetInput.index])).toEqual([
[nodes.sampler.id, 3],
]);
});
test("maintains all output links on conversion", async () => {
const { ez, graph, app } = await start();
const nodes = createDefaultWorkflow(ez, graph);
const save2 = ez.SaveImage(...nodes.decode.outputs);
const save3 = ez.SaveImage(...nodes.decode.outputs);
// Ensure an output with multiple links maintains them on convert to group
const group = await convertToGroup(app, graph, "test", [nodes.sampler, nodes.decode]);
expect(group.outputs[0].connections.length).toBe(3);
expect(group.outputs[0].connections[0].targetNode.id).toBe(nodes.save.id);
expect(group.outputs[0].connections[1].targetNode.id).toBe(save2.id);
expect(group.outputs[0].connections[2].targetNode.id).toBe(save3.id);
// and they're still linked when converting back to nodes
const newNodes = group.menu["Convert to nodes"].call();
const decode = graph.find(newNodes.find((n) => n.type === "VAEDecode"));
expect(decode.outputs[0].connections.length).toBe(3);
expect(decode.outputs[0].connections[0].targetNode.id).toBe(nodes.save.id);
expect(decode.outputs[0].connections[1].targetNode.id).toBe(save2.id);
expect(decode.outputs[0].connections[2].targetNode.id).toBe(save3.id);
});
test("can be be converted back to nodes", async () => {
const { ez, graph, app } = await start();
const nodes = createDefaultWorkflow(ez, graph);
const toConvert = [nodes.pos, nodes.neg, nodes.empty, nodes.sampler];
const group = await convertToGroup(app, graph, "test", toConvert);
// Edit some values to ensure they are set back onto the converted nodes
expect(group.widgets["text"].value).toBe("positive");
group.widgets["text"].value = "pos";
expect(group.widgets["CLIPTextEncode text"].value).toBe("negative");
group.widgets["CLIPTextEncode text"].value = "neg";
expect(group.widgets["width"].value).toBe(512);
group.widgets["width"].value = 1024;
expect(group.widgets["sampler_name"].value).toBe("euler");
group.widgets["sampler_name"].value = "ddim";
expect(group.widgets["control_after_generate"].value).toBe("randomize");
group.widgets["control_after_generate"].value = "fixed";
/** @type { Array<any> } */
group.menu["Convert to nodes"].call();
// ensure widget values are set
const pos = graph.find(nodes.pos.id);
expect(pos.node.type).toBe("CLIPTextEncode");
expect(pos.widgets["text"].value).toBe("pos");
const neg = graph.find(nodes.neg.id);
expect(neg.node.type).toBe("CLIPTextEncode");
expect(neg.widgets["text"].value).toBe("neg");
const empty = graph.find(nodes.empty.id);
expect(empty.node.type).toBe("EmptyLatentImage");
expect(empty.widgets["width"].value).toBe(1024);
const sampler = graph.find(nodes.sampler.id);
expect(sampler.node.type).toBe("KSampler");
expect(sampler.widgets["sampler_name"].value).toBe("ddim");
expect(sampler.widgets["control_after_generate"].value).toBe("fixed");
// validate links
expect(nodes.ckpt.outputs.CLIP.connections.map((t) => [t.targetNode.id, t.targetInput.index])).toEqual([
[pos.id, 0],
[neg.id, 0],
]);
expect(pos.outputs["CONDITIONING"].connections.map((t) => [t.targetNode.id, t.targetInput.index])).toEqual([
[nodes.sampler.id, 1],
]);
expect(neg.outputs["CONDITIONING"].connections.map((t) => [t.targetNode.id, t.targetInput.index])).toEqual([
[nodes.sampler.id, 2],
]);
expect(empty.outputs["LATENT"].connections.map((t) => [t.targetNode.id, t.targetInput.index])).toEqual([
[nodes.sampler.id, 3],
]);
});
test("it can embed reroutes as inputs", async () => {
const { ez, graph, app } = await start();
const nodes = createDefaultWorkflow(ez, graph);
// Add and connect a reroute to the clip text encodes
const reroute = ez.Reroute();
nodes.ckpt.outputs.CLIP.connectTo(reroute.inputs[0]);
reroute.outputs[0].connectTo(nodes.pos.inputs[0]);
reroute.outputs[0].connectTo(nodes.neg.inputs[0]);
// Convert to group and ensure we only have 1 input of the correct type
const group = await convertToGroup(app, graph, "test", [nodes.pos, nodes.neg, nodes.empty, reroute]);
expect(group.inputs).toHaveLength(1);
expect(group.inputs[0].input.type).toEqual("CLIP");
expect((await graph.toPrompt()).output).toEqual(getOutput());
});
test("it can embed reroutes as outputs", async () => {
const { ez, graph, app } = await start();
const nodes = createDefaultWorkflow(ez, graph);
// Add a reroute with no output so we output IMAGE even though its used internally
const reroute = ez.Reroute();
nodes.decode.outputs.IMAGE.connectTo(reroute.inputs[0]);
// Convert to group and ensure there is an IMAGE output
const group = await convertToGroup(app, graph, "test", [nodes.decode, nodes.save, reroute]);
expect(group.outputs).toHaveLength(1);
expect(group.outputs[0].output.type).toEqual("IMAGE");
expect((await graph.toPrompt()).output).toEqual(getOutput([nodes.decode.id, nodes.save.id]));
});
test("it can embed reroutes as pipes", async () => {
const { ez, graph, app } = await start();
const nodes = createDefaultWorkflow(ez, graph);
// Use reroutes as a pipe
const rerouteModel = ez.Reroute();
const rerouteClip = ez.Reroute();
const rerouteVae = ez.Reroute();
nodes.ckpt.outputs.MODEL.connectTo(rerouteModel.inputs[0]);
nodes.ckpt.outputs.CLIP.connectTo(rerouteClip.inputs[0]);
nodes.ckpt.outputs.VAE.connectTo(rerouteVae.inputs[0]);
const group = await convertToGroup(app, graph, "test", [rerouteModel, rerouteClip, rerouteVae]);
expect(group.outputs).toHaveLength(3);
expect(group.outputs.map((o) => o.output.type)).toEqual(["MODEL", "CLIP", "VAE"]);
expect(group.outputs).toHaveLength(3);
expect(group.outputs.map((o) => o.output.type)).toEqual(["MODEL", "CLIP", "VAE"]);
group.outputs[0].connectTo(nodes.sampler.inputs.model);
group.outputs[1].connectTo(nodes.pos.inputs.clip);
group.outputs[1].connectTo(nodes.neg.inputs.clip);
});
test("can handle reroutes used internally", async () => {
const { ez, graph, app } = await start();
const nodes = createDefaultWorkflow(ez, graph);
let reroutes = [];
let prevNode = nodes.ckpt;
for(let i = 0; i < 5; i++) {
const reroute = ez.Reroute();
prevNode.outputs[0].connectTo(reroute.inputs[0]);
prevNode = reroute;
reroutes.push(reroute);
}
prevNode.outputs[0].connectTo(nodes.sampler.inputs.model);
const group = await convertToGroup(app, graph, "test", [...reroutes, ...Object.values(nodes)]);
expect((await graph.toPrompt()).output).toEqual(getOutput());
group.menu["Convert to nodes"].call();
expect((await graph.toPrompt()).output).toEqual(getOutput());
});
test("creates with widget values from inner nodes", async () => {
const { ez, graph, app } = await start();
const nodes = createDefaultWorkflow(ez, graph);
nodes.ckpt.widgets.ckpt_name.value = "model2.ckpt";
nodes.pos.widgets.text.value = "hello";
nodes.neg.widgets.text.value = "world";
nodes.empty.widgets.width.value = 256;
nodes.empty.widgets.height.value = 1024;
nodes.sampler.widgets.seed.value = 1;
nodes.sampler.widgets.control_after_generate.value = "increment";
nodes.sampler.widgets.steps.value = 8;
nodes.sampler.widgets.cfg.value = 4.5;
nodes.sampler.widgets.sampler_name.value = "uni_pc";
nodes.sampler.widgets.scheduler.value = "karras";
nodes.sampler.widgets.denoise.value = 0.9;
const group = await convertToGroup(app, graph, "test", [
nodes.ckpt,
nodes.pos,
nodes.neg,
nodes.empty,
nodes.sampler,
]);
expect(group.widgets["ckpt_name"].value).toEqual("model2.ckpt");
expect(group.widgets["text"].value).toEqual("hello");
expect(group.widgets["CLIPTextEncode text"].value).toEqual("world");
expect(group.widgets["width"].value).toEqual(256);
expect(group.widgets["height"].value).toEqual(1024);
expect(group.widgets["seed"].value).toEqual(1);
expect(group.widgets["control_after_generate"].value).toEqual("increment");
expect(group.widgets["steps"].value).toEqual(8);
expect(group.widgets["cfg"].value).toEqual(4.5);
expect(group.widgets["sampler_name"].value).toEqual("uni_pc");
expect(group.widgets["scheduler"].value).toEqual("karras");
expect(group.widgets["denoise"].value).toEqual(0.9);
expect((await graph.toPrompt()).output).toEqual(
getOutput([nodes.ckpt.id, nodes.pos.id, nodes.neg.id, nodes.empty.id, nodes.sampler.id], {
[nodes.ckpt.id]: { ckpt_name: "model2.ckpt" },
[nodes.pos.id]: { text: "hello" },
[nodes.neg.id]: { text: "world" },
[nodes.empty.id]: { width: 256, height: 1024 },
[nodes.sampler.id]: {
seed: 1,
steps: 8,
cfg: 4.5,
sampler_name: "uni_pc",
scheduler: "karras",
denoise: 0.9,
},
})
);
});
test("group inputs can be reroutes", async () => {
const { ez, graph, app } = await start();
const nodes = createDefaultWorkflow(ez, graph);
const group = await convertToGroup(app, graph, "test", [nodes.pos, nodes.neg]);
const reroute = ez.Reroute();
nodes.ckpt.outputs.CLIP.connectTo(reroute.inputs[0]);
reroute.outputs[0].connectTo(group.inputs[0]);
reroute.outputs[0].connectTo(group.inputs[1]);
expect((await graph.toPrompt()).output).toEqual(getOutput([nodes.pos.id, nodes.neg.id]));
});
test("group outputs can be reroutes", async () => {
const { ez, graph, app } = await start();
const nodes = createDefaultWorkflow(ez, graph);
const group = await convertToGroup(app, graph, "test", [nodes.pos, nodes.neg]);
const reroute1 = ez.Reroute();
const reroute2 = ez.Reroute();
group.outputs[0].connectTo(reroute1.inputs[0]);
group.outputs[1].connectTo(reroute2.inputs[0]);
reroute1.outputs[0].connectTo(nodes.sampler.inputs.positive);
reroute2.outputs[0].connectTo(nodes.sampler.inputs.negative);
expect((await graph.toPrompt()).output).toEqual(getOutput([nodes.pos.id, nodes.neg.id]));
});
test("groups can connect to each other", async () => {
const { ez, graph, app } = await start();
const nodes = createDefaultWorkflow(ez, graph);
const group1 = await convertToGroup(app, graph, "test", [nodes.pos, nodes.neg]);
const group2 = await convertToGroup(app, graph, "test2", [nodes.empty, nodes.sampler]);
group1.outputs[0].connectTo(group2.inputs["positive"]);
group1.outputs[1].connectTo(group2.inputs["negative"]);
expect((await graph.toPrompt()).output).toEqual(
getOutput([nodes.pos.id, nodes.neg.id, nodes.empty.id, nodes.sampler.id])
);
});
test("displays generated image on group node", async () => {
const { ez, graph, app } = await start();
const nodes = createDefaultWorkflow(ez, graph);
let group = await convertToGroup(app, graph, "test", [
nodes.pos,
nodes.neg,
nodes.empty,
nodes.sampler,
nodes.decode,
nodes.save,
]);
const { api } = require("../../web/scripts/api");
api.dispatchEvent(new CustomEvent("execution_start", {}));
api.dispatchEvent(new CustomEvent("executing", { detail: `${nodes.save.id}` }));
// Event should be forwarded to group node id
expect(+app.runningNodeId).toEqual(group.id);
expect(group.node["imgs"]).toBeFalsy();
api.dispatchEvent(
new CustomEvent("executed", {
detail: {
node: `${nodes.save.id}`,
output: {
images: [
{
filename: "test.png",
type: "output",
},
],
},
},
})
);
// Trigger paint
group.node.onDrawBackground?.(app.canvas.ctx, app.canvas.canvas);
expect(group.node["images"]).toEqual([
{
filename: "test.png",
type: "output",
},
]);
// Reload
const workflow = JSON.stringify((await graph.toPrompt()).workflow);
await app.loadGraphData(JSON.parse(workflow));
group = graph.find(group);
// Trigger inner nodes to get created
group.node["getInnerNodes"]();
// Check it works for internal node ids
api.dispatchEvent(new CustomEvent("execution_start", {}));
api.dispatchEvent(new CustomEvent("executing", { detail: `${group.id}:5` }));
// Event should be forwarded to group node id
expect(+app.runningNodeId).toEqual(group.id);
expect(group.node["imgs"]).toBeFalsy();
api.dispatchEvent(
new CustomEvent("executed", {
detail: {
node: `${group.id}:5`,
output: {
images: [
{
filename: "test2.png",
type: "output",
},
],
},
},
})
);
// Trigger paint
group.node.onDrawBackground?.(app.canvas.ctx, app.canvas.canvas);
expect(group.node["images"]).toEqual([
{
filename: "test2.png",
type: "output",
},
]);
});
test("allows widgets to be converted to inputs", async () => {
const { ez, graph, app } = await start();
const nodes = createDefaultWorkflow(ez, graph);
const group = await convertToGroup(app, graph, "test", [nodes.pos, nodes.neg]);
group.widgets[0].convertToInput();
const primitive = ez.PrimitiveNode();
primitive.outputs[0].connectTo(group.inputs["text"]);
primitive.widgets[0].value = "hello";
expect((await graph.toPrompt()).output).toEqual(
getOutput([nodes.pos.id, nodes.neg.id], {
[nodes.pos.id]: { text: "hello" },
})
);
});
test("can be copied", async () => {
const { ez, graph, app } = await start();
const nodes = createDefaultWorkflow(ez, graph);
const group1 = await convertToGroup(app, graph, "test", [
nodes.pos,
nodes.neg,
nodes.empty,
nodes.sampler,
nodes.decode,
nodes.save,
]);
group1.widgets["text"].value = "hello";
group1.widgets["width"].value = 256;
group1.widgets["seed"].value = 1;
// Clone the node
group1.menu.Clone.call();
expect(app.graph._nodes).toHaveLength(3);
const group2 = graph.find(app.graph._nodes[2]);
expect(group2.node.type).toEqual("workflow/test");
expect(group2.id).not.toEqual(group1.id);
// Reconnect ckpt
nodes.ckpt.outputs.MODEL.connectTo(group2.inputs["model"]);
nodes.ckpt.outputs.CLIP.connectTo(group2.inputs["clip"]);
nodes.ckpt.outputs.CLIP.connectTo(group2.inputs["CLIPTextEncode clip"]);
nodes.ckpt.outputs.VAE.connectTo(group2.inputs["vae"]);
group2.widgets["text"].value = "world";
group2.widgets["width"].value = 1024;
group2.widgets["seed"].value = 100;
let i = 0;
expect((await graph.toPrompt()).output).toEqual({
...getOutput([nodes.empty.id, nodes.pos.id, nodes.neg.id, nodes.sampler.id, nodes.decode.id, nodes.save.id], {
[nodes.empty.id]: { width: 256 },
[nodes.pos.id]: { text: "hello" },
[nodes.sampler.id]: { seed: 1 },
}),
...getOutput(
{
[nodes.empty.id]: `${group2.id}:${i++}`,
[nodes.pos.id]: `${group2.id}:${i++}`,
[nodes.neg.id]: `${group2.id}:${i++}`,
[nodes.sampler.id]: `${group2.id}:${i++}`,
[nodes.decode.id]: `${group2.id}:${i++}`,
[nodes.save.id]: `${group2.id}:${i++}`,
},
{
[nodes.empty.id]: { width: 1024 },
[nodes.pos.id]: { text: "world" },
[nodes.sampler.id]: { seed: 100 },
}
),
});
graph.arrange();
});
test("is embedded in workflow", async () => {
let { ez, graph, app } = await start();
const nodes = createDefaultWorkflow(ez, graph);
let group = await convertToGroup(app, graph, "test", [nodes.pos, nodes.neg]);
const workflow = JSON.stringify((await graph.toPrompt()).workflow);
// Clear the environment
({ ez, graph, app } = await start({
resetEnv: true,
}));
// Ensure the node isnt registered
expect(() => ez["workflow/test"]).toThrow();
// Reload the workflow
await app.loadGraphData(JSON.parse(workflow));
// Ensure the node is found
group = graph.find(group);
// Generate prompt and ensure it is as expected
expect((await graph.toPrompt()).output).toEqual(
getOutput({
[nodes.pos.id]: `${group.id}:0`,
[nodes.neg.id]: `${group.id}:1`,
})
);
});
test("shows missing node error on missing internal node when loading graph data", async () => {
const { graph } = await start();
const dialogShow = jest.spyOn(graph.app.ui.dialog, "show");
await graph.app.loadGraphData({
last_node_id: 3,
last_link_id: 1,
nodes: [
{
id: 3,
type: "workflow/testerror",
},
],
links: [],
groups: [],
config: {},
extra: {
groupNodes: {
testerror: {
nodes: [
{
type: "NotKSampler",
},
{
type: "NotVAEDecode",
},
],
},
},
},
});
expect(dialogShow).toBeCalledTimes(1);
const call = dialogShow.mock.calls[0][0].innerHTML;
expect(call).toContain("the following node types were not found");
expect(call).toContain("NotKSampler");
expect(call).toContain("NotVAEDecode");
expect(call).toContain("workflow/testerror");
});
test("maintains widget inputs on conversion back to nodes", async () => {
const { ez, graph, app } = await start();
let pos = ez.CLIPTextEncode({ text: "positive" });
pos.node.title = "Positive";
let neg = ez.CLIPTextEncode({ text: "negative" });
neg.node.title = "Negative";
pos.widgets.text.convertToInput();
neg.widgets.text.convertToInput();
let primitive = ez.PrimitiveNode();
primitive.outputs[0].connectTo(pos.inputs.text);
primitive.outputs[0].connectTo(neg.inputs.text);
const group = await convertToGroup(app, graph, "test", [pos, neg, primitive]);
// This will use a primitive widget named 'value'
expect(group.widgets.length).toBe(1);
expect(group.widgets["value"].value).toBe("positive");
const newNodes = group.menu["Convert to nodes"].call();
pos = graph.find(newNodes.find((n) => n.title === "Positive"));
neg = graph.find(newNodes.find((n) => n.title === "Negative"));
primitive = graph.find(newNodes.find((n) => n.type === "PrimitiveNode"));
expect(pos.inputs).toHaveLength(2);
expect(neg.inputs).toHaveLength(2);
expect(primitive.outputs[0].connections).toHaveLength(2);
expect((await graph.toPrompt()).output).toEqual({
1: { inputs: { text: "positive" }, class_type: "CLIPTextEncode" },
2: { inputs: { text: "positive" }, class_type: "CLIPTextEncode" },
});
});
test("adds widgets in node execution order", async () => {
const { ez, graph, app } = await start();
const scale = ez.LatentUpscale();
const save = ez.SaveImage();
const empty = ez.EmptyLatentImage();
const decode = ez.VAEDecode();
scale.outputs.LATENT.connectTo(decode.inputs.samples);
decode.outputs.IMAGE.connectTo(save.inputs.images);
empty.outputs.LATENT.connectTo(scale.inputs.samples);
const group = await convertToGroup(app, graph, "test", [scale, save, empty, decode]);
const widgets = group.widgets.map((w) => w.widget.name);
expect(widgets).toStrictEqual([
"width",
"height",
"batch_size",
"upscale_method",
"LatentUpscale width",
"LatentUpscale height",
"crop",
"filename_prefix",
]);
});
test("adds output for external links when converting to group", async () => {
const { ez, graph, app } = await start();
const img = ez.EmptyLatentImage();
let decode = ez.VAEDecode(...img.outputs);
const preview1 = ez.PreviewImage(...decode.outputs);
const preview2 = ez.PreviewImage(...decode.outputs);
const group = await convertToGroup(app, graph, "test", [img, decode, preview1]);
// Ensure we have an output connected to the 2nd preview node
expect(group.outputs.length).toBe(1);
expect(group.outputs[0].connections.length).toBe(1);
expect(group.outputs[0].connections[0].targetNode.id).toBe(preview2.id);
// Convert back and ensure bothe previews are still connected
group.menu["Convert to nodes"].call();
decode = graph.find(decode);
expect(decode.outputs[0].connections.length).toBe(2);
expect(decode.outputs[0].connections[0].targetNode.id).toBe(preview1.id);
expect(decode.outputs[0].connections[1].targetNode.id).toBe(preview2.id);
});
test("adds output for external links when converting to group when nodes are not in execution order", async () => {
const { ez, graph, app } = await start();
const sampler = ez.KSampler();
const ckpt = ez.CheckpointLoaderSimple();
const empty = ez.EmptyLatentImage();
const pos = ez.CLIPTextEncode(ckpt.outputs.CLIP, { text: "positive" });
const neg = ez.CLIPTextEncode(ckpt.outputs.CLIP, { text: "negative" });
const decode1 = ez.VAEDecode(sampler.outputs.LATENT, ckpt.outputs.VAE);
const save = ez.SaveImage(decode1.outputs.IMAGE);
ckpt.outputs.MODEL.connectTo(sampler.inputs.model);
pos.outputs.CONDITIONING.connectTo(sampler.inputs.positive);
neg.outputs.CONDITIONING.connectTo(sampler.inputs.negative);
empty.outputs.LATENT.connectTo(sampler.inputs.latent_image);
const encode = ez.VAEEncode(decode1.outputs.IMAGE);
const vae = ez.VAELoader();
const decode2 = ez.VAEDecode(encode.outputs.LATENT, vae.outputs.VAE);
const preview = ez.PreviewImage(decode2.outputs.IMAGE);
vae.outputs.VAE.connectTo(encode.inputs.vae);
const group = await convertToGroup(app, graph, "test", [vae, decode1, encode, sampler]);
expect(group.outputs.length).toBe(3);
expect(group.outputs[0].output.name).toBe("VAE");
expect(group.outputs[0].output.type).toBe("VAE");
expect(group.outputs[1].output.name).toBe("IMAGE");
expect(group.outputs[1].output.type).toBe("IMAGE");
expect(group.outputs[2].output.name).toBe("LATENT");
expect(group.outputs[2].output.type).toBe("LATENT");
expect(group.outputs[0].connections.length).toBe(1);
expect(group.outputs[0].connections[0].targetNode.id).toBe(decode2.id);
expect(group.outputs[0].connections[0].targetInput.index).toBe(1);
expect(group.outputs[1].connections.length).toBe(1);
expect(group.outputs[1].connections[0].targetNode.id).toBe(save.id);
expect(group.outputs[1].connections[0].targetInput.index).toBe(0);
expect(group.outputs[2].connections.length).toBe(1);
expect(group.outputs[2].connections[0].targetNode.id).toBe(decode2.id);
expect(group.outputs[2].connections[0].targetInput.index).toBe(0);
expect((await graph.toPrompt()).output).toEqual({
...getOutput({ 1: ckpt.id, 2: pos.id, 3: neg.id, 4: empty.id, 5: sampler.id, 6: decode1.id, 7: save.id }),
[vae.id]: { inputs: { vae_name: "vae1.safetensors" }, class_type: vae.node.type },
[encode.id]: { inputs: { pixels: ["6", 0], vae: [vae.id + "", 0] }, class_type: encode.node.type },
[decode2.id]: { inputs: { samples: [encode.id + "", 0], vae: [vae.id + "", 0] }, class_type: decode2.node.type },
[preview.id]: { inputs: { images: [decode2.id + "", 0] }, class_type: preview.node.type },
});
});
test("works with IMAGEUPLOAD widget", async () => {
const { ez, graph, app } = await start();
const img = ez.LoadImage();
const preview1 = ez.PreviewImage(img.outputs[0]);
const group = await convertToGroup(app, graph, "test", [img, preview1]);
const widget = group.widgets["upload"];
expect(widget).toBeTruthy();
expect(widget.widget.type).toBe("button");
});
test("internal primitive populates widgets for all linked inputs", async () => {
const { ez, graph, app } = await start();
const img = ez.LoadImage();
const scale1 = ez.ImageScale(img.outputs[0]);
const scale2 = ez.ImageScale(img.outputs[0]);
ez.PreviewImage(scale1.outputs[0]);
ez.PreviewImage(scale2.outputs[0]);
scale1.widgets.width.convertToInput();
scale2.widgets.height.convertToInput();
const primitive = ez.PrimitiveNode();
primitive.outputs[0].connectTo(scale1.inputs.width);
primitive.outputs[0].connectTo(scale2.inputs.height);
const group = await convertToGroup(app, graph, "test", [img, primitive, scale1, scale2]);
group.widgets.value.value = 100;
expect((await graph.toPrompt()).output).toEqual({
1: {
inputs: { image: img.widgets.image.value, upload: "image" },
class_type: "LoadImage",
},
2: {
inputs: { upscale_method: "nearest-exact", width: 100, height: 512, crop: "disabled", image: ["1", 0] },
class_type: "ImageScale",
},
3: {
inputs: { upscale_method: "nearest-exact", width: 512, height: 100, crop: "disabled", image: ["1", 0] },
class_type: "ImageScale",
},
4: { inputs: { images: ["2", 0] }, class_type: "PreviewImage" },
5: { inputs: { images: ["3", 0] }, class_type: "PreviewImage" },
});
});
test("primitive control widgets values are copied on convert", async () => {
const { ez, graph, app } = await start();
const sampler = ez.KSampler();
sampler.widgets.seed.convertToInput();
sampler.widgets.sampler_name.convertToInput();
let p1 = ez.PrimitiveNode();
let p2 = ez.PrimitiveNode();
p1.outputs[0].connectTo(sampler.inputs.seed);
p2.outputs[0].connectTo(sampler.inputs.sampler_name);
p1.widgets.control_after_generate.value = "increment";
p2.widgets.control_after_generate.value = "decrement";
p2.widgets.control_filter_list.value = "/.*/";
p2.node.title = "p2";
const group = await convertToGroup(app, graph, "test", [sampler, p1, p2]);
expect(group.widgets.control_after_generate.value).toBe("increment");
expect(group.widgets["p2 control_after_generate"].value).toBe("decrement");
expect(group.widgets["p2 control_filter_list"].value).toBe("/.*/");
group.widgets.control_after_generate.value = "fixed";
group.widgets["p2 control_after_generate"].value = "randomize";
group.widgets["p2 control_filter_list"].value = "/.+/";
group.menu["Convert to nodes"].call();
p1 = graph.find(p1);
p2 = graph.find(p2);
expect(p1.widgets.control_after_generate.value).toBe("fixed");
expect(p2.widgets.control_after_generate.value).toBe("randomize");
expect(p2.widgets.control_filter_list.value).toBe("/.+/");
});
});

View File

@ -14,10 +14,10 @@ const lg = require("../utils/litegraph");
* @param { InstanceType<Ez["EzGraph"]> } graph * @param { InstanceType<Ez["EzGraph"]> } graph
* @param { InstanceType<Ez["EzInput"]> } input * @param { InstanceType<Ez["EzInput"]> } input
* @param { string } widgetType * @param { string } widgetType
* @param { boolean } hasControlWidget * @param { number } controlWidgetCount
* @returns * @returns
*/ */
async function connectPrimitiveAndReload(ez, graph, input, widgetType, hasControlWidget) { async function connectPrimitiveAndReload(ez, graph, input, widgetType, controlWidgetCount = 0) {
// Connect to primitive and ensure its still connected after // Connect to primitive and ensure its still connected after
let primitive = ez.PrimitiveNode(); let primitive = ez.PrimitiveNode();
primitive.outputs[0].connectTo(input); primitive.outputs[0].connectTo(input);
@ -33,13 +33,17 @@ async function connectPrimitiveAndReload(ez, graph, input, widgetType, hasContro
expect(valueWidget.widget.type).toBe(widgetType); expect(valueWidget.widget.type).toBe(widgetType);
// Check if control_after_generate should be added // Check if control_after_generate should be added
if (hasControlWidget) { if (controlWidgetCount) {
const controlWidget = primitive.widgets.control_after_generate; const controlWidget = primitive.widgets.control_after_generate;
expect(controlWidget.widget.type).toBe("combo"); expect(controlWidget.widget.type).toBe("combo");
if(widgetType === "combo") {
const filterWidget = primitive.widgets.control_filter_list;
expect(filterWidget.widget.type).toBe("string");
}
} }
// Ensure we dont have other widgets // Ensure we dont have other widgets
expect(primitive.node.widgets).toHaveLength(1 + +!!hasControlWidget); expect(primitive.node.widgets).toHaveLength(1 + controlWidgetCount);
}); });
return primitive; return primitive;
@ -55,8 +59,8 @@ describe("widget inputs", () => {
}); });
[ [
{ name: "int", type: "INT", widget: "number", control: true }, { name: "int", type: "INT", widget: "number", control: 1 },
{ name: "float", type: "FLOAT", widget: "number", control: true }, { name: "float", type: "FLOAT", widget: "number", control: 1 },
{ name: "text", type: "STRING" }, { name: "text", type: "STRING" },
{ {
name: "customtext", name: "customtext",
@ -64,7 +68,7 @@ describe("widget inputs", () => {
opt: { multiline: true }, opt: { multiline: true },
}, },
{ name: "toggle", type: "BOOLEAN" }, { name: "toggle", type: "BOOLEAN" },
{ name: "combo", type: ["a", "b", "c"], control: true }, { name: "combo", type: ["a", "b", "c"], control: 2 },
].forEach((c) => { ].forEach((c) => {
test(`widget conversion + primitive works on ${c.name}`, async () => { test(`widget conversion + primitive works on ${c.name}`, async () => {
const { ez, graph } = await start({ const { ez, graph } = await start({
@ -106,7 +110,7 @@ describe("widget inputs", () => {
n.widgets.ckpt_name.convertToInput(); n.widgets.ckpt_name.convertToInput();
expect(n.inputs.length).toEqual(inputCount + 1); expect(n.inputs.length).toEqual(inputCount + 1);
const primitive = await connectPrimitiveAndReload(ez, graph, n.inputs.ckpt_name, "combo", true); const primitive = await connectPrimitiveAndReload(ez, graph, n.inputs.ckpt_name, "combo", 2);
// Disconnect & reconnect // Disconnect & reconnect
primitive.outputs[0].connections[0].disconnect(); primitive.outputs[0].connections[0].disconnect();
@ -198,8 +202,8 @@ describe("widget inputs", () => {
}); });
expect(dialogShow).toBeCalledTimes(1); expect(dialogShow).toBeCalledTimes(1);
expect(dialogShow.mock.calls[0][0]).toContain("the following node types were not found"); expect(dialogShow.mock.calls[0][0].innerHTML).toContain("the following node types were not found");
expect(dialogShow.mock.calls[0][0]).toContain("TestNode"); expect(dialogShow.mock.calls[0][0].innerHTML).toContain("TestNode");
}); });
test("defaultInput widgets can be converted back to inputs", async () => { test("defaultInput widgets can be converted back to inputs", async () => {
@ -226,7 +230,7 @@ describe("widget inputs", () => {
// Reload and ensure it still only has 1 converted widget // Reload and ensure it still only has 1 converted widget
if (!assertNotNullOrUndefined(input)) return; if (!assertNotNullOrUndefined(input)) return;
await connectPrimitiveAndReload(ez, graph, input, "number", true); await connectPrimitiveAndReload(ez, graph, input, "number", 1);
n = graph.find(n); n = graph.find(n);
expect(n.widgets).toHaveLength(1); expect(n.widgets).toHaveLength(1);
w = n.widgets.example; w = n.widgets.example;
@ -258,7 +262,7 @@ describe("widget inputs", () => {
// Reload and ensure it still only has 1 converted widget // Reload and ensure it still only has 1 converted widget
if (assertNotNullOrUndefined(input)) { if (assertNotNullOrUndefined(input)) {
await connectPrimitiveAndReload(ez, graph, input, "number", true); await connectPrimitiveAndReload(ez, graph, input, "number", 1);
n = graph.find(n); n = graph.find(n);
expect(n.widgets).toHaveLength(1); expect(n.widgets).toHaveLength(1);
expect(n.widgets.example.isConvertedToInput).toBeTruthy(); expect(n.widgets.example.isConvertedToInput).toBeTruthy();
@ -316,4 +320,76 @@ describe("widget inputs", () => {
n1.outputs[0].connectTo(n2.inputs[0]); n1.outputs[0].connectTo(n2.inputs[0]);
expect(() => n1.outputs[0].connectTo(n3.inputs[0])).toThrow(); expect(() => n1.outputs[0].connectTo(n3.inputs[0])).toThrow();
}); });
test("combo primitive can filter list when control_after_generate called", async () => {
const { ez } = await start({
mockNodeDefs: {
...makeNodeDef("TestNode1", { example: [["A", "B", "C", "D", "AA", "BB", "CC", "DD", "AAA", "BBB"], {}] }),
},
});
const n1 = ez.TestNode1();
n1.widgets.example.convertToInput();
const p = ez.PrimitiveNode()
p.outputs[0].connectTo(n1.inputs[0]);
const value = p.widgets.value;
const control = p.widgets.control_after_generate.widget;
const filter = p.widgets.control_filter_list;
expect(p.widgets.length).toBe(3);
control.value = "increment";
expect(value.value).toBe("A");
// Manually trigger after queue when set to increment
control["afterQueued"]();
expect(value.value).toBe("B");
// Filter to items containing D
filter.value = "D";
control["afterQueued"]();
expect(value.value).toBe("D");
control["afterQueued"]();
expect(value.value).toBe("DD");
// Check decrement
value.value = "BBB";
control.value = "decrement";
filter.value = "B";
control["afterQueued"]();
expect(value.value).toBe("BB");
control["afterQueued"]();
expect(value.value).toBe("B");
// Check regex works
value.value = "BBB";
filter.value = "/[AB]|^C$/";
control["afterQueued"]();
expect(value.value).toBe("AAA");
control["afterQueued"]();
expect(value.value).toBe("BB");
control["afterQueued"]();
expect(value.value).toBe("AA");
control["afterQueued"]();
expect(value.value).toBe("C");
control["afterQueued"]();
expect(value.value).toBe("B");
control["afterQueued"]();
expect(value.value).toBe("A");
// Check random
control.value = "randomize";
filter.value = "/D/";
for(let i = 0; i < 100; i++) {
control["afterQueued"]();
expect(value.value === "D" || value.value === "DD").toBeTruthy();
}
// Ensure it doesnt apply when fixed
control.value = "fixed";
value.value = "B";
filter.value = "C";
control["afterQueued"]();
expect(value.value).toBe("B");
});
}); });

View File

@ -150,7 +150,7 @@ export class EzNodeMenuItem {
if (selectNode) { if (selectNode) {
this.node.select(); this.node.select();
} }
this.item.callback.call(this.node.node, undefined, undefined, undefined, undefined, this.node.node); return this.item.callback.call(this.node.node, undefined, undefined, undefined, undefined, this.node.node);
} }
} }
@ -240,8 +240,12 @@ export class EzNode {
return this.#makeLookupArray(() => this.app.canvas.getNodeMenuOptions(this.node), "content", EzNodeMenuItem); return this.#makeLookupArray(() => this.app.canvas.getNodeMenuOptions(this.node), "content", EzNodeMenuItem);
} }
select() { get isRemoved() {
this.app.canvas.selectNode(this.node); return !this.app.graph.getNodeById(this.id);
}
select(addToSelection = false) {
this.app.canvas.selectNode(this.node, addToSelection);
} }
// /** // /**
@ -275,12 +279,17 @@ export class EzNode {
if (!s) return p; if (!s) return p;
const name = s[nameProperty]; const name = s[nameProperty];
const item = new ctor(this, i, s);
// @ts-ignore // @ts-ignore
if (!name || name in p) { p.push(item);
if (name) {
// @ts-ignore
if (name in p) {
throw new Error(`Unable to store ${nodeProperty} ${name} on array as name conflicts.`); throw new Error(`Unable to store ${nodeProperty} ${name} on array as name conflicts.`);
} }
}
// @ts-ignore // @ts-ignore
p.push((p[name] = new ctor(this, i, s))); p[name] = item;
return p; return p;
}, Object.assign([], { $: this })); }, Object.assign([], { $: this }));
} }
@ -348,6 +357,19 @@ export class EzGraph {
}, 10); }, 10);
}); });
} }
/**
* @returns { Promise<{
* workflow: {},
* output: Record<string, {
* class_name: string,
* inputs: Record<string, [string, number] | unknown>
* }>}> }
*/
toPrompt() {
// @ts-ignore
return this.app.graphToPrompt();
}
} }
export const Ez = { export const Ez = {
@ -356,12 +378,12 @@ export const Ez = {
* @example * @example
* const { ez, graph } = Ez.graph(app); * const { ez, graph } = Ez.graph(app);
* graph.clear(); * graph.clear();
* const [model, clip, vae] = ez.CheckpointLoaderSimple(); * const [model, clip, vae] = ez.CheckpointLoaderSimple().outputs;
* const [pos] = ez.CLIPTextEncode(clip, { text: "positive" }); * const [pos] = ez.CLIPTextEncode(clip, { text: "positive" }).outputs;
* const [neg] = ez.CLIPTextEncode(clip, { text: "negative" }); * const [neg] = ez.CLIPTextEncode(clip, { text: "negative" }).outputs;
* const [latent] = ez.KSampler(model, pos, neg, ...ez.EmptyLatentImage()); * const [latent] = ez.KSampler(model, pos, neg, ...ez.EmptyLatentImage().outputs).outputs;
* const [image] = ez.VAEDecode(latent, vae); * const [image] = ez.VAEDecode(latent, vae).outputs;
* const saveNode = ez.SaveImage(image).node; * const saveNode = ez.SaveImage(image);
* console.log(saveNode); * console.log(saveNode);
* graph.arrange(); * graph.arrange();
* @param { app } app * @param { app } app

View File

@ -1,16 +1,24 @@
const { mockApi } = require("./setup"); const { mockApi } = require("./setup");
const { Ez } = require("./ezgraph"); const { Ez } = require("./ezgraph");
const lg = require("./litegraph");
/** /**
* *
* @param { Parameters<mockApi>[0] } config * @param { Parameters<mockApi>[0] & { resetEnv?: boolean, preSetup?(app): Promise<void> } } config
* @returns * @returns
*/ */
export async function start(config = undefined) { export async function start(config = {}) {
if(config.resetEnv) {
jest.resetModules();
jest.resetAllMocks();
lg.setup(global);
}
mockApi(config); mockApi(config);
const { app } = require("../../web/scripts/app"); const { app } = require("../../web/scripts/app");
config.preSetup?.(app);
await app.setup(); await app.setup();
return Ez.graph(app, global["LiteGraph"], global["LGraphCanvas"]); return { ...Ez.graph(app, global["LiteGraph"], global["LGraphCanvas"]), app };
} }
/** /**
@ -37,7 +45,7 @@ export function makeNodeDef(name, input, output = {}) {
output_name: [], output_name: [],
output_is_list: [], output_is_list: [],
input: { input: {
required: {} required: {},
}, },
}; };
for (const k in input) { for (const k in input) {
@ -47,7 +55,7 @@ export function makeNodeDef(name, input, output = {}) {
output = output.reduce((p, c) => { output = output.reduce((p, c) => {
p[c] = c; p[c] = c;
return p; return p;
}, {}) }, {});
} }
for (const k in output) { for (const k in output) {
nodeDef.output.push(output[k]); nodeDef.output.push(output[k]);
@ -69,3 +77,30 @@ export function assertNotNullOrUndefined(x) {
expect(x).not.toEqual(undefined); expect(x).not.toEqual(undefined);
return true; return true;
} }
/**
*
* @param { ReturnType<Ez["graph"]>["ez"] } ez
* @param { ReturnType<Ez["graph"]>["graph"] } graph
*/
export function createDefaultWorkflow(ez, graph) {
graph.clear();
const ckpt = ez.CheckpointLoaderSimple();
const pos = ez.CLIPTextEncode(ckpt.outputs.CLIP, { text: "positive" });
const neg = ez.CLIPTextEncode(ckpt.outputs.CLIP, { text: "negative" });
const empty = ez.EmptyLatentImage();
const sampler = ez.KSampler(
ckpt.outputs.MODEL,
pos.outputs.CONDITIONING,
neg.outputs.CONDITIONING,
empty.outputs.LATENT
);
const decode = ez.VAEDecode(sampler.outputs.LATENT, ckpt.outputs.VAE);
const save = ez.SaveImage(decode.outputs.IMAGE);
graph.arrange();
return { ckpt, pos, neg, empty, sampler, decode, save };
}

View File

@ -30,16 +30,20 @@ export function mockApi({ mockExtensions, mockNodeDefs } = {}) {
mockNodeDefs = JSON.parse(fs.readFileSync(path.resolve("./data/object_info.json"))); mockNodeDefs = JSON.parse(fs.readFileSync(path.resolve("./data/object_info.json")));
} }
jest.mock("../../web/scripts/api", () => ({ const events = new EventTarget();
get api() { const mockApi = {
return { addEventListener: events.addEventListener.bind(events),
addEventListener: jest.fn(), removeEventListener: events.removeEventListener.bind(events),
dispatchEvent: events.dispatchEvent.bind(events),
getSystemStats: jest.fn(), getSystemStats: jest.fn(),
getExtensions: jest.fn(() => mockExtensions), getExtensions: jest.fn(() => mockExtensions),
getNodeDefs: jest.fn(() => mockNodeDefs), getNodeDefs: jest.fn(() => mockNodeDefs),
init: jest.fn(), init: jest.fn(),
apiURL: jest.fn((x) => "../../web/" + x), apiURL: jest.fn((x) => "../../web/" + x),
}; };
jest.mock("../../web/scripts/api", () => ({
get api() {
return mockApi;
}, },
})); }));
} }

View File

@ -174,6 +174,213 @@ const colorPalettes = {
"tr-odd-bg-color": "#073642", "tr-odd-bg-color": "#073642",
} }
}, },
},
"arc": {
"id": "arc",
"name": "Arc",
"colors": {
"node_slot": {
"BOOLEAN": "",
"CLIP": "#eacb8b",
"CLIP_VISION": "#A8DADC",
"CLIP_VISION_OUTPUT": "#ad7452",
"CONDITIONING": "#cf876f",
"CONTROL_NET": "#00d78d",
"CONTROL_NET_WEIGHTS": "",
"FLOAT": "",
"GLIGEN": "",
"IMAGE": "#80a1c0",
"IMAGEUPLOAD": "",
"INT": "",
"LATENT": "#b38ead",
"LATENT_KEYFRAME": "",
"MASK": "#a3bd8d",
"MODEL": "#8978a7",
"SAMPLER": "",
"SIGMAS": "",
"STRING": "",
"STYLE_MODEL": "#C2FFAE",
"T2I_ADAPTER_WEIGHTS": "",
"TAESD": "#DCC274",
"TIMESTEP_KEYFRAME": "",
"UPSCALE_MODEL": "",
"VAE": "#be616b"
},
"litegraph_base": {
"BACKGROUND_IMAGE": "",
"CLEAR_BACKGROUND_COLOR": "#2b2f38",
"NODE_TITLE_COLOR": "#b2b7bd",
"NODE_SELECTED_TITLE_COLOR": "#FFF",
"NODE_TEXT_SIZE": 14,
"NODE_TEXT_COLOR": "#AAA",
"NODE_SUBTEXT_SIZE": 12,
"NODE_DEFAULT_COLOR": "#2b2f38",
"NODE_DEFAULT_BGCOLOR": "#242730",
"NODE_DEFAULT_BOXCOLOR": "#6e7581",
"NODE_DEFAULT_SHAPE": "box",
"NODE_BOX_OUTLINE_COLOR": "#FFF",
"DEFAULT_SHADOW_COLOR": "rgba(0,0,0,0.5)",
"DEFAULT_GROUP_FONT": 22,
"WIDGET_BGCOLOR": "#2b2f38",
"WIDGET_OUTLINE_COLOR": "#6e7581",
"WIDGET_TEXT_COLOR": "#DDD",
"WIDGET_SECONDARY_TEXT_COLOR": "#b2b7bd",
"LINK_COLOR": "#9A9",
"EVENT_LINK_COLOR": "#A86",
"CONNECTING_LINK_COLOR": "#AFA"
},
"comfy_base": {
"fg-color": "#fff",
"bg-color": "#2b2f38",
"comfy-menu-bg": "#242730",
"comfy-input-bg": "#2b2f38",
"input-text": "#ddd",
"descrip-text": "#b2b7bd",
"drag-text": "#ccc",
"error-text": "#ff4444",
"border-color": "#6e7581",
"tr-even-bg-color": "#2b2f38",
"tr-odd-bg-color": "#242730"
}
},
},
"nord": {
"id": "nord",
"name": "Nord",
"colors": {
"node_slot": {
"BOOLEAN": "",
"CLIP": "#eacb8b",
"CLIP_VISION": "#A8DADC",
"CLIP_VISION_OUTPUT": "#ad7452",
"CONDITIONING": "#cf876f",
"CONTROL_NET": "#00d78d",
"CONTROL_NET_WEIGHTS": "",
"FLOAT": "",
"GLIGEN": "",
"IMAGE": "#80a1c0",
"IMAGEUPLOAD": "",
"INT": "",
"LATENT": "#b38ead",
"LATENT_KEYFRAME": "",
"MASK": "#a3bd8d",
"MODEL": "#8978a7",
"SAMPLER": "",
"SIGMAS": "",
"STRING": "",
"STYLE_MODEL": "#C2FFAE",
"T2I_ADAPTER_WEIGHTS": "",
"TAESD": "#DCC274",
"TIMESTEP_KEYFRAME": "",
"UPSCALE_MODEL": "",
"VAE": "#be616b"
},
"litegraph_base": {
"BACKGROUND_IMAGE": "",
"CLEAR_BACKGROUND_COLOR": "#212732",
"NODE_TITLE_COLOR": "#999",
"NODE_SELECTED_TITLE_COLOR": "#e5eaf0",
"NODE_TEXT_SIZE": 14,
"NODE_TEXT_COLOR": "#bcc2c8",
"NODE_SUBTEXT_SIZE": 12,
"NODE_DEFAULT_COLOR": "#2e3440",
"NODE_DEFAULT_BGCOLOR": "#161b22",
"NODE_DEFAULT_BOXCOLOR": "#545d70",
"NODE_DEFAULT_SHAPE": "box",
"NODE_BOX_OUTLINE_COLOR": "#e5eaf0",
"DEFAULT_SHADOW_COLOR": "rgba(0,0,0,0.5)",
"DEFAULT_GROUP_FONT": 24,
"WIDGET_BGCOLOR": "#2e3440",
"WIDGET_OUTLINE_COLOR": "#545d70",
"WIDGET_TEXT_COLOR": "#bcc2c8",
"WIDGET_SECONDARY_TEXT_COLOR": "#999",
"LINK_COLOR": "#9A9",
"EVENT_LINK_COLOR": "#A86",
"CONNECTING_LINK_COLOR": "#AFA"
},
"comfy_base": {
"fg-color": "#e5eaf0",
"bg-color": "#2e3440",
"comfy-menu-bg": "#161b22",
"comfy-input-bg": "#2e3440",
"input-text": "#bcc2c8",
"descrip-text": "#999",
"drag-text": "#ccc",
"error-text": "#ff4444",
"border-color": "#545d70",
"tr-even-bg-color": "#2e3440",
"tr-odd-bg-color": "#161b22"
}
},
},
"github": {
"id": "github",
"name": "Github",
"colors": {
"node_slot": {
"BOOLEAN": "",
"CLIP": "#eacb8b",
"CLIP_VISION": "#A8DADC",
"CLIP_VISION_OUTPUT": "#ad7452",
"CONDITIONING": "#cf876f",
"CONTROL_NET": "#00d78d",
"CONTROL_NET_WEIGHTS": "",
"FLOAT": "",
"GLIGEN": "",
"IMAGE": "#80a1c0",
"IMAGEUPLOAD": "",
"INT": "",
"LATENT": "#b38ead",
"LATENT_KEYFRAME": "",
"MASK": "#a3bd8d",
"MODEL": "#8978a7",
"SAMPLER": "",
"SIGMAS": "",
"STRING": "",
"STYLE_MODEL": "#C2FFAE",
"T2I_ADAPTER_WEIGHTS": "",
"TAESD": "#DCC274",
"TIMESTEP_KEYFRAME": "",
"UPSCALE_MODEL": "",
"VAE": "#be616b"
},
"litegraph_base": {
"BACKGROUND_IMAGE": "",
"CLEAR_BACKGROUND_COLOR": "#040506",
"NODE_TITLE_COLOR": "#999",
"NODE_SELECTED_TITLE_COLOR": "#e5eaf0",
"NODE_TEXT_SIZE": 14,
"NODE_TEXT_COLOR": "#bcc2c8",
"NODE_SUBTEXT_SIZE": 12,
"NODE_DEFAULT_COLOR": "#161b22",
"NODE_DEFAULT_BGCOLOR": "#13171d",
"NODE_DEFAULT_BOXCOLOR": "#30363d",
"NODE_DEFAULT_SHAPE": "box",
"NODE_BOX_OUTLINE_COLOR": "#e5eaf0",
"DEFAULT_SHADOW_COLOR": "rgba(0,0,0,0.5)",
"DEFAULT_GROUP_FONT": 24,
"WIDGET_BGCOLOR": "#161b22",
"WIDGET_OUTLINE_COLOR": "#30363d",
"WIDGET_TEXT_COLOR": "#bcc2c8",
"WIDGET_SECONDARY_TEXT_COLOR": "#999",
"LINK_COLOR": "#9A9",
"EVENT_LINK_COLOR": "#A86",
"CONNECTING_LINK_COLOR": "#AFA"
},
"comfy_base": {
"fg-color": "#e5eaf0",
"bg-color": "#161b22",
"comfy-menu-bg": "#13171d",
"comfy-input-bg": "#161b22",
"input-text": "#bcc2c8",
"descrip-text": "#999",
"drag-text": "#ccc",
"error-text": "#ff4444",
"border-color": "#30363d",
"tr-even-bg-color": "#161b22",
"tr-odd-bg-color": "#13171d"
}
},
} }
}; };

File diff suppressed because it is too large Load Diff

View File

@ -1,5 +1,6 @@
import { app } from "../../scripts/app.js"; import { app } from "../../scripts/app.js";
import { ComfyDialog, $el } from "../../scripts/ui.js"; import { ComfyDialog, $el } from "../../scripts/ui.js";
import { GroupNodeConfig, GroupNodeHandler } from "./groupNode.js";
// Adds the ability to save and add multiple nodes as a template // Adds the ability to save and add multiple nodes as a template
// To save: // To save:
@ -291,11 +292,11 @@ app.registerExtension({
setup() { setup() {
const manage = new ManageTemplates(); const manage = new ManageTemplates();
const clipboardAction = (cb) => { const clipboardAction = async (cb) => {
// We use the clipboard functions but dont want to overwrite the current user clipboard // We use the clipboard functions but dont want to overwrite the current user clipboard
// Restore it after we've run our callback // Restore it after we've run our callback
const old = localStorage.getItem("litegrapheditor_clipboard"); const old = localStorage.getItem("litegrapheditor_clipboard");
cb(); await cb();
localStorage.setItem("litegrapheditor_clipboard", old); localStorage.setItem("litegrapheditor_clipboard", old);
}; };
@ -309,13 +310,31 @@ app.registerExtension({
disabled: !Object.keys(app.canvas.selected_nodes || {}).length, disabled: !Object.keys(app.canvas.selected_nodes || {}).length,
callback: () => { callback: () => {
const name = prompt("Enter name"); const name = prompt("Enter name");
if (!name || !name.trim()) return; if (!name?.trim()) return;
clipboardAction(() => { clipboardAction(() => {
app.canvas.copyToClipboard(); app.canvas.copyToClipboard();
let data = localStorage.getItem("litegrapheditor_clipboard");
data = JSON.parse(data);
const nodeIds = Object.keys(app.canvas.selected_nodes);
for (let i = 0; i < nodeIds.length; i++) {
const node = app.graph.getNodeById(nodeIds[i]);
const nodeData = node?.constructor.nodeData;
let groupData = GroupNodeHandler.getGroupData(node);
if (groupData) {
groupData = groupData.nodeData;
if (!data.groupNodes) {
data.groupNodes = {};
}
data.groupNodes[nodeData.name] = groupData;
data.nodes[i].type = nodeData.name;
}
}
manage.templates.push({ manage.templates.push({
name, name,
data: localStorage.getItem("litegrapheditor_clipboard"), data: JSON.stringify(data),
}); });
manage.store(); manage.store();
}); });
@ -323,15 +342,19 @@ app.registerExtension({
}); });
// Map each template to a menu item // Map each template to a menu item
const subItems = manage.templates.map((t) => ({ const subItems = manage.templates.map((t) => {
return {
content: t.name, content: t.name,
callback: () => { callback: () => {
clipboardAction(() => { clipboardAction(async () => {
const data = JSON.parse(t.data);
await GroupNodeConfig.registerFromWorkflow(data.groupNodes, {});
localStorage.setItem("litegrapheditor_clipboard", t.data); localStorage.setItem("litegrapheditor_clipboard", t.data);
app.canvas.pasteFromClipboard(); app.canvas.pasteFromClipboard();
}); });
}, },
})); };
});
subItems.push(null, { subItems.push(null, {
content: "Manage", content: "Manage",

View File

@ -0,0 +1,150 @@
import { app } from "../../scripts/app.js";
const MAX_HISTORY = 50;
let undo = [];
let redo = [];
let activeState = null;
let isOurLoad = false;
function checkState() {
const currentState = app.graph.serialize();
if (!graphEqual(activeState, currentState)) {
undo.push(activeState);
if (undo.length > MAX_HISTORY) {
undo.shift();
}
activeState = clone(currentState);
redo.length = 0;
}
}
const loadGraphData = app.loadGraphData;
app.loadGraphData = async function () {
const v = await loadGraphData.apply(this, arguments);
if (isOurLoad) {
isOurLoad = false;
} else {
checkState();
}
return v;
};
function clone(obj) {
try {
if (typeof structuredClone !== "undefined") {
return structuredClone(obj);
}
} catch (error) {
// structuredClone is stricter than using JSON.parse/stringify so fallback to that
}
return JSON.parse(JSON.stringify(obj));
}
function graphEqual(a, b, root = true) {
if (a === b) return true;
if (typeof a == "object" && a && typeof b == "object" && b) {
const keys = Object.getOwnPropertyNames(a);
if (keys.length != Object.getOwnPropertyNames(b).length) {
return false;
}
for (const key of keys) {
let av = a[key];
let bv = b[key];
if (root && key === "nodes") {
// Nodes need to be sorted as the order changes when selecting nodes
av = [...av].sort((a, b) => a.id - b.id);
bv = [...bv].sort((a, b) => a.id - b.id);
}
if (!graphEqual(av, bv, false)) {
return false;
}
}
return true;
}
return false;
}
const undoRedo = async (e) => {
if (e.ctrlKey || e.metaKey) {
if (e.key === "y") {
const prevState = redo.pop();
if (prevState) {
undo.push(activeState);
isOurLoad = true;
await app.loadGraphData(prevState);
activeState = prevState;
}
return true;
} else if (e.key === "z") {
const prevState = undo.pop();
if (prevState) {
redo.push(activeState);
isOurLoad = true;
await app.loadGraphData(prevState);
activeState = prevState;
}
return true;
}
}
};
const bindInput = (activeEl) => {
if (activeEl?.tagName !== "CANVAS" && activeEl?.tagName !== "BODY") {
for (const evt of ["change", "input", "blur"]) {
if (`on${evt}` in activeEl) {
const listener = () => {
checkState();
activeEl.removeEventListener(evt, listener);
};
activeEl.addEventListener(evt, listener);
return true;
}
}
}
};
window.addEventListener(
"keydown",
(e) => {
requestAnimationFrame(async () => {
const activeEl = document.activeElement;
if (activeEl?.tagName === "INPUT" || activeEl?.type === "textarea") {
// Ignore events on inputs, they have their native history
return;
}
// Check if this is a ctrl+z ctrl+y
if (await undoRedo(e)) return;
// If our active element is some type of input then handle changes after they're done
if (bindInput(activeEl)) return;
checkState();
});
},
true
);
// Handle clicking DOM elements (e.g. widgets)
window.addEventListener("mouseup", () => {
checkState();
});
// Handle litegraph clicks
const processMouseUp = LGraphCanvas.prototype.processMouseUp;
LGraphCanvas.prototype.processMouseUp = function (e) {
const v = processMouseUp.apply(this, arguments);
checkState();
return v;
};
const processMouseDown = LGraphCanvas.prototype.processMouseDown;
LGraphCanvas.prototype.processMouseDown = function (e) {
const v = processMouseDown.apply(this, arguments);
checkState();
return v;
};

View File

@ -1,4 +1,4 @@
import { ComfyWidgets, addValueControlWidget } from "../../scripts/widgets.js"; import { ComfyWidgets, addValueControlWidgets } from "../../scripts/widgets.js";
import { app } from "../../scripts/app.js"; import { app } from "../../scripts/app.js";
const CONVERTED_TYPE = "converted-widget"; const CONVERTED_TYPE = "converted-widget";
@ -121,6 +121,110 @@ function isValidCombo(combo, obj) {
return true; return true;
} }
export function mergeIfValid(output, config2, forceUpdate, recreateWidget, config1) {
if (!config1) {
config1 = output.widget[CONFIG] ?? output.widget[GET_CONFIG]();
}
if (config1[0] instanceof Array) {
if (!isValidCombo(config1[0], config2[0])) return false;
} else if (config1[0] !== config2[0]) {
// Types dont match
console.log(`connection rejected: types dont match`, config1[0], config2[0]);
return false;
}
const keys = new Set([...Object.keys(config1[1] ?? {}), ...Object.keys(config2[1] ?? {})]);
let customConfig;
const getCustomConfig = () => {
if (!customConfig) {
if (typeof structuredClone === "undefined") {
customConfig = JSON.parse(JSON.stringify(config1[1] ?? {}));
} else {
customConfig = structuredClone(config1[1] ?? {});
}
}
return customConfig;
};
const isNumber = config1[0] === "INT" || config1[0] === "FLOAT";
for (const k of keys.values()) {
if (k !== "default" && k !== "forceInput" && k !== "defaultInput") {
let v1 = config1[1][k];
let v2 = config2[1]?.[k];
if (v1 === v2 || (!v1 && !v2)) continue;
if (isNumber) {
if (k === "min") {
const theirMax = config2[1]?.["max"];
if (theirMax != null && v1 > theirMax) {
console.log("connection rejected: min > max", v1, theirMax);
return false;
}
getCustomConfig()[k] = v1 == null ? v2 : v2 == null ? v1 : Math.max(v1, v2);
continue;
} else if (k === "max") {
const theirMin = config2[1]?.["min"];
if (theirMin != null && v1 < theirMin) {
console.log("connection rejected: max < min", v1, theirMin);
return false;
}
getCustomConfig()[k] = v1 == null ? v2 : v2 == null ? v1 : Math.min(v1, v2);
continue;
} else if (k === "step") {
let step;
if (v1 == null) {
// No current step
step = v2;
} else if (v2 == null) {
// No new step
step = v1;
} else {
if (v1 < v2) {
// Ensure v1 is larger for the mod
const a = v2;
v2 = v1;
v1 = a;
}
if (v1 % v2) {
console.log("connection rejected: steps not divisible", "current:", v1, "new:", v2);
return false;
}
step = v1;
}
getCustomConfig()[k] = step;
continue;
}
}
console.log(`connection rejected: config ${k} values dont match`, v1, v2);
return false;
}
}
if (customConfig || forceUpdate) {
if (customConfig) {
output.widget[CONFIG] = [config1[0], customConfig];
}
const widget = recreateWidget?.call(this);
// When deleting a node this can be null
if (widget) {
const min = widget.options.min;
const max = widget.options.max;
if (min != null && widget.value < min) widget.value = min;
if (max != null && widget.value > max) widget.value = max;
widget.callback(widget.value);
}
}
return { customConfig };
}
app.registerExtension({ app.registerExtension({
name: "Comfy.WidgetInputs", name: "Comfy.WidgetInputs",
async beforeRegisterNodeDef(nodeType, nodeData, app) { async beforeRegisterNodeDef(nodeType, nodeData, app) {
@ -308,7 +412,7 @@ app.registerExtension({
this.isVirtualNode = true; this.isVirtualNode = true;
} }
applyToGraph() { applyToGraph(extraLinks = []) {
if (!this.outputs[0].links?.length) return; if (!this.outputs[0].links?.length) return;
function get_links(node) { function get_links(node) {
@ -325,10 +429,9 @@ app.registerExtension({
return links; return links;
} }
let links = get_links(this); let links = [...get_links(this).map((l) => app.graph.links[l]), ...extraLinks];
// For each output link copy our value over the original widget value // For each output link copy our value over the original widget value
for (const l of links) { for (const linkInfo of links) {
const linkInfo = app.graph.links[l];
const node = this.graph.getNodeById(linkInfo.target_id); const node = this.graph.getNodeById(linkInfo.target_id);
const input = node.inputs[linkInfo.target_slot]; const input = node.inputs[linkInfo.target_slot];
const widgetName = input.widget.name; const widgetName = input.widget.name;
@ -405,7 +508,12 @@ app.registerExtension({
} }
if (this.outputs[slot].links?.length) { if (this.outputs[slot].links?.length) {
return this.#isValidConnection(input); const valid = this.#isValidConnection(input);
if (valid) {
// On connect of additional outputs, copy our value to their widget
this.applyToGraph([{ target_id: target_node.id, target_slot }]);
}
return valid;
} }
} }
@ -462,12 +570,16 @@ app.registerExtension({
} }
} }
if (widget.type === "number" || widget.type === "combo") { if (!inputData?.[1]?.control_after_generate && (widget.type === "number" || widget.type === "combo")) {
let control_value = this.widgets_values?.[1]; let control_value = this.widgets_values?.[1];
if (!control_value) { if (!control_value) {
control_value = "fixed"; control_value = "fixed";
} }
addValueControlWidget(this, widget, control_value); addValueControlWidgets(this, widget, control_value, undefined, inputData);
let filter = this.widgets_values?.[2];
if(filter && this.widgets.length === 3) {
this.widgets[2].value = filter;
}
} }
// When our value changes, update other widgets to reflect our changes // When our value changes, update other widgets to reflect our changes
@ -503,6 +615,7 @@ app.registerExtension({
this.#removeWidgets(); this.#removeWidgets();
this.#onFirstConnection(true); this.#onFirstConnection(true);
for (let i = 0; i < this.widgets?.length; i++) this.widgets[i].value = values[i]; for (let i = 0; i < this.widgets?.length; i++) this.widgets[i].value = values[i];
return this.widgets[0];
} }
#mergeWidgetConfig() { #mergeWidgetConfig() {
@ -543,108 +656,8 @@ app.registerExtension({
#isValidConnection(input, forceUpdate) { #isValidConnection(input, forceUpdate) {
// Only allow connections where the configs match // Only allow connections where the configs match
const output = this.outputs[0]; const output = this.outputs[0];
const config1 = output.widget[CONFIG] ?? output.widget[GET_CONFIG]();
const config2 = input.widget[GET_CONFIG](); const config2 = input.widget[GET_CONFIG]();
return !!mergeIfValid.call(this, output, config2, forceUpdate, this.#recreateWidget);
if (config1[0] instanceof Array) {
if (!isValidCombo(config1[0], config2[0])) return false;
} else if (config1[0] !== config2[0]) {
// Types dont match
console.log(`connection rejected: types dont match`, config1[0], config2[0]);
return false;
}
const keys = new Set([...Object.keys(config1[1] ?? {}), ...Object.keys(config2[1] ?? {})]);
let customConfig;
const getCustomConfig = () => {
if (!customConfig) {
if (typeof structuredClone === "undefined") {
customConfig = JSON.parse(JSON.stringify(config1[1] ?? {}));
} else {
customConfig = structuredClone(config1[1] ?? {});
}
}
return customConfig;
};
const isNumber = config1[0] === "INT" || config1[0] === "FLOAT";
for (const k of keys.values()) {
if (k !== "default" && k !== "forceInput" && k !== "defaultInput") {
let v1 = config1[1][k];
let v2 = config2[1][k];
if (v1 === v2 || (!v1 && !v2)) continue;
if (isNumber) {
if (k === "min") {
const theirMax = config2[1]["max"];
if (theirMax != null && v1 > theirMax) {
console.log("connection rejected: min > max", v1, theirMax);
return false;
}
getCustomConfig()[k] = v1 == null ? v2 : v2 == null ? v1 : Math.max(v1, v2);
continue;
} else if (k === "max") {
const theirMin = config2[1]["min"];
if (theirMin != null && v1 < theirMin) {
console.log("connection rejected: max < min", v1, theirMin);
return false;
}
getCustomConfig()[k] = v1 == null ? v2 : v2 == null ? v1 : Math.min(v1, v2);
continue;
} else if (k === "step") {
let step;
if (v1 == null) {
// No current step
step = v2;
} else if (v2 == null) {
// No new step
step = v1;
} else {
if (v1 < v2) {
// Ensure v1 is larger for the mod
const a = v2;
v2 = v1;
v1 = a;
}
if (v1 % v2) {
console.log("connection rejected: steps not divisible", "current:", v1, "new:", v2);
return false;
}
step = v1;
}
getCustomConfig()[k] = step;
continue;
}
}
console.log(`connection rejected: config ${k} values dont match`, v1, v2);
return false;
}
}
if (customConfig || forceUpdate) {
if (customConfig) {
output.widget[CONFIG] = [config1[0], customConfig];
}
this.#recreateWidget();
const widget = this.widgets[0];
// When deleting a node this can be null
if (widget) {
const min = widget.options.min;
const max = widget.options.max;
if (min != null && widget.value < min) widget.value = min;
if (max != null && widget.value > max) widget.value = max;
widget.callback(widget.value);
}
}
return true;
} }
#removeWidgets() { #removeWidgets() {

View File

@ -2533,7 +2533,7 @@
var w = this.widgets[i]; var w = this.widgets[i];
if(!w) if(!w)
continue; continue;
if(w.options && w.options.property && this.properties[ w.options.property ]) if(w.options && w.options.property && (this.properties[ w.options.property ] != undefined))
w.value = JSON.parse( JSON.stringify( this.properties[ w.options.property ] ) ); w.value = JSON.parse( JSON.stringify( this.properties[ w.options.property ] ) );
} }
if (info.widgets_values) { if (info.widgets_values) {
@ -5714,10 +5714,10 @@ LGraphNode.prototype.executeAction = function(action)
* @method enableWebGL * @method enableWebGL
**/ **/
LGraphCanvas.prototype.enableWebGL = function() { LGraphCanvas.prototype.enableWebGL = function() {
if (typeof GL === undefined) { if (typeof GL === "undefined") {
throw "litegl.js must be included to use a WebGL canvas"; throw "litegl.js must be included to use a WebGL canvas";
} }
if (typeof enableWebGLCanvas === undefined) { if (typeof enableWebGLCanvas === "undefined") {
throw "webglCanvas.js must be included to use this feature"; throw "webglCanvas.js must be included to use this feature";
} }
@ -7110,15 +7110,16 @@ LGraphNode.prototype.executeAction = function(action)
} }
}; };
LGraphCanvas.prototype.copyToClipboard = function() { LGraphCanvas.prototype.copyToClipboard = function(nodes) {
var clipboard_info = { var clipboard_info = {
nodes: [], nodes: [],
links: [] links: []
}; };
var index = 0; var index = 0;
var selected_nodes_array = []; var selected_nodes_array = [];
for (var i in this.selected_nodes) { if (!nodes) nodes = this.selected_nodes;
var node = this.selected_nodes[i]; for (var i in nodes) {
var node = nodes[i];
if (node.clonable === false) if (node.clonable === false)
continue; continue;
node._relative_id = index; node._relative_id = index;
@ -11702,7 +11703,7 @@ LGraphNode.prototype.executeAction = function(action)
default: default:
iS = 0; // try with first if no name set iS = 0; // try with first if no name set
} }
if (typeof options.node_from.outputs[iS] !== undefined){ if (typeof options.node_from.outputs[iS] !== "undefined"){
if (iS!==false && iS>-1){ if (iS!==false && iS>-1){
options.node_from.connectByType( iS, node, options.node_from.outputs[iS].type ); options.node_from.connectByType( iS, node, options.node_from.outputs[iS].type );
} }
@ -11730,7 +11731,7 @@ LGraphNode.prototype.executeAction = function(action)
default: default:
iS = 0; // try with first if no name set iS = 0; // try with first if no name set
} }
if (typeof options.node_to.inputs[iS] !== undefined){ if (typeof options.node_to.inputs[iS] !== "undefined"){
if (iS!==false && iS>-1){ if (iS!==false && iS>-1){
// try connection // try connection
options.node_to.connectByTypeOutput(iS,node,options.node_to.inputs[iS].type); options.node_to.connectByTypeOutput(iS,node,options.node_to.inputs[iS].type);

View File

@ -254,9 +254,9 @@ class ComfyApi extends EventTarget {
* Gets the prompt execution history * Gets the prompt execution history
* @returns Prompt history including node outputs * @returns Prompt history including node outputs
*/ */
async getHistory() { async getHistory(max_items=200) {
try { try {
const res = await this.fetchApi("/history"); const res = await this.fetchApi(`/history?max_items=${max_items}`);
return { History: Object.values(await res.json()) }; return { History: Object.values(await res.json()) };
} catch (error) { } catch (error) {
console.error(error); console.error(error);

View File

@ -4,7 +4,10 @@ import { ComfyUI, $el } from "./ui.js";
import { api } from "./api.js"; import { api } from "./api.js";
import { defaultGraph } from "./defaultGraph.js"; import { defaultGraph } from "./defaultGraph.js";
import { getPngMetadata, getWebpMetadata, importA1111, getLatentMetadata } from "./pnginfo.js"; import { getPngMetadata, getWebpMetadata, importA1111, getLatentMetadata } from "./pnginfo.js";
import { addDomClippingSetting } from "./domWidget.js";
import { createImageHost, calculateImageGrid } from "./ui/imagePreview.js"
export const ANIM_PREVIEW_WIDGET = "$$comfy_animation_preview"
function sanitizeNodeName(string) { function sanitizeNodeName(string) {
let entityMap = { let entityMap = {
@ -409,7 +412,9 @@ export class ComfyApp {
return shiftY; return shiftY;
} }
node.prototype.setSizeForImage = function () { node.prototype.setSizeForImage = function (force) {
if(!force && this.animatedImages) return;
if (this.inputHeight) { if (this.inputHeight) {
this.setSize(this.size); this.setSize(this.size);
return; return;
@ -426,13 +431,20 @@ export class ComfyApp {
let imagesChanged = false let imagesChanged = false
const output = app.nodeOutputs[this.id + ""]; const output = app.nodeOutputs[this.id + ""];
if (output && output.images) { if (output?.images) {
this.animatedImages = output?.animated?.find(Boolean);
if (this.images !== output.images) { if (this.images !== output.images) {
this.images = output.images; this.images = output.images;
imagesChanged = true; imagesChanged = true;
imgURLs = imgURLs.concat(output.images.map(params => { imgURLs = imgURLs.concat(
return api.apiURL("/view?" + new URLSearchParams(params).toString() + app.getPreviewFormatParam() + app.getRandParam()); output.images.map((params) => {
})) return api.apiURL(
"/view?" +
new URLSearchParams(params).toString() +
(this.animatedImages ? "" : app.getPreviewFormatParam()) + app.getRandParam()
);
})
);
} }
} }
@ -511,7 +523,35 @@ export class ComfyApp {
return true; return true;
} }
if (this.imgs && this.imgs.length) { if (this.imgs?.length) {
const widgetIdx = this.widgets?.findIndex((w) => w.name === ANIM_PREVIEW_WIDGET);
if(this.animatedImages) {
// Instead of using the canvas we'll use a IMG
if(widgetIdx > -1) {
// Replace content
const widget = this.widgets[widgetIdx];
widget.options.host.updateImages(this.imgs);
} else {
const host = createImageHost(this);
this.setSizeForImage(true);
const widget = this.addDOMWidget(ANIM_PREVIEW_WIDGET, "img", host.el, {
host,
getHeight: host.getHeight,
onDraw: host.onDraw,
hideOnZoom: false
});
widget.serializeValue = () => undefined;
widget.options.host.updateImages(this.imgs);
}
return;
}
if (widgetIdx > -1) {
this.widgets[widgetIdx].onRemove?.();
this.widgets.splice(widgetIdx, 1);
}
const canvas = app.graph.list_of_graphcanvas[0]; const canvas = app.graph.list_of_graphcanvas[0];
const mouse = canvas.graph_mouse; const mouse = canvas.graph_mouse;
if (!canvas.pointer_is_down && this.pointerDown) { if (!canvas.pointer_is_down && this.pointerDown) {
@ -551,31 +591,7 @@ export class ComfyApp {
} }
else { else {
cell_padding = 0; cell_padding = 0;
let best = 0; ({ cellWidth, cellHeight, cols, shiftX } = calculateImageGrid(this.imgs, dw, dh));
let w = this.imgs[0].naturalWidth;
let h = this.imgs[0].naturalHeight;
// compact style
for (let c = 1; c <= numImages; c++) {
const rows = Math.ceil(numImages / c);
const cW = dw / c;
const cH = dh / rows;
const scaleX = cW / w;
const scaleY = cH / h;
const scale = Math.min(scaleX, scaleY, 1);
const imageW = w * scale;
const imageH = h * scale;
const area = imageW * imageH * numImages;
if (area > best) {
best = area;
cellWidth = imageW;
cellHeight = imageH;
cols = c;
shiftX = c * ((cW - imageW) / 2);
}
}
} }
let anyHovered = false; let anyHovered = false;
@ -767,7 +783,7 @@ export class ComfyApp {
* Adds a handler on paste that extracts and loads images or workflows from pasted JSON data * Adds a handler on paste that extracts and loads images or workflows from pasted JSON data
*/ */
#addPasteHandler() { #addPasteHandler() {
document.addEventListener("paste", (e) => { document.addEventListener("paste", async (e) => {
// ctrl+shift+v is used to paste nodes with connections // ctrl+shift+v is used to paste nodes with connections
// this is handled by litegraph // this is handled by litegraph
if(this.shiftDown) return; if(this.shiftDown) return;
@ -815,7 +831,7 @@ export class ComfyApp {
} }
if (workflow && workflow.version && workflow.nodes && workflow.extra) { if (workflow && workflow.version && workflow.nodes && workflow.extra) {
this.loadGraphData(workflow); await this.loadGraphData(workflow);
} }
else { else {
if (e.target.type === "text" || e.target.type === "textarea") { if (e.target.type === "text" || e.target.type === "textarea") {
@ -1165,7 +1181,19 @@ export class ComfyApp {
}); });
api.addEventListener("executed", ({ detail }) => { api.addEventListener("executed", ({ detail }) => {
const output = this.nodeOutputs[detail.node];
if (detail.merge && output) {
for (const k in detail.output ?? {}) {
const v = output[k];
if (v instanceof Array) {
output[k] = v.concat(detail.output[k]);
} else {
output[k] = detail.output[k];
}
}
} else {
this.nodeOutputs[detail.node] = detail.output; this.nodeOutputs[detail.node] = detail.output;
}
const node = this.graph.getNodeById(detail.node); const node = this.graph.getNodeById(detail.node);
if (node) { if (node) {
if (node.onExecuted) if (node.onExecuted)
@ -1276,9 +1304,11 @@ export class ComfyApp {
canvasEl.tabIndex = "1"; canvasEl.tabIndex = "1";
document.body.prepend(canvasEl); document.body.prepend(canvasEl);
addDomClippingSetting();
this.#addProcessMouseHandler(); this.#addProcessMouseHandler();
this.#addProcessKeyHandler(); this.#addProcessKeyHandler();
this.#addConfigureHandler(); this.#addConfigureHandler();
this.#addApiUpdateHandlers();
this.graph = new LGraph(); this.graph = new LGraph();
@ -1315,7 +1345,7 @@ export class ComfyApp {
const json = localStorage.getItem("workflow"); const json = localStorage.getItem("workflow");
if (json) { if (json) {
const workflow = JSON.parse(json); const workflow = JSON.parse(json);
this.loadGraphData(workflow); await this.loadGraphData(workflow);
restored = true; restored = true;
} }
} catch (err) { } catch (err) {
@ -1324,7 +1354,7 @@ export class ComfyApp {
// We failed to restore a workflow so load the default // We failed to restore a workflow so load the default
if (!restored) { if (!restored) {
this.loadGraphData(); await this.loadGraphData();
} }
// Save current workflow automatically // Save current workflow automatically
@ -1332,7 +1362,6 @@ export class ComfyApp {
this.#addDrawNodeHandler(); this.#addDrawNodeHandler();
this.#addDrawGroupsHandler(); this.#addDrawGroupsHandler();
this.#addApiUpdateHandlers();
this.#addDropHandler(); this.#addDropHandler();
this.#addCopyHandler(); this.#addCopyHandler();
this.#addPasteHandler(); this.#addPasteHandler();
@ -1352,24 +1381,27 @@ export class ComfyApp {
await this.#invokeExtensionsAsync("registerCustomNodes"); await this.#invokeExtensionsAsync("registerCustomNodes");
} }
async registerNodesFromDefs(defs) { getWidgetType(inputData, inputName) {
await this.#invokeExtensionsAsync("addCustomNodeDefs", defs); const type = inputData[0];
// Generate list of known widgets if (Array.isArray(type)) {
const widgets = Object.assign( return "COMBO";
{}, } else if (`${type}:${inputName}` in this.widgets) {
ComfyWidgets, return `${type}:${inputName}`;
...(await this.#invokeExtensionsAsync("getCustomWidgets")).filter(Boolean) } else if (type in this.widgets) {
); return type;
} else {
return null;
}
}
// Register a node for each definition async registerNodeDef(nodeId, nodeData) {
for (const nodeId in defs) { const self = this;
const nodeData = defs[nodeId];
const node = Object.assign( const node = Object.assign(
function ComfyNode() { function ComfyNode() {
var inputs = nodeData["input"]["required"]; var inputs = nodeData["input"]["required"];
if (nodeData["input"]["optional"] != undefined) { if (nodeData["input"]["optional"] != undefined) {
inputs = Object.assign({}, nodeData["input"]["required"], nodeData["input"]["optional"]) inputs = Object.assign({}, nodeData["input"]["required"], nodeData["input"]["optional"]);
} }
const config = { minWidth: 1, minHeight: 1 }; const config = { minWidth: 1, minHeight: 1 };
for (const inputName in inputs) { for (const inputName in inputs) {
@ -1377,15 +1409,13 @@ export class ComfyApp {
const type = inputData[0]; const type = inputData[0];
let widgetCreated = true; let widgetCreated = true;
if (Array.isArray(type)) { const widgetType = self.getWidgetType(inputData, inputName);
// Enums if(widgetType) {
Object.assign(config, widgets.COMBO(this, inputName, inputData, app) || {}); if(widgetType === "COMBO") {
} else if (`${type}:${inputName}` in widgets) { Object.assign(config, self.widgets.COMBO(this, inputName, inputData, app) || {});
// Support custom widgets by Type:Name } else {
Object.assign(config, widgets[`${type}:${inputName}`](this, inputName, inputData, app) || {}); Object.assign(config, self.widgets[widgetType](this, inputName, inputData, app) || {});
} else if (type in widgets) { }
// Standard type widgets
Object.assign(config, widgets[type](this, inputName, inputData, app) || {});
} else { } else {
// Node connection inputs // Node connection inputs
this.addInput(inputName, type); this.addInput(inputName, type);
@ -1434,6 +1464,21 @@ export class ComfyApp {
LiteGraph.registerNodeType(nodeId, node); LiteGraph.registerNodeType(nodeId, node);
node.category = nodeData.category; node.category = nodeData.category;
} }
async registerNodesFromDefs(defs) {
await this.#invokeExtensionsAsync("addCustomNodeDefs", defs);
// Generate list of known widgets
this.widgets = Object.assign(
{},
ComfyWidgets,
...(await this.#invokeExtensionsAsync("getCustomWidgets")).filter(Boolean)
);
// Register a node for each definition
for (const nodeId in defs) {
this.registerNodeDef(nodeId, defs[nodeId]);
}
} }
loadTemplateData(templateData) { loadTemplateData(templateData) {
@ -1475,9 +1520,14 @@ export class ComfyApp {
showMissingNodesError(missingNodeTypes, hasAddedNodes = true) { showMissingNodesError(missingNodeTypes, hasAddedNodes = true) {
this.ui.dialog.show( this.ui.dialog.show(
`When loading the graph, the following node types were not found: <ul>${Array.from(new Set(missingNodeTypes)).map( $el("div", [
(t) => `<li>${t}</li>` $el("span", { textContent: "When loading the graph, the following node types were not found: " }),
).join("")}</ul>${hasAddedNodes ? "Nodes that have failed to load will show as red on the graph." : ""}` $el(
"ul",
Array.from(new Set(missingNodeTypes)).map((t) => $el("li", { textContent: t }))
),
...(hasAddedNodes ? [$el("span", { textContent: "Nodes that have failed to load will show as red on the graph." })] : []),
])
); );
this.logging.addEntry("Comfy.App", "warn", { this.logging.addEntry("Comfy.App", "warn", {
MissingNodes: missingNodeTypes, MissingNodes: missingNodeTypes,
@ -1488,31 +1538,35 @@ export class ComfyApp {
* Populates the graph with the specified workflow data * Populates the graph with the specified workflow data
* @param {*} graphData A serialized graph object * @param {*} graphData A serialized graph object
*/ */
loadGraphData(graphData) { async loadGraphData(graphData) {
this.clean(); this.clean();
let reset_invalid_values = false; let reset_invalid_values = false;
if (!graphData) { if (!graphData) {
if (typeof structuredClone === "undefined") graphData = defaultGraph;
{
graphData = JSON.parse(JSON.stringify(defaultGraph));
}else
{
graphData = structuredClone(defaultGraph);
}
reset_invalid_values = true; reset_invalid_values = true;
} }
if (typeof structuredClone === "undefined")
{
graphData = JSON.parse(JSON.stringify(graphData));
}else
{
graphData = structuredClone(graphData);
}
const missingNodeTypes = []; const missingNodeTypes = [];
await this.#invokeExtensionsAsync("beforeConfigureGraph", graphData, missingNodeTypes);
for (let n of graphData.nodes) { for (let n of graphData.nodes) {
// Patch T2IAdapterLoader to ControlNetLoader since they are the same node now // Patch T2IAdapterLoader to ControlNetLoader since they are the same node now
if (n.type == "T2IAdapterLoader") n.type = "ControlNetLoader"; if (n.type == "T2IAdapterLoader") n.type = "ControlNetLoader";
if (n.type == "ConditioningAverage ") n.type = "ConditioningAverage"; //typo fix if (n.type == "ConditioningAverage ") n.type = "ConditioningAverage"; //typo fix
if (n.type == "SDV_img2vid_Conditioning") n.type = "SVD_img2vid_Conditioning"; //typo fix
// Find missing node types // Find missing node types
if (!(n.type in LiteGraph.registered_node_types)) { if (!(n.type in LiteGraph.registered_node_types)) {
n.type = sanitizeNodeName(n.type);
missingNodeTypes.push(n.type); missingNodeTypes.push(n.type);
n.type = sanitizeNodeName(n.type);
} }
} }
@ -1604,6 +1658,7 @@ export class ComfyApp {
if (missingNodeTypes.length) { if (missingNodeTypes.length) {
this.showMissingNodesError(missingNodeTypes); this.showMissingNodesError(missingNodeTypes);
} }
await this.#invokeExtensionsAsync("afterConfigureGraph", missingNodeTypes);
} }
/** /**
@ -1611,22 +1666,24 @@ export class ComfyApp {
* @returns The workflow and node links * @returns The workflow and node links
*/ */
async graphToPrompt() { async graphToPrompt() {
for (const node of this.graph.computeExecutionOrder(false)) { for (const outerNode of this.graph.computeExecutionOrder(false)) {
const innerNodes = outerNode.getInnerNodes ? outerNode.getInnerNodes() : [outerNode];
for (const node of innerNodes) {
if (node.isVirtualNode) { if (node.isVirtualNode) {
// Don't serialize frontend only nodes but let them make changes // Don't serialize frontend only nodes but let them make changes
if (node.applyToGraph) { if (node.applyToGraph) {
node.applyToGraph(); node.applyToGraph();
} }
continue; }
} }
} }
const workflow = this.graph.serialize(); const workflow = this.graph.serialize();
const output = {}; const output = {};
// Process nodes in order of execution // Process nodes in order of execution
for (const node of this.graph.computeExecutionOrder(false)) { for (const outerNode of this.graph.computeExecutionOrder(false)) {
const n = workflow.nodes.find((n) => n.id === node.id); const innerNodes = outerNode.getInnerNodes ? outerNode.getInnerNodes() : [outerNode];
for (const node of innerNodes) {
if (node.isVirtualNode) { if (node.isVirtualNode) {
continue; continue;
} }
@ -1644,7 +1701,7 @@ export class ComfyApp {
for (const i in widgets) { for (const i in widgets) {
const widget = widgets[i]; const widget = widgets[i];
if (!widget.options || widget.options.serialize !== false) { if (!widget.options || widget.options.serialize !== false) {
inputs[widget.name] = widget.serializeValue ? await widget.serializeValue(n, i) : widget.value; inputs[widget.name] = widget.serializeValue ? await widget.serializeValue(node, i) : widget.value;
} }
} }
} }
@ -1688,6 +1745,9 @@ export class ComfyApp {
} }
if (link) { if (link) {
if (parent?.updateLink) {
link = parent.updateLink(link);
}
inputs[node.inputs[i].name] = [String(link.origin_id), parseInt(link.origin_slot)]; inputs[node.inputs[i].name] = [String(link.origin_id), parseInt(link.origin_slot)];
} }
} }
@ -1698,6 +1758,7 @@ export class ComfyApp {
class_type: node.comfyClass, class_type: node.comfyClass,
}; };
} }
}
// Remove inputs connected to removed nodes // Remove inputs connected to removed nodes
@ -1816,7 +1877,7 @@ export class ComfyApp {
const pngInfo = await getPngMetadata(file); const pngInfo = await getPngMetadata(file);
if (pngInfo) { if (pngInfo) {
if (pngInfo.workflow) { if (pngInfo.workflow) {
this.loadGraphData(JSON.parse(pngInfo.workflow)); await this.loadGraphData(JSON.parse(pngInfo.workflow));
} else if (pngInfo.parameters) { } else if (pngInfo.parameters) {
importA1111(this.graph, pngInfo.parameters); importA1111(this.graph, pngInfo.parameters);
} }
@ -1832,21 +1893,21 @@ export class ComfyApp {
} }
} else if (file.type === "application/json" || file.name?.endsWith(".json")) { } else if (file.type === "application/json" || file.name?.endsWith(".json")) {
const reader = new FileReader(); const reader = new FileReader();
reader.onload = () => { reader.onload = async () => {
const jsonContent = JSON.parse(reader.result); const jsonContent = JSON.parse(reader.result);
if (jsonContent?.templates) { if (jsonContent?.templates) {
this.loadTemplateData(jsonContent); this.loadTemplateData(jsonContent);
} else if(this.isApiJson(jsonContent)) { } else if(this.isApiJson(jsonContent)) {
this.loadApiJson(jsonContent); this.loadApiJson(jsonContent);
} else { } else {
this.loadGraphData(jsonContent); await this.loadGraphData(jsonContent);
} }
}; };
reader.readAsText(file); reader.readAsText(file);
} else if (file.name?.endsWith(".latent") || file.name?.endsWith(".safetensors")) { } else if (file.name?.endsWith(".latent") || file.name?.endsWith(".safetensors")) {
const info = await getLatentMetadata(file); const info = await getLatentMetadata(file);
if (info.workflow) { if (info.workflow) {
this.loadGraphData(JSON.parse(info.workflow)); await this.loadGraphData(JSON.parse(info.workflow));
} }
} }
} }
@ -1867,7 +1928,7 @@ export class ComfyApp {
for (const id of ids) { for (const id of ids) {
const data = apiData[id]; const data = apiData[id];
const node = LiteGraph.createNode(data.class_type); const node = LiteGraph.createNode(data.class_type);
node.id = id; node.id = isNaN(+id) ? id : +id;
graph.add(node); graph.add(node);
} }

322
web/scripts/domWidget.js Normal file
View File

@ -0,0 +1,322 @@
import { app, ANIM_PREVIEW_WIDGET } from "./app.js";
const SIZE = Symbol();
function intersect(a, b) {
const x = Math.max(a.x, b.x);
const num1 = Math.min(a.x + a.width, b.x + b.width);
const y = Math.max(a.y, b.y);
const num2 = Math.min(a.y + a.height, b.y + b.height);
if (num1 >= x && num2 >= y) return [x, y, num1 - x, num2 - y];
else return null;
}
function getClipPath(node, element, elRect) {
const selectedNode = Object.values(app.canvas.selected_nodes)[0];
if (selectedNode && selectedNode !== node) {
const MARGIN = 7;
const scale = app.canvas.ds.scale;
const bounding = selectedNode.getBounding();
const intersection = intersect(
{ x: elRect.x / scale, y: elRect.y / scale, width: elRect.width / scale, height: elRect.height / scale },
{
x: selectedNode.pos[0] + app.canvas.ds.offset[0] - MARGIN,
y: selectedNode.pos[1] + app.canvas.ds.offset[1] - LiteGraph.NODE_TITLE_HEIGHT - MARGIN,
width: bounding[2] + MARGIN + MARGIN,
height: bounding[3] + MARGIN + MARGIN,
}
);
if (!intersection) {
return "";
}
const widgetRect = element.getBoundingClientRect();
const clipX = intersection[0] - widgetRect.x / scale + "px";
const clipY = intersection[1] - widgetRect.y / scale + "px";
const clipWidth = intersection[2] + "px";
const clipHeight = intersection[3] + "px";
const path = `polygon(0% 0%, 0% 100%, ${clipX} 100%, ${clipX} ${clipY}, calc(${clipX} + ${clipWidth}) ${clipY}, calc(${clipX} + ${clipWidth}) calc(${clipY} + ${clipHeight}), ${clipX} calc(${clipY} + ${clipHeight}), ${clipX} 100%, 100% 100%, 100% 0%)`;
return path;
}
return "";
}
function computeSize(size) {
if (this.widgets?.[0]?.last_y == null) return;
let y = this.widgets[0].last_y;
let freeSpace = size[1] - y;
let widgetHeight = 0;
let dom = [];
for (const w of this.widgets) {
if (w.type === "converted-widget") {
// Ignore
delete w.computedHeight;
} else if (w.computeSize) {
widgetHeight += w.computeSize()[1] + 4;
} else if (w.element) {
// Extract DOM widget size info
const styles = getComputedStyle(w.element);
let minHeight = w.options.getMinHeight?.() ?? parseInt(styles.getPropertyValue("--comfy-widget-min-height"));
let maxHeight = w.options.getMaxHeight?.() ?? parseInt(styles.getPropertyValue("--comfy-widget-max-height"));
let prefHeight = w.options.getHeight?.() ?? styles.getPropertyValue("--comfy-widget-height");
if (prefHeight.endsWith?.("%")) {
prefHeight = size[1] * (parseFloat(prefHeight.substring(0, prefHeight.length - 1)) / 100);
} else {
prefHeight = parseInt(prefHeight);
if (isNaN(minHeight)) {
minHeight = prefHeight;
}
}
if (isNaN(minHeight)) {
minHeight = 50;
}
if (!isNaN(maxHeight)) {
if (!isNaN(prefHeight)) {
prefHeight = Math.min(prefHeight, maxHeight);
} else {
prefHeight = maxHeight;
}
}
dom.push({
minHeight,
prefHeight,
w,
});
} else {
widgetHeight += LiteGraph.NODE_WIDGET_HEIGHT + 4;
}
}
freeSpace -= widgetHeight;
// Calculate sizes with all widgets at their min height
const prefGrow = []; // Nodes that want to grow to their prefd size
const canGrow = []; // Nodes that can grow to auto size
let growBy = 0;
for (const d of dom) {
freeSpace -= d.minHeight;
if (isNaN(d.prefHeight)) {
canGrow.push(d);
d.w.computedHeight = d.minHeight;
} else {
const diff = d.prefHeight - d.minHeight;
if (diff > 0) {
prefGrow.push(d);
growBy += diff;
d.diff = diff;
} else {
d.w.computedHeight = d.minHeight;
}
}
}
if (this.imgs && !this.widgets.find((w) => w.name === ANIM_PREVIEW_WIDGET)) {
// Allocate space for image
freeSpace -= 220;
}
if (freeSpace < 0) {
// Not enough space for all widgets so we need to grow
size[1] -= freeSpace;
this.graph.setDirtyCanvas(true);
} else {
// Share the space between each
const growDiff = freeSpace - growBy;
if (growDiff > 0) {
// All pref sizes can be fulfilled
freeSpace = growDiff;
for (const d of prefGrow) {
d.w.computedHeight = d.prefHeight;
}
} else {
// We need to grow evenly
const shared = -growDiff / prefGrow.length;
for (const d of prefGrow) {
d.w.computedHeight = d.prefHeight - shared;
}
freeSpace = 0;
}
if (freeSpace > 0 && canGrow.length) {
// Grow any that are auto height
const shared = freeSpace / canGrow.length;
for (const d of canGrow) {
d.w.computedHeight += shared;
}
}
}
// Position each of the widgets
for (const w of this.widgets) {
w.y = y;
if (w.computedHeight) {
y += w.computedHeight;
} else if (w.computeSize) {
y += w.computeSize()[1] + 4;
} else {
y += LiteGraph.NODE_WIDGET_HEIGHT + 4;
}
}
}
// Override the compute visible nodes function to allow us to hide/show DOM elements when the node goes offscreen
const elementWidgets = new Set();
const computeVisibleNodes = LGraphCanvas.prototype.computeVisibleNodes;
LGraphCanvas.prototype.computeVisibleNodes = function () {
const visibleNodes = computeVisibleNodes.apply(this, arguments);
for (const node of app.graph._nodes) {
if (elementWidgets.has(node)) {
const hidden = visibleNodes.indexOf(node) === -1;
for (const w of node.widgets) {
if (w.element) {
w.element.hidden = hidden;
if (hidden) {
w.options.onHide?.(w);
}
}
}
}
}
return visibleNodes;
};
let enableDomClipping = true;
export function addDomClippingSetting() {
app.ui.settings.addSetting({
id: "Comfy.DOMClippingEnabled",
name: "Enable DOM element clipping (enabling may reduce performance)",
type: "boolean",
defaultValue: enableDomClipping,
onChange(value) {
enableDomClipping = !!value;
},
});
}
LGraphNode.prototype.addDOMWidget = function (name, type, element, options) {
options = { hideOnZoom: true, selectOn: ["focus", "click"], ...options };
if (!element.parentElement) {
document.body.append(element);
}
let mouseDownHandler;
if (element.blur) {
mouseDownHandler = (event) => {
if (!element.contains(event.target)) {
element.blur();
}
};
document.addEventListener("mousedown", mouseDownHandler);
}
const widget = {
type,
name,
get value() {
return options.getValue?.() ?? undefined;
},
set value(v) {
options.setValue?.(v);
widget.callback?.(widget.value);
},
draw: function (ctx, node, widgetWidth, y, widgetHeight) {
if (widget.computedHeight == null) {
computeSize.call(node, node.size);
}
const hidden =
node.flags?.collapsed ||
(!!options.hideOnZoom && app.canvas.ds.scale < 0.5) ||
widget.computedHeight <= 0 ||
widget.type === "converted-widget";
element.hidden = hidden;
element.style.display = hidden ? "none" : null;
if (hidden) {
widget.options.onHide?.(widget);
return;
}
const margin = 10;
const elRect = ctx.canvas.getBoundingClientRect();
const transform = new DOMMatrix()
.scaleSelf(elRect.width / ctx.canvas.width, elRect.height / ctx.canvas.height)
.multiplySelf(ctx.getTransform())
.translateSelf(margin, margin + y);
const scale = new DOMMatrix().scaleSelf(transform.a, transform.d);
Object.assign(element.style, {
transformOrigin: "0 0",
transform: scale,
left: `${transform.a + transform.e}px`,
top: `${transform.d + transform.f}px`,
width: `${widgetWidth - margin * 2}px`,
height: `${(widget.computedHeight ?? 50) - margin * 2}px`,
position: "absolute",
zIndex: app.graph._nodes.indexOf(node),
});
if (enableDomClipping) {
element.style.clipPath = getClipPath(node, element, elRect);
element.style.willChange = "clip-path";
}
this.options.onDraw?.(widget);
},
element,
options,
onRemove() {
if (mouseDownHandler) {
document.removeEventListener("mousedown", mouseDownHandler);
}
element.remove();
},
};
for (const evt of options.selectOn) {
element.addEventListener(evt, () => {
app.canvas.selectNode(this);
app.canvas.bringToFront(this);
});
}
this.addCustomWidget(widget);
elementWidgets.add(this);
const collapse = this.collapse;
this.collapse = function() {
collapse.apply(this, arguments);
if(this.flags?.collapsed) {
element.hidden = true;
element.style.display = "none";
}
}
const onRemoved = this.onRemoved;
this.onRemoved = function () {
element.remove();
elementWidgets.delete(this);
onRemoved?.apply(this, arguments);
};
if (!this[SIZE]) {
this[SIZE] = true;
const onResize = this.onResize;
this.onResize = function (size) {
options.beforeResize?.call(widget, this);
computeSize.call(this, size);
onResize?.apply(this, arguments);
options.afterResize?.call(widget, this);
};
}
return widget;
};

View File

@ -24,7 +24,7 @@ export function getPngMetadata(file) {
const length = dataView.getUint32(offset); const length = dataView.getUint32(offset);
// Get the chunk type // Get the chunk type
const type = String.fromCharCode(...pngData.slice(offset + 4, offset + 8)); const type = String.fromCharCode(...pngData.slice(offset + 4, offset + 8));
if (type === "tEXt") { if (type === "tEXt" || type == "comf") {
// Get the keyword // Get the keyword
let keyword_end = offset + 8; let keyword_end = offset + 8;
while (pngData[keyword_end] !== 0) { while (pngData[keyword_end] !== 0) {
@ -50,7 +50,6 @@ export function getPngMetadata(file) {
function parseExifData(exifData) { function parseExifData(exifData) {
// Check for the correct TIFF header (0x4949 for little-endian or 0x4D4D for big-endian) // Check for the correct TIFF header (0x4949 for little-endian or 0x4D4D for big-endian)
const isLittleEndian = new Uint16Array(exifData.slice(0, 2))[0] === 0x4949; const isLittleEndian = new Uint16Array(exifData.slice(0, 2))[0] === 0x4949;
console.log(exifData);
// Function to read 16-bit and 32-bit integers from binary data // Function to read 16-bit and 32-bit integers from binary data
function readInt(offset, isLittleEndian, length) { function readInt(offset, isLittleEndian, length) {
@ -126,6 +125,9 @@ export function getWebpMetadata(file) {
const chunk_length = dataView.getUint32(offset + 4, true); const chunk_length = dataView.getUint32(offset + 4, true);
const chunk_type = String.fromCharCode(...webp.slice(offset, offset + 4)); const chunk_type = String.fromCharCode(...webp.slice(offset, offset + 4));
if (chunk_type === "EXIF") { if (chunk_type === "EXIF") {
if (String.fromCharCode(...webp.slice(offset + 8, offset + 8 + 6)) == "Exif\0\0") {
offset += 6;
}
let data = parseExifData(webp.slice(offset + 8, offset + 8 + chunk_length)); let data = parseExifData(webp.slice(offset + 8, offset + 8 + chunk_length));
for (var key in data) { for (var key in data) {
var value = data[key]; var value = data[key];

View File

@ -462,8 +462,8 @@ class ComfyList {
return $el("div", {textContent: item.prompt[0] + ": "}, [ return $el("div", {textContent: item.prompt[0] + ": "}, [
$el("button", { $el("button", {
textContent: "Load", textContent: "Load",
onclick: () => { onclick: async () => {
app.loadGraphData(item.prompt[3].extra_pnginfo.workflow); await app.loadGraphData(item.prompt[3].extra_pnginfo.workflow);
if (item.outputs) { if (item.outputs) {
app.nodeOutputs = item.outputs; app.nodeOutputs = item.outputs;
} }
@ -599,7 +599,7 @@ export class ComfyUI {
const fileInput = $el("input", { const fileInput = $el("input", {
id: "comfy-file-input", id: "comfy-file-input",
type: "file", type: "file",
accept: ".json,image/png,.latent,.safetensors", accept: ".json,image/png,.latent,.safetensors,image/webp",
style: {display: "none"}, style: {display: "none"},
parent: document.body, parent: document.body,
onchange: () => { onchange: () => {
@ -784,9 +784,9 @@ export class ComfyUI {
} }
}), }),
$el("button", { $el("button", {
id: "comfy-load-default-button", textContent: "Load Default", onclick: () => { id: "comfy-load-default-button", textContent: "Load Default", onclick: async () => {
if (!confirmClear.value || confirm("Load default workflow?")) { if (!confirmClear.value || confirm("Load default workflow?")) {
app.loadGraphData() await app.loadGraphData()
} }
} }
}), }),

View File

@ -0,0 +1,97 @@
import { $el } from "../ui.js";
export function calculateImageGrid(imgs, dw, dh) {
let best = 0;
let w = imgs[0].naturalWidth;
let h = imgs[0].naturalHeight;
const numImages = imgs.length;
let cellWidth, cellHeight, cols, rows, shiftX;
// compact style
for (let c = 1; c <= numImages; c++) {
const r = Math.ceil(numImages / c);
const cW = dw / c;
const cH = dh / r;
const scaleX = cW / w;
const scaleY = cH / h;
const scale = Math.min(scaleX, scaleY, 1);
const imageW = w * scale;
const imageH = h * scale;
const area = imageW * imageH * numImages;
if (area > best) {
best = area;
cellWidth = imageW;
cellHeight = imageH;
cols = c;
rows = r;
shiftX = c * ((cW - imageW) / 2);
}
}
return { cellWidth, cellHeight, cols, rows, shiftX };
}
export function createImageHost(node) {
const el = $el("div.comfy-img-preview");
let currentImgs;
let first = true;
function updateSize() {
let w = null;
let h = null;
if (currentImgs) {
let elH = el.clientHeight;
if (first) {
first = false;
// On first run, if we are small then grow a bit
if (elH < 190) {
elH = 190;
}
el.style.setProperty("--comfy-widget-min-height", elH);
} else {
el.style.setProperty("--comfy-widget-min-height", null);
}
const nw = node.size[0];
({ cellWidth: w, cellHeight: h } = calculateImageGrid(currentImgs, nw - 20, elH));
w += "px";
h += "px";
el.style.setProperty("--comfy-img-preview-width", w);
el.style.setProperty("--comfy-img-preview-height", h);
}
}
return {
el,
updateImages(imgs) {
if (imgs !== currentImgs) {
if (currentImgs == null) {
requestAnimationFrame(() => {
updateSize();
});
}
el.replaceChildren(...imgs);
currentImgs = imgs;
node.onResize(node.size);
node.graph.setDirtyCanvas(true, true);
}
},
getHeight() {
updateSize();
},
onDraw() {
// Element from point uses a hittest find elements so we need to toggle pointer events
el.style.pointerEvents = "all";
const over = document.elementFromPoint(app.canvas.mouse[0], app.canvas.mouse[1]);
el.style.pointerEvents = "none";
if(!over) return;
// Set the overIndex so Open Image etc work
const idx = currentImgs.indexOf(over);
node.overIndex = idx;
},
};
}

View File

@ -1,4 +1,5 @@
import { api } from "./api.js" import { api } from "./api.js"
import "./domWidget.js";
function getNumberDefaults(inputData, defaultStep, precision, enable_rounding) { function getNumberDefaults(inputData, defaultStep, precision, enable_rounding) {
let defaultVal = inputData[1]["default"]; let defaultVal = inputData[1]["default"];
@ -22,18 +23,89 @@ function getNumberDefaults(inputData, defaultStep, precision, enable_rounding) {
return { val: defaultVal, config: { min, max, step: 10.0 * step, round, precision } }; return { val: defaultVal, config: { min, max, step: 10.0 * step, round, precision } };
} }
export function addValueControlWidget(node, targetWidget, defaultValue = "randomize", values) { export function addValueControlWidget(node, targetWidget, defaultValue = "randomize", values, widgetName, inputData) {
const valueControl = node.addWidget("combo", "control_after_generate", defaultValue, function (v) { }, { let name = inputData[1]?.control_after_generate;
if(typeof name !== "string") {
name = widgetName;
}
const widgets = addValueControlWidgets(node, targetWidget, defaultValue, {
addFilterList: false,
controlAfterGenerateName: name
}, inputData);
return widgets[0];
}
export function addValueControlWidgets(node, targetWidget, defaultValue = "randomize", options, inputData) {
if (!defaultValue) defaultValue = "randomize";
if (!options) options = {};
const getName = (defaultName, optionName) => {
let name = defaultName;
if (options[optionName]) {
name = options[optionName];
} else if (typeof inputData?.[1]?.[defaultName] === "string") {
name = inputData?.[1]?.[defaultName];
} else if (inputData?.[1]?.control_prefix) {
name = inputData?.[1]?.control_prefix + " " + name
}
return name;
}
const widgets = [];
const valueControl = node.addWidget(
"combo",
getName("control_after_generate", "controlAfterGenerateName"),
defaultValue,
function () {},
{
values: ["fixed", "increment", "decrement", "randomize"], values: ["fixed", "increment", "decrement", "randomize"],
serialize: false, // Don't include this in prompt. serialize: false, // Don't include this in prompt.
}); }
valueControl.afterQueued = () => { );
widgets.push(valueControl);
const isCombo = targetWidget.type === "combo";
let comboFilter;
if (isCombo && options.addFilterList !== false) {
comboFilter = node.addWidget(
"string",
getName("control_filter_list", "controlFilterListName"),
"",
function () {},
{
serialize: false, // Don't include this in prompt.
}
);
widgets.push(comboFilter);
}
valueControl.afterQueued = () => {
var v = valueControl.value; var v = valueControl.value;
if (targetWidget.type == "combo" && v !== "fixed") { if (isCombo && v !== "fixed") {
let current_index = targetWidget.options.values.indexOf(targetWidget.value); let values = targetWidget.options.values;
let current_length = targetWidget.options.values.length; const filter = comboFilter?.value;
if (filter) {
let check;
if (filter.startsWith("/") && filter.endsWith("/")) {
try {
const regex = new RegExp(filter.substring(1, filter.length - 1));
check = (item) => regex.test(item);
} catch (error) {
console.error("Error constructing RegExp filter for node " + node.id, filter, error);
}
}
if (!check) {
const lower = filter.toLocaleLowerCase();
check = (item) => item.toLocaleLowerCase().includes(lower);
}
values = values.filter(item => check(item));
if (!values.length && targetWidget.options.values.length) {
console.warn("Filter for node " + node.id + " has filtered out all items", filter);
}
}
let current_index = values.indexOf(targetWidget.value);
let current_length = values.length;
switch (v) { switch (v) {
case "increment": case "increment":
@ -50,11 +122,12 @@ export function addValueControlWidget(node, targetWidget, defaultValue = "random
current_index = Math.max(0, current_index); current_index = Math.max(0, current_index);
current_index = Math.min(current_length - 1, current_index); current_index = Math.min(current_length - 1, current_index);
if (current_index >= 0) { if (current_index >= 0) {
let value = targetWidget.options.values[current_index]; let value = values[current_index];
targetWidget.value = value; targetWidget.value = value;
targetWidget.callback(value); targetWidget.callback(value);
} }
} else { //number } else {
//number
let min = targetWidget.options.min; let min = targetWidget.options.min;
let max = targetWidget.options.max; let max = targetWidget.options.max;
// limit to something that javascript can handle // limit to something that javascript can handle
@ -79,184 +152,66 @@ export function addValueControlWidget(node, targetWidget, defaultValue = "random
} }
/*check if values are over or under their respective /*check if values are over or under their respective
* ranges and set them to min or max.*/ * ranges and set them to min or max.*/
if (targetWidget.value < min) if (targetWidget.value < min) targetWidget.value = min;
targetWidget.value = min;
if (targetWidget.value > max) if (targetWidget.value > max)
targetWidget.value = max; targetWidget.value = max;
targetWidget.callback(targetWidget.value); targetWidget.callback(targetWidget.value);
} }
} };
return valueControl; return widgets;
}; };
function seedWidget(node, inputName, inputData, app) { function seedWidget(node, inputName, inputData, app, widgetName) {
const seed = ComfyWidgets.INT(node, inputName, inputData, app); const seed = createIntWidget(node, inputName, inputData, app, true);
const seedControl = addValueControlWidget(node, seed.widget, "randomize"); const seedControl = addValueControlWidget(node, seed.widget, "randomize", undefined, widgetName, inputData);
seed.widget.linkedWidgets = [seedControl]; seed.widget.linkedWidgets = [seedControl];
return seed; return seed;
} }
const MultilineSymbol = Symbol(); function createIntWidget(node, inputName, inputData, app, isSeedInput) {
const MultilineResizeSymbol = Symbol(); const control = inputData[1]?.control_after_generate;
if (!isSeedInput && control) {
return seedWidget(node, inputName, inputData, app, typeof control === "string" ? control : undefined);
}
let widgetType = isSlider(inputData[1]["display"], app);
const { val, config } = getNumberDefaults(inputData, 1, 0, true);
Object.assign(config, { precision: 0 });
return {
widget: node.addWidget(
widgetType,
inputName,
val,
function (v) {
const s = this.options.step / 10;
this.value = Math.round(v / s) * s;
},
config
),
};
}
function addMultilineWidget(node, name, opts, app) { function addMultilineWidget(node, name, opts, app) {
const MIN_SIZE = 50; const inputEl = document.createElement("textarea");
inputEl.className = "comfy-multiline-input";
inputEl.value = opts.defaultVal;
inputEl.placeholder = opts.placeholder || name;
function computeSize(size) { const widget = node.addDOMWidget(name, "customtext", inputEl, {
if (node.widgets[0].last_y == null) return; getValue() {
return inputEl.value;
let y = node.widgets[0].last_y;
let freeSpace = size[1] - y;
// Compute the height of all non customtext widgets
let widgetHeight = 0;
const multi = [];
for (let i = 0; i < node.widgets.length; i++) {
const w = node.widgets[i];
if (w.type === "customtext") {
multi.push(w);
} else {
if (w.computeSize) {
widgetHeight += w.computeSize()[1] + 4;
} else {
widgetHeight += LiteGraph.NODE_WIDGET_HEIGHT + 4;
}
}
}
// See how large each text input can be
freeSpace -= widgetHeight;
freeSpace /= multi.length + (!!node.imgs?.length);
if (freeSpace < MIN_SIZE) {
// There isnt enough space for all the widgets, increase the size of the node
freeSpace = MIN_SIZE;
node.size[1] = y + widgetHeight + freeSpace * (multi.length + (!!node.imgs?.length));
node.graph.setDirtyCanvas(true);
}
// Position each of the widgets
for (const w of node.widgets) {
w.y = y;
if (w.type === "customtext") {
y += freeSpace;
w.computedHeight = freeSpace - multi.length*4;
} else if (w.computeSize) {
y += w.computeSize()[1] + 4;
} else {
y += LiteGraph.NODE_WIDGET_HEIGHT + 4;
}
}
node.inputHeight = freeSpace;
}
const widget = {
type: "customtext",
name,
get value() {
return this.inputEl.value;
}, },
set value(x) { setValue(v) {
this.inputEl.value = x; inputEl.value = v;
}, },
draw: function (ctx, _, widgetWidth, y, widgetHeight) {
if (!this.parent.inputHeight) {
// If we are initially offscreen when created we wont have received a resize event
// Calculate it here instead
computeSize(node.size);
}
const visible = app.canvas.ds.scale > 0.5 && this.type === "customtext";
const margin = 10;
const elRect = ctx.canvas.getBoundingClientRect();
const transform = new DOMMatrix()
.scaleSelf(elRect.width / ctx.canvas.width, elRect.height / ctx.canvas.height)
.multiplySelf(ctx.getTransform())
.translateSelf(margin, margin + y);
const scale = new DOMMatrix().scaleSelf(transform.a, transform.d)
Object.assign(this.inputEl.style, {
transformOrigin: "0 0",
transform: scale,
left: `${transform.a + transform.e}px`,
top: `${transform.d + transform.f}px`,
width: `${widgetWidth - (margin * 2)}px`,
height: `${this.parent.inputHeight - (margin * 2)}px`,
position: "absolute",
background: (!node.color)?'':node.color,
color: (!node.color)?'':'white',
zIndex: app.graph._nodes.indexOf(node),
}); });
this.inputEl.hidden = !visible; widget.inputEl = inputEl;
},
}; inputEl.addEventListener("input", () => {
widget.inputEl = document.createElement("textarea"); widget.callback?.(widget.value);
widget.inputEl.className = "comfy-multiline-input";
widget.inputEl.value = opts.defaultVal;
widget.inputEl.placeholder = opts.placeholder || "";
document.addEventListener("mousedown", function (event) {
if (!widget.inputEl.contains(event.target)) {
widget.inputEl.blur();
}
}); });
widget.parent = node;
document.body.appendChild(widget.inputEl);
node.addCustomWidget(widget);
app.canvas.onDrawBackground = function () {
// Draw node isnt fired once the node is off the screen
// if it goes off screen quickly, the input may not be removed
// this shifts it off screen so it can be moved back if the node is visible.
for (let n in app.graph._nodes) {
n = graph._nodes[n];
for (let w in n.widgets) {
let wid = n.widgets[w];
if (Object.hasOwn(wid, "inputEl")) {
wid.inputEl.style.left = -8000 + "px";
wid.inputEl.style.position = "absolute";
}
}
}
};
node.onRemoved = function () {
// When removing this node we need to remove the input from the DOM
for (let y in this.widgets) {
if (this.widgets[y].inputEl) {
this.widgets[y].inputEl.remove();
}
}
};
widget.onRemove = () => {
widget.inputEl?.remove();
// Restore original size handler if we are the last
if (!--node[MultilineSymbol]) {
node.onResize = node[MultilineResizeSymbol];
delete node[MultilineSymbol];
delete node[MultilineResizeSymbol];
}
};
if (node[MultilineSymbol]) {
node[MultilineSymbol]++;
} else {
node[MultilineSymbol] = 1;
const onResize = (node[MultilineResizeSymbol] = node.onResize);
node.onResize = function (size) {
computeSize(size);
// Call original resizer handler
if (onResize) {
onResize.apply(this, arguments);
}
};
}
return { minWidth: 400, minHeight: 200, widget }; return { minWidth: 400, minHeight: 200, widget };
} }
@ -288,31 +243,26 @@ export const ComfyWidgets = {
}, config) }; }, config) };
}, },
INT(node, inputName, inputData, app) { INT(node, inputName, inputData, app) {
let widgetType = isSlider(inputData[1]["display"], app); return createIntWidget(node, inputName, inputData, app);
const { val, config } = getNumberDefaults(inputData, 1, 0, true);
Object.assign(config, { precision: 0 });
return {
widget: node.addWidget(
widgetType,
inputName,
val,
function (v) {
const s = this.options.step / 10;
this.value = Math.round(v / s) * s;
},
config
),
};
}, },
BOOLEAN(node, inputName, inputData) { BOOLEAN(node, inputName, inputData) {
let defaultVal = inputData[1]["default"]; let defaultVal = false;
let options = {};
if (inputData[1]) {
if (inputData[1].default)
defaultVal = inputData[1].default;
if (inputData[1].label_on)
options["on"] = inputData[1].label_on;
if (inputData[1].label_off)
options["off"] = inputData[1].label_off;
}
return { return {
widget: node.addWidget( widget: node.addWidget(
"toggle", "toggle",
inputName, inputName,
defaultVal, defaultVal,
() => {}, () => {},
{"on": inputData[1].label_on, "off": inputData[1].label_off} options,
) )
}; };
}, },
@ -338,10 +288,14 @@ export const ComfyWidgets = {
if (inputData[1] && inputData[1].default) { if (inputData[1] && inputData[1].default) {
defaultValue = inputData[1].default; defaultValue = inputData[1].default;
} }
return { widget: node.addWidget("combo", inputName, defaultValue, () => {}, { values: type }) }; const res = { widget: node.addWidget("combo", inputName, defaultValue, () => {}, { values: type }) };
if (inputData[1]?.control_after_generate) {
res.widget.linkedWidgets = addValueControlWidgets(node, res.widget, undefined, undefined, inputData);
}
return res;
}, },
IMAGEUPLOAD(node, inputName, inputData, app) { IMAGEUPLOAD(node, inputName, inputData, app) {
const imageWidget = node.widgets.find((w) => w.name === "image"); const imageWidget = node.widgets.find((w) => w.name === (inputData[1]?.widget ?? "image"));
let uploadWidget; let uploadWidget;
function showImage(name) { function showImage(name) {
@ -455,9 +409,10 @@ export const ComfyWidgets = {
document.body.append(fileInput); document.body.append(fileInput);
// Create the button widget for selecting the files // Create the button widget for selecting the files
uploadWidget = node.addWidget("button", "choose file to upload", "image", () => { uploadWidget = node.addWidget("button", inputName, "image", () => {
fileInput.click(); fileInput.click();
}); });
uploadWidget.label = "choose file to upload";
uploadWidget.serialize = false; uploadWidget.serialize = false;
// Add handler to check if an image is being dragged over our node // Add handler to check if an image is being dragged over our node

View File

@ -409,6 +409,21 @@ dialog::backdrop {
width: calc(100% - 10px); width: calc(100% - 10px);
} }
.comfy-img-preview {
pointer-events: none;
overflow: hidden;
display: flex;
flex-wrap: wrap;
align-content: flex-start;
justify-content: center;
}
.comfy-img-preview img {
object-fit: contain;
width: var(--comfy-img-preview-width);
height: var(--comfy-img-preview-height);
}
/* Search box */ /* Search box */
.litegraph.litesearchbox { .litegraph.litesearchbox {