Merge branch 'master' into image-cache

This commit is contained in:
Jairo Correa 2023-12-02 05:16:21 -03:00 committed by GitHub
commit c92f3dca73
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
57 changed files with 5588 additions and 924 deletions

9
.vscode/settings.json vendored Normal file
View File

@ -0,0 +1,9 @@
{
"path-intellisense.mappings": {
"../": "${workspaceFolder}/web/extensions/core"
},
"[python]": {
"editor.defaultFormatter": "ms-python.autopep8"
},
"python.formatting.provider": "none"
}

View File

@ -11,7 +11,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
## Features
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
- Fully supports SD1.x, SD2.x and SDXL
- Fully supports SD1.x, SD2.x, [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/) and [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/)
- Asynchronous Queue system
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
- Command line option: ```--lowvram``` to make it work on GPUs with less than 3GB vram (enabled automatically on GPUs with low vram)
@ -30,6 +30,8 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
- [unCLIP Models](https://comfyanonymous.github.io/ComfyUI_examples/unclip/)
- [GLIGEN](https://comfyanonymous.github.io/ComfyUI_examples/gligen/)
- [Model Merging](https://comfyanonymous.github.io/ComfyUI_examples/model_merging/)
- [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/)
- [SDXL Turbo](https://comfyanonymous.github.io/ComfyUI_examples/sdturbo/)
- Latent previews with [TAESD](#how-to-show-high-quality-previews)
- Starts up very fast.
- Works fully offline: will never download anything.
@ -43,6 +45,7 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
|---------------------------|--------------------------------------------------------------------------------------------------------------------|
| Ctrl + Enter | Queue up current graph for generation |
| Ctrl + Shift + Enter | Queue up current graph as first for generation |
| Ctrl + Z/Ctrl + Y | Undo/Redo |
| Ctrl + S | Save workflow |
| Ctrl + O | Load workflow |
| Ctrl + A | Select all nodes |
@ -98,6 +101,7 @@ AMD users can install rocm and pytorch with pip if you don't have it already ins
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm5.6```
This is the command to install the nightly with ROCm 5.7 that might have some performance improvements:
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm5.7```
### NVIDIA
@ -190,7 +194,7 @@ To use a textual inversion concepts/embeddings in a text prompt put them in the
Make sure you use the regular loaders/Load Checkpoint node to load checkpoints. It will auto pick the right settings depending on your GPU.
You can set this command line setting to disable the upcasting to fp32 in some cross attention operations which will increase your speed. Note that this will very likely give you black images on SD2.x models. If you use xformers this option does not do anything.
You can set this command line setting to disable the upcasting to fp32 in some cross attention operations which will increase your speed. Note that this will very likely give you black images on SD2.x models. If you use xformers or pytorch attention this option does not do anything.
```--dont-upcast-attention```

View File

@ -54,6 +54,7 @@ class ControlNet(nn.Module):
transformer_depth_output=None,
device=None,
operations=comfy.ops,
**kwargs,
):
super().__init__()
assert use_spatial_transformer == True, "use_spatial_transformer has to be true"

View File

@ -62,6 +62,13 @@ fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in
fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.")
fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16.")
fpte_group = parser.add_mutually_exclusive_group()
fpte_group.add_argument("--fp8_e4m3fn-text-enc", action="store_true", help="Store text encoder weights in fp8 (e4m3fn variant).")
fpte_group.add_argument("--fp8_e5m2-text-enc", action="store_true", help="Store text encoder weights in fp8 (e5m2 variant).")
fpte_group.add_argument("--fp16-text-enc", action="store_true", help="Store text encoder weights in fp16.")
fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Store text encoder weights in fp32.")
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize when loading models with Intel GPUs.")

View File

@ -33,7 +33,7 @@ class ControlBase:
self.cond_hint_original = None
self.cond_hint = None
self.strength = 1.0
self.timestep_percent_range = (1.0, 0.0)
self.timestep_percent_range = (0.0, 1.0)
self.timestep_range = None
if device is None:
@ -42,7 +42,7 @@ class ControlBase:
self.previous_controlnet = None
self.global_average_pooling = False
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(1.0, 0.0)):
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0)):
self.cond_hint_original = cond_hint
self.strength = strength
self.timestep_percent_range = timestep_percent_range

View File

@ -858,7 +858,7 @@ def predict_eps_sigma(model, input, sigma_in, **kwargs):
return (input - model(input, sigma_in, **kwargs)) / sigma
def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, extra_args=None, callback=None, disable=False, noise_mask=None, variant='bh1'):
def sample_unipc(model, noise, image, sigmas, max_denoise, extra_args=None, callback=None, disable=False, noise_mask=None, variant='bh1'):
timesteps = sigmas.clone()
if sigmas[-1] == 0:
timesteps = sigmas[:]

View File

@ -750,3 +750,61 @@ def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, n
if sigmas[i + 1] > 0:
x += sigmas[i + 1] * noise_sampler(sigmas[i], sigmas[i + 1])
return x
@torch.no_grad()
def sample_heunpp2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
# From MIT licensed: https://github.com/Carzit/sd-webui-samplers-scheduler/
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
s_end = sigmas[-1]
for i in trange(len(sigmas) - 1, disable=disable):
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
eps = torch.randn_like(x) * s_noise
sigma_hat = sigmas[i] * (gamma + 1)
if gamma > 0:
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
denoised = model(x, sigma_hat * s_in, **extra_args)
d = to_d(x, sigma_hat, denoised)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
dt = sigmas[i + 1] - sigma_hat
if sigmas[i + 1] == s_end:
# Euler method
x = x + d * dt
elif sigmas[i + 2] == s_end:
# Heun's method
x_2 = x + d * dt
denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
w = 2 * sigmas[0]
w2 = sigmas[i+1]/w
w1 = 1 - w2
d_prime = d * w1 + d_2 * w2
x = x + d_prime * dt
else:
# Heun++
x_2 = x + d * dt
denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
dt_2 = sigmas[i + 2] - sigmas[i + 1]
x_3 = x_2 + d_2 * dt_2
denoised_3 = model(x_3, sigmas[i + 2] * s_in, **extra_args)
d_3 = to_d(x_3, sigmas[i + 2], denoised_3)
w = 3 * sigmas[0]
w2 = sigmas[i + 1] / w
w3 = sigmas[i + 2] / w
w1 = 1 - w2 - w3
d_prime = w1 * d + w2 * d_2 + w3 * d_3
x = x + d_prime * dt
return x

View File

@ -5,8 +5,10 @@ import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat
from typing import Optional, Any
from functools import partial
from .diffusionmodules.util import checkpoint
from .diffusionmodules.util import checkpoint, AlphaBlender, timestep_embedding
from .sub_quadratic_attention import efficient_dot_product_attention
from comfy import model_management
@ -276,9 +278,20 @@ def attention_split(q, k, v, heads, mask=None):
)
return r1
BROKEN_XFORMERS = False
try:
x_vers = xformers.__version__
#I think 0.0.23 is also broken (q with bs bigger than 65535 gives CUDA error)
BROKEN_XFORMERS = x_vers.startswith("0.0.21") or x_vers.startswith("0.0.22") or x_vers.startswith("0.0.23")
except:
pass
def attention_xformers(q, k, v, heads, mask=None):
b, _, dim_head = q.shape
dim_head //= heads
if BROKEN_XFORMERS:
if b * heads > 65535:
return attention_pytorch(q, k, v, heads, mask)
q, k, v = map(
lambda t: t.unsqueeze(3)
@ -370,53 +383,72 @@ class CrossAttention(nn.Module):
class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
disable_self_attn=False, dtype=None, device=None, operations=comfy.ops):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, ff_in=False, inner_dim=None,
disable_self_attn=False, disable_temporal_crossattention=False, switch_temporal_ca_to_sa=False, dtype=None, device=None, operations=comfy.ops):
super().__init__()
self.ff_in = ff_in or inner_dim is not None
if inner_dim is None:
inner_dim = dim
self.is_res = inner_dim == dim
if self.ff_in:
self.norm_in = nn.LayerNorm(dim, dtype=dtype, device=device)
self.ff_in = FeedForward(dim, dim_out=inner_dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations)
self.disable_self_attn = disable_self_attn
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
self.attn1 = CrossAttention(query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout,
context_dim=context_dim if self.disable_self_attn else None, dtype=dtype, device=device, operations=operations) # is a self-attention if not self.disable_self_attn
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations)
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device, operations=operations) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim, dtype=dtype, device=device)
self.norm2 = nn.LayerNorm(dim, dtype=dtype, device=device)
self.norm3 = nn.LayerNorm(dim, dtype=dtype, device=device)
self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations)
if disable_temporal_crossattention:
if switch_temporal_ca_to_sa:
raise ValueError
else:
self.attn2 = None
else:
context_dim_attn2 = None
if not switch_temporal_ca_to_sa:
context_dim_attn2 = context_dim
self.attn2 = CrossAttention(query_dim=inner_dim, context_dim=context_dim_attn2,
heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device, operations=operations) # is self-attn if context is none
self.norm2 = nn.LayerNorm(inner_dim, dtype=dtype, device=device)
self.norm1 = nn.LayerNorm(inner_dim, dtype=dtype, device=device)
self.norm3 = nn.LayerNorm(inner_dim, dtype=dtype, device=device)
self.checkpoint = checkpoint
self.n_heads = n_heads
self.d_head = d_head
self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa
def forward(self, x, context=None, transformer_options={}):
return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint)
def _forward(self, x, context=None, transformer_options={}):
extra_options = {}
block = None
block_index = 0
if "current_index" in transformer_options:
extra_options["transformer_index"] = transformer_options["current_index"]
if "block_index" in transformer_options:
block_index = transformer_options["block_index"]
extra_options["block_index"] = block_index
if "original_shape" in transformer_options:
extra_options["original_shape"] = transformer_options["original_shape"]
if "block" in transformer_options:
block = transformer_options["block"]
extra_options["block"] = block
if "cond_or_uncond" in transformer_options:
extra_options["cond_or_uncond"] = transformer_options["cond_or_uncond"]
if "patches" in transformer_options:
transformer_patches = transformer_options["patches"]
else:
transformer_patches = {}
block = transformer_options.get("block", None)
block_index = transformer_options.get("block_index", 0)
transformer_patches = {}
transformer_patches_replace = {}
for k in transformer_options:
if k == "patches":
transformer_patches = transformer_options[k]
elif k == "patches_replace":
transformer_patches_replace = transformer_options[k]
else:
extra_options[k] = transformer_options[k]
extra_options["n_heads"] = self.n_heads
extra_options["dim_head"] = self.d_head
if "patches_replace" in transformer_options:
transformer_patches_replace = transformer_options["patches_replace"]
else:
transformer_patches_replace = {}
if self.ff_in:
x_skip = x
x = self.ff_in(self.norm_in(x))
if self.is_res:
x += x_skip
n = self.norm1(x)
if self.disable_self_attn:
@ -465,31 +497,34 @@ class BasicTransformerBlock(nn.Module):
for p in patch:
x = p(x, extra_options)
n = self.norm2(x)
context_attn2 = context
value_attn2 = None
if "attn2_patch" in transformer_patches:
patch = transformer_patches["attn2_patch"]
value_attn2 = context_attn2
for p in patch:
n, context_attn2, value_attn2 = p(n, context_attn2, value_attn2, extra_options)
attn2_replace_patch = transformer_patches_replace.get("attn2", {})
block_attn2 = transformer_block
if block_attn2 not in attn2_replace_patch:
block_attn2 = block
if block_attn2 in attn2_replace_patch:
if value_attn2 is None:
if self.attn2 is not None:
n = self.norm2(x)
if self.switch_temporal_ca_to_sa:
context_attn2 = n
else:
context_attn2 = context
value_attn2 = None
if "attn2_patch" in transformer_patches:
patch = transformer_patches["attn2_patch"]
value_attn2 = context_attn2
n = self.attn2.to_q(n)
context_attn2 = self.attn2.to_k(context_attn2)
value_attn2 = self.attn2.to_v(value_attn2)
n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options)
n = self.attn2.to_out(n)
else:
n = self.attn2(n, context=context_attn2, value=value_attn2)
for p in patch:
n, context_attn2, value_attn2 = p(n, context_attn2, value_attn2, extra_options)
attn2_replace_patch = transformer_patches_replace.get("attn2", {})
block_attn2 = transformer_block
if block_attn2 not in attn2_replace_patch:
block_attn2 = block
if block_attn2 in attn2_replace_patch:
if value_attn2 is None:
value_attn2 = context_attn2
n = self.attn2.to_q(n)
context_attn2 = self.attn2.to_k(context_attn2)
value_attn2 = self.attn2.to_v(value_attn2)
n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options)
n = self.attn2.to_out(n)
else:
n = self.attn2(n, context=context_attn2, value=value_attn2)
if "attn2_output_patch" in transformer_patches:
patch = transformer_patches["attn2_output_patch"]
@ -497,7 +532,12 @@ class BasicTransformerBlock(nn.Module):
n = p(n, extra_options)
x += n
x = self.ff(self.norm3(x)) + x
if self.is_res:
x_skip = x
x = self.ff(self.norm3(x))
if self.is_res:
x += x_skip
return x
@ -565,3 +605,164 @@ class SpatialTransformer(nn.Module):
x = self.proj_out(x)
return x + x_in
class SpatialVideoTransformer(SpatialTransformer):
def __init__(
self,
in_channels,
n_heads,
d_head,
depth=1,
dropout=0.0,
use_linear=False,
context_dim=None,
use_spatial_context=False,
timesteps=None,
merge_strategy: str = "fixed",
merge_factor: float = 0.5,
time_context_dim=None,
ff_in=False,
checkpoint=False,
time_depth=1,
disable_self_attn=False,
disable_temporal_crossattention=False,
max_time_embed_period: int = 10000,
dtype=None, device=None, operations=comfy.ops
):
super().__init__(
in_channels,
n_heads,
d_head,
depth=depth,
dropout=dropout,
use_checkpoint=checkpoint,
context_dim=context_dim,
use_linear=use_linear,
disable_self_attn=disable_self_attn,
dtype=dtype, device=device, operations=operations
)
self.time_depth = time_depth
self.depth = depth
self.max_time_embed_period = max_time_embed_period
time_mix_d_head = d_head
n_time_mix_heads = n_heads
time_mix_inner_dim = int(time_mix_d_head * n_time_mix_heads)
inner_dim = n_heads * d_head
if use_spatial_context:
time_context_dim = context_dim
self.time_stack = nn.ModuleList(
[
BasicTransformerBlock(
inner_dim,
n_time_mix_heads,
time_mix_d_head,
dropout=dropout,
context_dim=time_context_dim,
# timesteps=timesteps,
checkpoint=checkpoint,
ff_in=ff_in,
inner_dim=time_mix_inner_dim,
disable_self_attn=disable_self_attn,
disable_temporal_crossattention=disable_temporal_crossattention,
dtype=dtype, device=device, operations=operations
)
for _ in range(self.depth)
]
)
assert len(self.time_stack) == len(self.transformer_blocks)
self.use_spatial_context = use_spatial_context
self.in_channels = in_channels
time_embed_dim = self.in_channels * 4
self.time_pos_embed = nn.Sequential(
operations.Linear(self.in_channels, time_embed_dim, dtype=dtype, device=device),
nn.SiLU(),
operations.Linear(time_embed_dim, self.in_channels, dtype=dtype, device=device),
)
self.time_mixer = AlphaBlender(
alpha=merge_factor, merge_strategy=merge_strategy
)
def forward(
self,
x: torch.Tensor,
context: Optional[torch.Tensor] = None,
time_context: Optional[torch.Tensor] = None,
timesteps: Optional[int] = None,
image_only_indicator: Optional[torch.Tensor] = None,
transformer_options={}
) -> torch.Tensor:
_, _, h, w = x.shape
x_in = x
spatial_context = None
if exists(context):
spatial_context = context
if self.use_spatial_context:
assert (
context.ndim == 3
), f"n dims of spatial context should be 3 but are {context.ndim}"
if time_context is None:
time_context = context
time_context_first_timestep = time_context[::timesteps]
time_context = repeat(
time_context_first_timestep, "b ... -> (b n) ...", n=h * w
)
elif time_context is not None and not self.use_spatial_context:
time_context = repeat(time_context, "b ... -> (b n) ...", n=h * w)
if time_context.ndim == 2:
time_context = rearrange(time_context, "b c -> b 1 c")
x = self.norm(x)
if not self.use_linear:
x = self.proj_in(x)
x = rearrange(x, "b c h w -> b (h w) c")
if self.use_linear:
x = self.proj_in(x)
num_frames = torch.arange(timesteps, device=x.device)
num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
num_frames = rearrange(num_frames, "b t -> (b t)")
t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False, max_period=self.max_time_embed_period).to(x.dtype)
emb = self.time_pos_embed(t_emb)
emb = emb[:, None, :]
for it_, (block, mix_block) in enumerate(
zip(self.transformer_blocks, self.time_stack)
):
transformer_options["block_index"] = it_
x = block(
x,
context=spatial_context,
transformer_options=transformer_options,
)
x_mix = x
x_mix = x_mix + emb
B, S, C = x_mix.shape
x_mix = rearrange(x_mix, "(b t) s c -> (b s) t c", t=timesteps)
x_mix = mix_block(x_mix, context=time_context) #TODO: transformer_options
x_mix = rearrange(
x_mix, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps
)
x = self.time_mixer(x_spatial=x, x_temporal=x_mix, image_only_indicator=image_only_indicator)
if self.use_linear:
x = self.proj_out(x)
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
if not self.use_linear:
x = self.proj_out(x)
out = x + x_in
return out

View File

@ -5,6 +5,8 @@ import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from functools import partial
from .util import (
checkpoint,
@ -12,8 +14,9 @@ from .util import (
zero_module,
normalization,
timestep_embedding,
AlphaBlender,
)
from ..attention import SpatialTransformer
from ..attention import SpatialTransformer, SpatialVideoTransformer, default
from comfy.ldm.util import exists
import comfy.ops
@ -28,6 +31,26 @@ class TimestepBlock(nn.Module):
Apply the module to `x` given `emb` timestep embeddings.
"""
#This is needed because accelerate makes a copy of transformer_options which breaks "transformer_index"
def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None, time_context=None, num_video_frames=None, image_only_indicator=None):
for layer in ts:
if isinstance(layer, VideoResBlock):
x = layer(x, emb, num_video_frames, image_only_indicator)
elif isinstance(layer, TimestepBlock):
x = layer(x, emb)
elif isinstance(layer, SpatialVideoTransformer):
x = layer(x, context, time_context, num_video_frames, image_only_indicator, transformer_options)
if "transformer_index" in transformer_options:
transformer_options["transformer_index"] += 1
elif isinstance(layer, SpatialTransformer):
x = layer(x, context, transformer_options)
if "transformer_index" in transformer_options:
transformer_options["transformer_index"] += 1
elif isinstance(layer, Upsample):
x = layer(x, output_shape=output_shape)
else:
x = layer(x)
return x
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
"""
@ -35,31 +58,8 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
support it as an extra input.
"""
def forward(self, x, emb, context=None, transformer_options={}, output_shape=None):
for layer in self:
if isinstance(layer, TimestepBlock):
x = layer(x, emb)
elif isinstance(layer, SpatialTransformer):
x = layer(x, context, transformer_options)
elif isinstance(layer, Upsample):
x = layer(x, output_shape=output_shape)
else:
x = layer(x)
return x
#This is needed because accelerate makes a copy of transformer_options which breaks "current_index"
def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None):
for layer in ts:
if isinstance(layer, TimestepBlock):
x = layer(x, emb)
elif isinstance(layer, SpatialTransformer):
x = layer(x, context, transformer_options)
transformer_options["current_index"] += 1
elif isinstance(layer, Upsample):
x = layer(x, output_shape=output_shape)
else:
x = layer(x)
return x
def forward(self, *args, **kwargs):
return forward_timestep_embed(self, *args, **kwargs)
class Upsample(nn.Module):
"""
@ -154,6 +154,9 @@ class ResBlock(TimestepBlock):
use_checkpoint=False,
up=False,
down=False,
kernel_size=3,
exchange_temb_dims=False,
skip_t_emb=False,
dtype=None,
device=None,
operations=comfy.ops
@ -166,11 +169,17 @@ class ResBlock(TimestepBlock):
self.use_conv = use_conv
self.use_checkpoint = use_checkpoint
self.use_scale_shift_norm = use_scale_shift_norm
self.exchange_temb_dims = exchange_temb_dims
if isinstance(kernel_size, list):
padding = [k // 2 for k in kernel_size]
else:
padding = kernel_size // 2
self.in_layers = nn.Sequential(
nn.GroupNorm(32, channels, dtype=dtype, device=device),
nn.SiLU(),
operations.conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=dtype, device=device),
operations.conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device),
)
self.updown = up or down
@ -184,19 +193,24 @@ class ResBlock(TimestepBlock):
else:
self.h_upd = self.x_upd = nn.Identity()
self.emb_layers = nn.Sequential(
nn.SiLU(),
operations.Linear(
emb_channels,
2 * self.out_channels if use_scale_shift_norm else self.out_channels, dtype=dtype, device=device
),
)
self.skip_t_emb = skip_t_emb
if self.skip_t_emb:
self.emb_layers = None
self.exchange_temb_dims = False
else:
self.emb_layers = nn.Sequential(
nn.SiLU(),
operations.Linear(
emb_channels,
2 * self.out_channels if use_scale_shift_norm else self.out_channels, dtype=dtype, device=device
),
)
self.out_layers = nn.Sequential(
nn.GroupNorm(32, self.out_channels, dtype=dtype, device=device),
nn.SiLU(),
nn.Dropout(p=dropout),
zero_module(
operations.conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1, dtype=dtype, device=device)
operations.conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device)
),
)
@ -204,7 +218,7 @@ class ResBlock(TimestepBlock):
self.skip_connection = nn.Identity()
elif use_conv:
self.skip_connection = operations.conv_nd(
dims, channels, self.out_channels, 3, padding=1, dtype=dtype, device=device
dims, channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device
)
else:
self.skip_connection = operations.conv_nd(dims, channels, self.out_channels, 1, dtype=dtype, device=device)
@ -230,19 +244,110 @@ class ResBlock(TimestepBlock):
h = in_conv(h)
else:
h = self.in_layers(x)
emb_out = self.emb_layers(emb).type(h.dtype)
while len(emb_out.shape) < len(h.shape):
emb_out = emb_out[..., None]
emb_out = None
if not self.skip_t_emb:
emb_out = self.emb_layers(emb).type(h.dtype)
while len(emb_out.shape) < len(h.shape):
emb_out = emb_out[..., None]
if self.use_scale_shift_norm:
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
scale, shift = th.chunk(emb_out, 2, dim=1)
h = out_norm(h) * (1 + scale) + shift
h = out_norm(h)
if emb_out is not None:
scale, shift = th.chunk(emb_out, 2, dim=1)
h *= (1 + scale)
h += shift
h = out_rest(h)
else:
h = h + emb_out
if emb_out is not None:
if self.exchange_temb_dims:
emb_out = rearrange(emb_out, "b t c ... -> b c t ...")
h = h + emb_out
h = self.out_layers(h)
return self.skip_connection(x) + h
class VideoResBlock(ResBlock):
def __init__(
self,
channels: int,
emb_channels: int,
dropout: float,
video_kernel_size=3,
merge_strategy: str = "fixed",
merge_factor: float = 0.5,
out_channels=None,
use_conv: bool = False,
use_scale_shift_norm: bool = False,
dims: int = 2,
use_checkpoint: bool = False,
up: bool = False,
down: bool = False,
dtype=None,
device=None,
operations=comfy.ops
):
super().__init__(
channels,
emb_channels,
dropout,
out_channels=out_channels,
use_conv=use_conv,
use_scale_shift_norm=use_scale_shift_norm,
dims=dims,
use_checkpoint=use_checkpoint,
up=up,
down=down,
dtype=dtype,
device=device,
operations=operations
)
self.time_stack = ResBlock(
default(out_channels, channels),
emb_channels,
dropout=dropout,
dims=3,
out_channels=default(out_channels, channels),
use_scale_shift_norm=False,
use_conv=False,
up=False,
down=False,
kernel_size=video_kernel_size,
use_checkpoint=use_checkpoint,
exchange_temb_dims=True,
dtype=dtype,
device=device,
operations=operations
)
self.time_mixer = AlphaBlender(
alpha=merge_factor,
merge_strategy=merge_strategy,
rearrange_pattern="b t -> b 1 t 1 1",
)
def forward(
self,
x: th.Tensor,
emb: th.Tensor,
num_video_frames: int,
image_only_indicator = None,
) -> th.Tensor:
x = super().forward(x, emb)
x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames)
x = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames)
x = self.time_stack(
x, rearrange(emb, "(b t) ... -> b t ...", t=num_video_frames)
)
x = self.time_mixer(
x_spatial=x_mix, x_temporal=x, image_only_indicator=image_only_indicator
)
x = rearrange(x, "b c t h w -> (b t) c h w")
return x
class Timestep(nn.Module):
def __init__(self, dim):
super().__init__()
@ -255,7 +360,10 @@ def apply_control(h, control, name):
if control is not None and name in control and len(control[name]) > 0:
ctrl = control[name].pop()
if ctrl is not None:
h += ctrl
try:
h += ctrl
except:
print("warning control could not be applied", h.shape, ctrl.shape)
return h
class UNetModel(nn.Module):
@ -316,6 +424,16 @@ class UNetModel(nn.Module):
adm_in_channels=None,
transformer_depth_middle=None,
transformer_depth_output=None,
use_temporal_resblock=False,
use_temporal_attention=False,
time_context_dim=None,
extra_ff_mix_layer=False,
use_spatial_context=False,
merge_strategy=None,
merge_factor=0.0,
video_kernel_size=None,
disable_temporal_crossattention=False,
max_ddpm_temb_period=10000,
device=None,
operations=comfy.ops,
):
@ -370,8 +488,12 @@ class UNetModel(nn.Module):
self.num_heads = num_heads
self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample
self.use_temporal_resblocks = use_temporal_resblock
self.predict_codebook_ids = n_embed is not None
self.default_num_video_frames = None
self.default_image_only_indicator = None
time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential(
operations.Linear(model_channels, time_embed_dim, dtype=self.dtype, device=device),
@ -408,13 +530,104 @@ class UNetModel(nn.Module):
input_block_chans = [model_channels]
ch = model_channels
ds = 1
def get_attention_layer(
ch,
num_heads,
dim_head,
depth=1,
context_dim=None,
use_checkpoint=False,
disable_self_attn=False,
):
if use_temporal_attention:
return SpatialVideoTransformer(
ch,
num_heads,
dim_head,
depth=depth,
context_dim=context_dim,
time_context_dim=time_context_dim,
dropout=dropout,
ff_in=extra_ff_mix_layer,
use_spatial_context=use_spatial_context,
merge_strategy=merge_strategy,
merge_factor=merge_factor,
checkpoint=use_checkpoint,
use_linear=use_linear_in_transformer,
disable_self_attn=disable_self_attn,
disable_temporal_crossattention=disable_temporal_crossattention,
max_time_embed_period=max_ddpm_temb_period,
dtype=self.dtype, device=device, operations=operations
)
else:
return SpatialTransformer(
ch, num_heads, dim_head, depth=depth, context_dim=context_dim,
disable_self_attn=disable_self_attn, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
)
def get_resblock(
merge_factor,
merge_strategy,
video_kernel_size,
ch,
time_embed_dim,
dropout,
out_channels,
dims,
use_checkpoint,
use_scale_shift_norm,
down=False,
up=False,
dtype=None,
device=None,
operations=comfy.ops
):
if self.use_temporal_resblocks:
return VideoResBlock(
merge_factor=merge_factor,
merge_strategy=merge_strategy,
video_kernel_size=video_kernel_size,
channels=ch,
emb_channels=time_embed_dim,
dropout=dropout,
out_channels=out_channels,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
down=down,
up=up,
dtype=dtype,
device=device,
operations=operations
)
else:
return ResBlock(
channels=ch,
emb_channels=time_embed_dim,
dropout=dropout,
out_channels=out_channels,
use_checkpoint=use_checkpoint,
dims=dims,
use_scale_shift_norm=use_scale_shift_norm,
down=down,
up=up,
dtype=dtype,
device=device,
operations=operations
)
for level, mult in enumerate(channel_mult):
for nr in range(self.num_res_blocks[level]):
layers = [
ResBlock(
ch,
time_embed_dim,
dropout,
get_resblock(
merge_factor=merge_factor,
merge_strategy=merge_strategy,
video_kernel_size=video_kernel_size,
ch=ch,
time_embed_dim=time_embed_dim,
dropout=dropout,
out_channels=mult * model_channels,
dims=dims,
use_checkpoint=use_checkpoint,
@ -441,11 +654,9 @@ class UNetModel(nn.Module):
disabled_sa = False
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
layers.append(SpatialTransformer(
layers.append(get_attention_layer(
ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
)
disable_self_attn=disabled_sa, use_checkpoint=use_checkpoint)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
@ -454,10 +665,13 @@ class UNetModel(nn.Module):
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
get_resblock(
merge_factor=merge_factor,
merge_strategy=merge_strategy,
video_kernel_size=video_kernel_size,
ch=ch,
time_embed_dim=time_embed_dim,
dropout=dropout,
out_channels=out_ch,
dims=dims,
use_checkpoint=use_checkpoint,
@ -487,10 +701,14 @@ class UNetModel(nn.Module):
#num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
mid_block = [
ResBlock(
ch,
time_embed_dim,
dropout,
get_resblock(
merge_factor=merge_factor,
merge_strategy=merge_strategy,
video_kernel_size=video_kernel_size,
ch=ch,
time_embed_dim=time_embed_dim,
dropout=dropout,
out_channels=None,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
@ -499,15 +717,18 @@ class UNetModel(nn.Module):
operations=operations
)]
if transformer_depth_middle >= 0:
mid_block += [SpatialTransformer( # always uses a self-attn
mid_block += [get_attention_layer( # always uses a self-attn
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
disable_self_attn=disable_middle_self_attn, use_checkpoint=use_checkpoint
),
ResBlock(
ch,
time_embed_dim,
dropout,
get_resblock(
merge_factor=merge_factor,
merge_strategy=merge_strategy,
video_kernel_size=video_kernel_size,
ch=ch,
time_embed_dim=time_embed_dim,
dropout=dropout,
out_channels=None,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
@ -523,10 +744,13 @@ class UNetModel(nn.Module):
for i in range(self.num_res_blocks[level] + 1):
ich = input_block_chans.pop()
layers = [
ResBlock(
ch + ich,
time_embed_dim,
dropout,
get_resblock(
merge_factor=merge_factor,
merge_strategy=merge_strategy,
video_kernel_size=video_kernel_size,
ch=ch + ich,
time_embed_dim=time_embed_dim,
dropout=dropout,
out_channels=model_channels * mult,
dims=dims,
use_checkpoint=use_checkpoint,
@ -554,19 +778,21 @@ class UNetModel(nn.Module):
if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
layers.append(
SpatialTransformer(
get_attention_layer(
ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
disable_self_attn=disabled_sa, use_checkpoint=use_checkpoint
)
)
if level and i == self.num_res_blocks[level]:
out_ch = ch
layers.append(
ResBlock(
ch,
time_embed_dim,
dropout,
get_resblock(
merge_factor=merge_factor,
merge_strategy=merge_strategy,
video_kernel_size=video_kernel_size,
ch=ch,
time_embed_dim=time_embed_dim,
dropout=dropout,
out_channels=out_ch,
dims=dims,
use_checkpoint=use_checkpoint,
@ -605,9 +831,13 @@ class UNetModel(nn.Module):
:return: an [N x C x ...] Tensor of outputs.
"""
transformer_options["original_shape"] = list(x.shape)
transformer_options["current_index"] = 0
transformer_options["transformer_index"] = 0
transformer_patches = transformer_options.get("patches", {})
num_video_frames = kwargs.get("num_video_frames", self.default_num_video_frames)
image_only_indicator = kwargs.get("image_only_indicator", self.default_image_only_indicator)
time_context = kwargs.get("time_context", None)
assert (y is not None) == (
self.num_classes is not None
), "must specify y if and only if the model is class-conditional"
@ -622,14 +852,24 @@ class UNetModel(nn.Module):
h = x.type(self.dtype)
for id, module in enumerate(self.input_blocks):
transformer_options["block"] = ("input", id)
h = forward_timestep_embed(module, h, emb, context, transformer_options)
h = forward_timestep_embed(module, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
h = apply_control(h, control, 'input')
if "input_block_patch" in transformer_patches:
patch = transformer_patches["input_block_patch"]
for p in patch:
h = p(h, transformer_options)
hs.append(h)
if "input_block_patch_after_skip" in transformer_patches:
patch = transformer_patches["input_block_patch_after_skip"]
for p in patch:
h = p(h, transformer_options)
transformer_options["block"] = ("middle", 0)
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options)
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
h = apply_control(h, control, 'middle')
for id, module in enumerate(self.output_blocks):
transformer_options["block"] = ("output", id)
hsp = hs.pop()
@ -646,7 +886,7 @@ class UNetModel(nn.Module):
output_shape = hs[-1].shape
else:
output_shape = None
h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape)
h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
h = h.type(x.dtype)
if self.predict_codebook_ids:
return self.id_predictor(h)

View File

@ -13,11 +13,78 @@ import math
import torch
import torch.nn as nn
import numpy as np
from einops import repeat
from einops import repeat, rearrange
from comfy.ldm.util import instantiate_from_config
import comfy.ops
class AlphaBlender(nn.Module):
strategies = ["learned", "fixed", "learned_with_images"]
def __init__(
self,
alpha: float,
merge_strategy: str = "learned_with_images",
rearrange_pattern: str = "b t -> (b t) 1 1",
):
super().__init__()
self.merge_strategy = merge_strategy
self.rearrange_pattern = rearrange_pattern
assert (
merge_strategy in self.strategies
), f"merge_strategy needs to be in {self.strategies}"
if self.merge_strategy == "fixed":
self.register_buffer("mix_factor", torch.Tensor([alpha]))
elif (
self.merge_strategy == "learned"
or self.merge_strategy == "learned_with_images"
):
self.register_parameter(
"mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
)
else:
raise ValueError(f"unknown merge strategy {self.merge_strategy}")
def get_alpha(self, image_only_indicator: torch.Tensor) -> torch.Tensor:
# skip_time_mix = rearrange(repeat(skip_time_mix, 'b -> (b t) () () ()', t=t), '(b t) 1 ... -> b 1 t ...', t=t)
if self.merge_strategy == "fixed":
# make shape compatible
# alpha = repeat(self.mix_factor, '1 -> b () t () ()', t=t, b=bs)
alpha = self.mix_factor
elif self.merge_strategy == "learned":
alpha = torch.sigmoid(self.mix_factor)
# make shape compatible
# alpha = repeat(alpha, '1 -> s () ()', s = t * bs)
elif self.merge_strategy == "learned_with_images":
assert image_only_indicator is not None, "need image_only_indicator ..."
alpha = torch.where(
image_only_indicator.bool(),
torch.ones(1, 1, device=image_only_indicator.device),
rearrange(torch.sigmoid(self.mix_factor), "... -> ... 1"),
)
alpha = rearrange(alpha, self.rearrange_pattern)
# make shape compatible
# alpha = repeat(alpha, '1 -> s () ()', s = t * bs)
else:
raise NotImplementedError()
return alpha
def forward(
self,
x_spatial,
x_temporal,
image_only_indicator=None,
) -> torch.Tensor:
alpha = self.get_alpha(image_only_indicator)
x = (
alpha.to(x_spatial.dtype) * x_spatial
+ (1.0 - alpha).to(x_spatial.dtype) * x_temporal
)
return x
def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
if schedule == "linear":
betas = (

View File

@ -0,0 +1,244 @@
import functools
from typing import Callable, Iterable, Union
import torch
from einops import rearrange, repeat
import comfy.ops
from .diffusionmodules.model import (
AttnBlock,
Decoder,
ResnetBlock,
)
from .diffusionmodules.openaimodel import ResBlock, timestep_embedding
from .attention import BasicTransformerBlock
def partialclass(cls, *args, **kwargs):
class NewCls(cls):
__init__ = functools.partialmethod(cls.__init__, *args, **kwargs)
return NewCls
class VideoResBlock(ResnetBlock):
def __init__(
self,
out_channels,
*args,
dropout=0.0,
video_kernel_size=3,
alpha=0.0,
merge_strategy="learned",
**kwargs,
):
super().__init__(out_channels=out_channels, dropout=dropout, *args, **kwargs)
if video_kernel_size is None:
video_kernel_size = [3, 1, 1]
self.time_stack = ResBlock(
channels=out_channels,
emb_channels=0,
dropout=dropout,
dims=3,
use_scale_shift_norm=False,
use_conv=False,
up=False,
down=False,
kernel_size=video_kernel_size,
use_checkpoint=False,
skip_t_emb=True,
)
self.merge_strategy = merge_strategy
if self.merge_strategy == "fixed":
self.register_buffer("mix_factor", torch.Tensor([alpha]))
elif self.merge_strategy == "learned":
self.register_parameter(
"mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
)
else:
raise ValueError(f"unknown merge strategy {self.merge_strategy}")
def get_alpha(self, bs):
if self.merge_strategy == "fixed":
return self.mix_factor
elif self.merge_strategy == "learned":
return torch.sigmoid(self.mix_factor)
else:
raise NotImplementedError()
def forward(self, x, temb, skip_video=False, timesteps=None):
b, c, h, w = x.shape
if timesteps is None:
timesteps = b
x = super().forward(x, temb)
if not skip_video:
x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
x = self.time_stack(x, temb)
alpha = self.get_alpha(bs=b // timesteps)
x = alpha * x + (1.0 - alpha) * x_mix
x = rearrange(x, "b c t h w -> (b t) c h w")
return x
class AE3DConv(torch.nn.Conv2d):
def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs):
super().__init__(in_channels, out_channels, *args, **kwargs)
if isinstance(video_kernel_size, Iterable):
padding = [int(k // 2) for k in video_kernel_size]
else:
padding = int(video_kernel_size // 2)
self.time_mix_conv = torch.nn.Conv3d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=video_kernel_size,
padding=padding,
)
def forward(self, input, timesteps=None, skip_video=False):
if timesteps is None:
timesteps = input.shape[0]
x = super().forward(input)
if skip_video:
return x
x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
x = self.time_mix_conv(x)
return rearrange(x, "b c t h w -> (b t) c h w")
class AttnVideoBlock(AttnBlock):
def __init__(
self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned"
):
super().__init__(in_channels)
# no context, single headed, as in base class
self.time_mix_block = BasicTransformerBlock(
dim=in_channels,
n_heads=1,
d_head=in_channels,
checkpoint=False,
ff_in=True,
)
time_embed_dim = self.in_channels * 4
self.video_time_embed = torch.nn.Sequential(
comfy.ops.Linear(self.in_channels, time_embed_dim),
torch.nn.SiLU(),
comfy.ops.Linear(time_embed_dim, self.in_channels),
)
self.merge_strategy = merge_strategy
if self.merge_strategy == "fixed":
self.register_buffer("mix_factor", torch.Tensor([alpha]))
elif self.merge_strategy == "learned":
self.register_parameter(
"mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
)
else:
raise ValueError(f"unknown merge strategy {self.merge_strategy}")
def forward(self, x, timesteps=None, skip_time_block=False):
if skip_time_block:
return super().forward(x)
if timesteps is None:
timesteps = x.shape[0]
x_in = x
x = self.attention(x)
h, w = x.shape[2:]
x = rearrange(x, "b c h w -> b (h w) c")
x_mix = x
num_frames = torch.arange(timesteps, device=x.device)
num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
num_frames = rearrange(num_frames, "b t -> (b t)")
t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False)
emb = self.video_time_embed(t_emb) # b, n_channels
emb = emb[:, None, :]
x_mix = x_mix + emb
alpha = self.get_alpha()
x_mix = self.time_mix_block(x_mix, timesteps=timesteps)
x = alpha * x + (1.0 - alpha) * x_mix # alpha merge
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
x = self.proj_out(x)
return x_in + x
def get_alpha(
self,
):
if self.merge_strategy == "fixed":
return self.mix_factor
elif self.merge_strategy == "learned":
return torch.sigmoid(self.mix_factor)
else:
raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}")
def make_time_attn(
in_channels,
attn_type="vanilla",
attn_kwargs=None,
alpha: float = 0,
merge_strategy: str = "learned",
):
return partialclass(
AttnVideoBlock, in_channels, alpha=alpha, merge_strategy=merge_strategy
)
class Conv2DWrapper(torch.nn.Conv2d):
def forward(self, input: torch.Tensor, **kwargs) -> torch.Tensor:
return super().forward(input)
class VideoDecoder(Decoder):
available_time_modes = ["all", "conv-only", "attn-only"]
def __init__(
self,
*args,
video_kernel_size: Union[int, list] = 3,
alpha: float = 0.0,
merge_strategy: str = "learned",
time_mode: str = "conv-only",
**kwargs,
):
self.video_kernel_size = video_kernel_size
self.alpha = alpha
self.merge_strategy = merge_strategy
self.time_mode = time_mode
assert (
self.time_mode in self.available_time_modes
), f"time_mode parameter has to be in {self.available_time_modes}"
if self.time_mode != "attn-only":
kwargs["conv_out_op"] = partialclass(AE3DConv, video_kernel_size=self.video_kernel_size)
if self.time_mode not in ["conv-only", "only-last-conv"]:
kwargs["attn_op"] = partialclass(make_time_attn, alpha=self.alpha, merge_strategy=self.merge_strategy)
if self.time_mode not in ["attn-only", "only-last-conv"]:
kwargs["resnet_op"] = partialclass(VideoResBlock, video_kernel_size=self.video_kernel_size, alpha=self.alpha, merge_strategy=self.merge_strategy)
super().__init__(*args, **kwargs)
def get_last_layer(self, skip_time_mix=False, **kwargs):
if self.time_mode == "attn-only":
raise NotImplementedError("TODO")
else:
return (
self.conv_out.time_mix_conv.weight
if not skip_time_mix
else self.conv_out.weight
)

View File

@ -10,17 +10,22 @@ from . import utils
class ModelType(Enum):
EPS = 1
V_PREDICTION = 2
V_PREDICTION_EDM = 3
from comfy.model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete
from comfy.model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete, ModelSamplingContinuousEDM
def model_sampling(model_config, model_type):
s = ModelSamplingDiscrete
if model_type == ModelType.EPS:
c = EPS
elif model_type == ModelType.V_PREDICTION:
c = V_PREDICTION
s = ModelSamplingDiscrete
elif model_type == ModelType.V_PREDICTION_EDM:
c = V_PREDICTION
s = ModelSamplingContinuousEDM
class ModelSampling(s, c):
pass
@ -121,6 +126,7 @@ class BaseModel(torch.nn.Module):
if k.startswith(unet_prefix):
to_load[k[len(unet_prefix):]] = sd.pop(k)
to_load = self.model_config.process_unet_state_dict(to_load)
m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
if len(m) > 0:
print("unet missing:", m)
@ -157,6 +163,17 @@ class BaseModel(torch.nn.Module):
def set_inpaint(self):
self.inpaint_model = True
def memory_required(self, input_shape):
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
#TODO: this needs to be tweaked
area = input_shape[0] * input_shape[2] * input_shape[3]
return (area * comfy.model_management.dtype_size(self.get_dtype()) / 50) * (1024 * 1024)
else:
#TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory.
area = input_shape[0] * input_shape[2] * input_shape[3]
return (((area * 0.6) / 0.9) + 1024) * (1024 * 1024)
def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0):
adm_inputs = []
weights = []
@ -251,3 +268,48 @@ class SDXL(BaseModel):
out.append(self.embedder(torch.Tensor([target_width])))
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1)
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
class SVD_img2vid(BaseModel):
def __init__(self, model_config, model_type=ModelType.V_PREDICTION_EDM, device=None):
super().__init__(model_config, model_type, device=device)
self.embedder = Timestep(256)
def encode_adm(self, **kwargs):
fps_id = kwargs.get("fps", 6) - 1
motion_bucket_id = kwargs.get("motion_bucket_id", 127)
augmentation = kwargs.get("augmentation_level", 0)
out = []
out.append(self.embedder(torch.Tensor([fps_id])))
out.append(self.embedder(torch.Tensor([motion_bucket_id])))
out.append(self.embedder(torch.Tensor([augmentation])))
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0)
return flat
def extra_conds(self, **kwargs):
out = {}
adm = self.encode_adm(**kwargs)
if adm is not None:
out['y'] = comfy.conds.CONDRegular(adm)
latent_image = kwargs.get("concat_latent_image", None)
noise = kwargs.get("noise", None)
device = kwargs["device"]
if latent_image is None:
latent_image = torch.zeros_like(noise)
if latent_image.shape[1:] != noise.shape[1:]:
latent_image = utils.common_upscale(latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center")
latent_image = utils.repeat_to_batch_size(latent_image, noise.shape[0])
out['c_concat'] = comfy.conds.CONDNoiseShape(latent_image)
if "time_conditioning" in kwargs:
out["time_context"] = comfy.conds.CONDCrossAttn(kwargs["time_conditioning"])
out['image_only_indicator'] = comfy.conds.CONDConstant(torch.zeros((1,), device=device))
out['num_video_frames'] = comfy.conds.CONDConstant(noise.shape[0])
return out

View File

@ -24,7 +24,8 @@ def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
last_transformer_depth = count_blocks(state_dict_keys, transformer_prefix + '{}')
context_dim = state_dict['{}0.attn2.to_k.weight'.format(transformer_prefix)].shape[1]
use_linear_in_transformer = len(state_dict['{}1.proj_in.weight'.format(prefix)].shape) == 2
return last_transformer_depth, context_dim, use_linear_in_transformer
time_stack = '{}1.time_stack.0.attn1.to_q.weight'.format(prefix) in state_dict or '{}1.time_mix_blocks.0.attn1.to_q.weight'.format(prefix) in state_dict
return last_transformer_depth, context_dim, use_linear_in_transformer, time_stack
return None
def detect_unet_config(state_dict, key_prefix, dtype):
@ -57,6 +58,7 @@ def detect_unet_config(state_dict, key_prefix, dtype):
context_dim = None
use_linear_in_transformer = False
video_model = False
current_res = 1
count = 0
@ -99,6 +101,7 @@ def detect_unet_config(state_dict, key_prefix, dtype):
if context_dim is None:
context_dim = out[1]
use_linear_in_transformer = out[2]
video_model = out[3]
else:
transformer_depth.append(0)
@ -127,6 +130,19 @@ def detect_unet_config(state_dict, key_prefix, dtype):
unet_config["transformer_depth_middle"] = transformer_depth_middle
unet_config['use_linear_in_transformer'] = use_linear_in_transformer
unet_config["context_dim"] = context_dim
if video_model:
unet_config["extra_ff_mix_layer"] = True
unet_config["use_spatial_context"] = True
unet_config["merge_strategy"] = "learned_with_images"
unet_config["merge_factor"] = 0.0
unet_config["video_kernel_size"] = [3, 1, 1]
unet_config["use_temporal_resblock"] = True
unet_config["use_temporal_attention"] = True
else:
unet_config["use_temporal_resblock"] = False
unet_config["use_temporal_attention"] = False
return unet_config
def model_config_from_unet_config(unet_config):
@ -186,17 +202,24 @@ def convert_config(unet_config):
def unet_config_from_diffusers_unet(state_dict, dtype):
match = {}
attention_resolutions = []
transformer_depth = []
attn_res = 1
for i in range(5):
k = "down_blocks.{}.attentions.1.transformer_blocks.0.attn2.to_k.weight".format(i)
if k in state_dict:
match["context_dim"] = state_dict[k].shape[1]
attention_resolutions.append(attn_res)
attn_res *= 2
down_blocks = count_blocks(state_dict, "down_blocks.{}")
for i in range(down_blocks):
attn_blocks = count_blocks(state_dict, "down_blocks.{}.attentions.".format(i) + '{}')
for ab in range(attn_blocks):
transformer_count = count_blocks(state_dict, "down_blocks.{}.attentions.{}.transformer_blocks.".format(i, ab) + '{}')
transformer_depth.append(transformer_count)
if transformer_count > 0:
match["context_dim"] = state_dict["down_blocks.{}.attentions.{}.transformer_blocks.0.attn2.to_k.weight".format(i, ab)].shape[1]
match["attention_resolutions"] = attention_resolutions
attn_res *= 2
if attn_blocks == 0:
transformer_depth.append(0)
transformer_depth.append(0)
match["transformer_depth"] = transformer_depth
match["model_channels"] = state_dict["conv_in.weight"].shape[0]
match["in_channels"] = state_dict["conv_in.weight"].shape[1]
@ -208,50 +231,65 @@ def unet_config_from_diffusers_unet(state_dict, dtype):
SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4],
'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64}
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 10, 10], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 10,
'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10],
'use_temporal_attention': False, 'use_temporal_resblock': False}
SDXL_refiner = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2560, 'dtype': dtype, 'in_channels': 4, 'model_channels': 384,
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 4, 4, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 4, 'use_linear_in_transformer': True, 'context_dim': 1280, "num_head_channels": 64}
'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [0, 0, 4, 4, 4, 4, 0, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 4,
'use_linear_in_transformer': True, 'context_dim': 1280, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 4, 4, 4, 4, 4, 4, 0, 0, 0],
'use_temporal_attention': False, 'use_temporal_resblock': False}
SD21 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, "num_head_channels": 64}
'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2],
'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': True,
'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
'use_temporal_attention': False, 'use_temporal_resblock': False}
SD21_uncliph = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2048, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, "num_head_channels": 64}
'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1,
'use_linear_in_transformer': True, 'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
'use_temporal_attention': False, 'use_temporal_resblock': False}
SD21_unclipl = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 1536, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024}
'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1,
'use_linear_in_transformer': True, 'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
'use_temporal_attention': False, 'use_temporal_resblock': False}
SD15 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768, "num_heads": 8}
SD15 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'adm_in_channels': None,
'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0],
'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768, 'num_heads': 8,
'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
'use_temporal_attention': False, 'use_temporal_resblock': False}
SDXL_mid_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [4], 'transformer_depth': [0, 0, 1], 'channel_mult': [1, 2, 4],
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64}
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 0, 0, 1, 1], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 1,
'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 0, 0, 0, 1, 1, 1],
'use_temporal_attention': False, 'use_temporal_resblock': False}
SDXL_small_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [], 'transformer_depth': [0, 0, 0], 'channel_mult': [1, 2, 4],
'transformer_depth_middle': 0, 'use_linear_in_transformer': True, "num_head_channels": 64, 'context_dim': 1}
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 0, 0, 0, 0], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 0,
'use_linear_in_transformer': True, 'num_head_channels': 64, 'context_dim': 1, 'transformer_depth_output': [0, 0, 0, 0, 0, 0, 0, 0, 0],
'use_temporal_attention': False, 'use_temporal_resblock': False}
SDXL_diffusers_inpaint = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 9, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4],
'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64}
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 9, 'model_channels': 320,
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 10, 10], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 10,
'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10],
'use_temporal_attention': False, 'use_temporal_resblock': False}
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint]
SSD_1B = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 4, 4], 'transformer_depth_output': [0, 0, 0, 1, 1, 2, 10, 4, 4],
'channel_mult': [1, 2, 4], 'transformer_depth_middle': -1, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
'use_temporal_attention': False, 'use_temporal_resblock': False}
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B]
for unet_config in supported_models:
matches = True

View File

@ -133,6 +133,10 @@ else:
import xformers
import xformers.ops
XFORMERS_IS_AVAILABLE = True
try:
XFORMERS_IS_AVAILABLE = xformers._has_cpp_library
except:
pass
try:
XFORMERS_VERSION = xformers.version.__version__
print("xformers version:", XFORMERS_VERSION)
@ -478,6 +482,21 @@ def text_encoder_device():
else:
return torch.device("cpu")
def text_encoder_dtype(device=None):
if args.fp8_e4m3fn_text_enc:
return torch.float8_e4m3fn
elif args.fp8_e5m2_text_enc:
return torch.float8_e5m2
elif args.fp16_text_enc:
return torch.float16
elif args.fp32_text_enc:
return torch.float32
if should_use_fp16(device, prioritize_performance=False):
return torch.float16
else:
return torch.float32
def vae_device():
return get_torch_device()
@ -579,27 +598,6 @@ def get_free_memory(dev=None, torch_free_too=False):
else:
return mem_free_total
def batch_area_memory(area):
if xformers_enabled() or pytorch_attention_flash_attention():
#TODO: these formulas are copied from maximum_batch_area below
return (area / 20) * (1024 * 1024)
else:
return (((area * 0.6) / 0.9) + 1024) * (1024 * 1024)
def maximum_batch_area():
global vram_state
if vram_state == VRAMState.NO_VRAM:
return 0
memory_free = get_free_memory() / (1024 * 1024)
if xformers_enabled() or pytorch_attention_flash_attention():
#TODO: this needs to be tweaked
area = 20 * memory_free
else:
#TODO: this formula is because AMD sucks and has memory management issues which might be fixed in the future
area = ((memory_free - 1024) * 0.9) / (0.6)
return int(max(area, 0))
def cpu_mode():
global cpu_state
return cpu_state == CPUState.CPU

View File

@ -37,7 +37,7 @@ class ModelPatcher:
return size
def clone(self):
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device)
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, weight_inplace_update=self.weight_inplace_update)
n.patches = {}
for k in self.patches:
n.patches[k] = self.patches[k][:]
@ -52,6 +52,9 @@ class ModelPatcher:
return True
return False
def memory_required(self, input_shape):
return self.model.memory_required(input_shape=input_shape)
def set_model_sampler_cfg_function(self, sampler_cfg_function):
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
@ -93,6 +96,12 @@ class ModelPatcher:
def set_model_attn2_output_patch(self, patch):
self.set_model_patch(patch, "attn2_output_patch")
def set_model_input_block_patch(self, patch):
self.set_model_patch(patch, "input_block_patch")
def set_model_input_block_patch_after_skip(self, patch):
self.set_model_patch(patch, "input_block_patch_after_skip")
def set_model_output_block_patch(self, patch):
self.set_model_patch(patch, "output_block_patch")

View File

@ -1,7 +1,7 @@
import torch
import numpy as np
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
import math
class EPS:
def calculate_input(self, sigma, noise):
@ -24,7 +24,7 @@ class ModelSamplingDiscrete(torch.nn.Module):
super().__init__()
beta_schedule = "linear"
if model_config is not None:
beta_schedule = model_config.beta_schedule
beta_schedule = model_config.sampling_settings.get("beta_schedule", beta_schedule)
self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
self.sigma_data = 1.0
@ -65,16 +65,65 @@ class ModelSamplingDiscrete(torch.nn.Module):
def timestep(self, sigma):
log_sigma = sigma.log()
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
return dists.abs().argmin(dim=0).view(sigma.shape)
return dists.abs().argmin(dim=0).view(sigma.shape).to(sigma.device)
def sigma(self, timestep):
t = torch.clamp(timestep.float(), min=0, max=(len(self.sigmas) - 1))
t = torch.clamp(timestep.float().to(self.log_sigmas.device), min=0, max=(len(self.sigmas) - 1))
low_idx = t.floor().long()
high_idx = t.ceil().long()
w = t.frac()
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
return log_sigma.exp()
return log_sigma.exp().to(timestep.device)
def percent_to_sigma(self, percent):
return self.sigma(torch.tensor(percent * 999.0))
if percent <= 0.0:
return 999999999.9
if percent >= 1.0:
return 0.0
percent = 1.0 - percent
return self.sigma(torch.tensor(percent * 999.0)).item()
class ModelSamplingContinuousEDM(torch.nn.Module):
def __init__(self, model_config=None):
super().__init__()
self.sigma_data = 1.0
if model_config is not None:
sampling_settings = model_config.sampling_settings
else:
sampling_settings = {}
sigma_min = sampling_settings.get("sigma_min", 0.002)
sigma_max = sampling_settings.get("sigma_max", 120.0)
self.set_sigma_range(sigma_min, sigma_max)
def set_sigma_range(self, sigma_min, sigma_max):
sigmas = torch.linspace(math.log(sigma_min), math.log(sigma_max), 1000).exp()
self.register_buffer('sigmas', sigmas) #for compatibility with some schedulers
self.register_buffer('log_sigmas', sigmas.log())
@property
def sigma_min(self):
return self.sigmas[0]
@property
def sigma_max(self):
return self.sigmas[-1]
def timestep(self, sigma):
return 0.25 * sigma.log()
def sigma(self, timestep):
return (timestep / 0.25).exp()
def percent_to_sigma(self, percent):
if percent <= 0.0:
return 999999999.9
if percent >= 1.0:
return 0.0
percent = 1.0 - percent
log_sigma_min = math.log(self.sigma_min)
return math.exp((math.log(self.sigma_max) - log_sigma_min) * percent + log_sigma_min)

View File

@ -83,7 +83,7 @@ def prepare_sampling(model, noise_shape, positive, negative, noise_mask):
real_model = None
models, inference_memory = get_additional_models(positive, negative, model.model_dtype())
comfy.model_management.load_models_gpu([model] + models, comfy.model_management.batch_area_memory(noise_shape[0] * noise_shape[2] * noise_shape[3]) + inference_memory)
comfy.model_management.load_models_gpu([model] + models, model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory)
real_model = model.model
return real_model, positive, negative, noise_mask, models

View File

@ -11,7 +11,7 @@ import comfy.conds
#The main sampling function shared by all the samplers
#Returns denoised
def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
def get_area_and_mult(conds, x_in, timestep_in):
area = (x_in.shape[2], x_in.shape[3], 0, 0)
strength = 1.0
@ -134,7 +134,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
return out
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, model_options):
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
out_cond = torch.zeros_like(x_in)
out_count = torch.ones_like(x_in) * 1e-37
@ -170,9 +170,11 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
to_batch_temp.reverse()
to_batch = to_batch_temp[:1]
free_memory = model_management.get_free_memory(x_in.device)
for i in range(1, len(to_batch_temp) + 1):
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
if (len(batch_amount) * first_shape[0] * first_shape[2] * first_shape[3] < max_total_area):
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
if model.memory_required(input_shape) < free_memory:
to_batch = batch_amount
break
@ -218,12 +220,14 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
transformer_options["patches"] = patches
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
transformer_options["sigmas"] = timestep
c['transformer_options'] = transformer_options
if 'model_function_wrapper' in model_options:
output = model_options['model_function_wrapper'](model_function, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
else:
output = model_function(input_x, timestep_, **c).chunk(batch_chunks)
output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
del input_x
for o in range(batch_chunks):
@ -242,11 +246,10 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
return out_cond, out_uncond
max_total_area = model_management.maximum_batch_area()
if math.isclose(cond_scale, 1.0):
uncond = None
cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, model_options)
cond, uncond = calc_cond_uncond_batch(model, cond, uncond, x, timestep, model_options)
if "sampler_cfg_function" in model_options:
args = {"cond": x - cond, "uncond": x - uncond, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep}
return x - model_options["sampler_cfg_function"](args)
@ -258,7 +261,7 @@ class CFGNoisePredictor(torch.nn.Module):
super().__init__()
self.inner_model = model
def apply_model(self, x, timestep, cond, uncond, cond_scale, model_options={}, seed=None):
out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, model_options=model_options, seed=seed)
out = sampling_function(self.inner_model, x, timestep, uncond, cond, cond_scale, model_options=model_options, seed=seed)
return out
def forward(self, *args, **kwargs):
return self.apply_model(*args, **kwargs)
@ -511,52 +514,69 @@ class Sampler:
class UNIPC(Sampler):
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
return uni_pc.sample_unipc(model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=self.max_denoise(model_wrap, sigmas), extra_args=extra_args, noise_mask=denoise_mask, callback=callback, disable=disable_pbar)
return uni_pc.sample_unipc(model_wrap, noise, latent_image, sigmas, max_denoise=self.max_denoise(model_wrap, sigmas), extra_args=extra_args, noise_mask=denoise_mask, callback=callback, disable=disable_pbar)
class UNIPCBH2(Sampler):
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
return uni_pc.sample_unipc(model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=self.max_denoise(model_wrap, sigmas), extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2', disable=disable_pbar)
return uni_pc.sample_unipc(model_wrap, noise, latent_image, sigmas, max_denoise=self.max_denoise(model_wrap, sigmas), extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2', disable=disable_pbar)
KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm"]
class KSAMPLER(Sampler):
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
self.sampler_function = sampler_function
self.extra_options = extra_options
self.inpaint_options = inpaint_options
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
extra_args["denoise_mask"] = denoise_mask
model_k = KSamplerX0Inpaint(model_wrap)
model_k.latent_image = latent_image
if self.inpaint_options.get("random", False): #TODO: Should this be the default?
generator = torch.manual_seed(extra_args.get("seed", 41) + 1)
model_k.noise = torch.randn(noise.shape, generator=generator, device="cpu").to(noise.dtype).to(noise.device)
else:
model_k.noise = noise
if self.max_denoise(model_wrap, sigmas):
noise = noise * torch.sqrt(1.0 + sigmas[0] ** 2.0)
else:
noise = noise * sigmas[0]
k_callback = None
total_steps = len(sigmas) - 1
if callback is not None:
k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps)
if latent_image is not None:
noise += latent_image
samples = self.sampler_function(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **self.extra_options)
return samples
def ksampler(sampler_name, extra_options={}, inpaint_options={}):
class KSAMPLER(Sampler):
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
extra_args["denoise_mask"] = denoise_mask
model_k = KSamplerX0Inpaint(model_wrap)
model_k.latent_image = latent_image
if inpaint_options.get("random", False): #TODO: Should this be the default?
generator = torch.manual_seed(extra_args.get("seed", 41) + 1)
model_k.noise = torch.randn(noise.shape, generator=generator, device="cpu").to(noise.dtype).to(noise.device)
else:
model_k.noise = noise
if self.max_denoise(model_wrap, sigmas):
noise = noise * torch.sqrt(1.0 + sigmas[0] ** 2.0)
else:
noise = noise * sigmas[0]
k_callback = None
total_steps = len(sigmas) - 1
if callback is not None:
k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps)
if sampler_name == "dpm_fast":
def dpm_fast_function(model, noise, sigmas, extra_args, callback, disable):
sigma_min = sigmas[-1]
if sigma_min == 0:
sigma_min = sigmas[-2]
total_steps = len(sigmas) - 1
return k_diffusion_sampling.sample_dpm_fast(model, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=callback, disable=disable)
sampler_function = dpm_fast_function
elif sampler_name == "dpm_adaptive":
def dpm_adaptive_function(model, noise, sigmas, extra_args, callback, disable):
sigma_min = sigmas[-1]
if sigma_min == 0:
sigma_min = sigmas[-2]
return k_diffusion_sampling.sample_dpm_adaptive(model, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=callback, disable=disable)
sampler_function = dpm_adaptive_function
else:
sampler_function = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name))
if latent_image is not None:
noise += latent_image
if sampler_name == "dpm_fast":
samples = k_diffusion_sampling.sample_dpm_fast(model_k, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=k_callback, disable=disable_pbar)
elif sampler_name == "dpm_adaptive":
samples = k_diffusion_sampling.sample_dpm_adaptive(model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback, disable=disable_pbar)
else:
samples = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name))(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **extra_options)
return samples
return KSAMPLER
return KSAMPLER(sampler_function, extra_options, inpaint_options)
def wrap_model(model):
model_denoise = CFGNoisePredictor(model)
@ -617,11 +637,11 @@ def calculate_sigmas_scheduler(model, scheduler_name, steps):
print("error invalid scheduler", self.scheduler)
return sigmas
def sampler_class(name):
def sampler_object(name):
if name == "uni_pc":
sampler = UNIPC
sampler = UNIPC()
elif name == "uni_pc_bh2":
sampler = UNIPCBH2
sampler = UNIPCBH2()
elif name == "ddim":
sampler = ksampler("euler", inpaint_options={"random": True})
else:
@ -686,6 +706,6 @@ class KSampler:
else:
return torch.zeros_like(noise)
sampler = sampler_class(self.sampler)
sampler = sampler_object(self.sampler)
return sample(self.model, noise, positive, negative, cfg, self.device, sampler(), sigmas, self.model_options, latent_image=latent_image, denoise_mask=denoise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
return sample(self.model, noise, positive, negative, cfg, self.device, sampler, sigmas, self.model_options, latent_image=latent_image, denoise_mask=denoise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)

View File

@ -23,6 +23,7 @@ import comfy.model_patcher
import comfy.lora
import comfy.t2i_adapter.adapter
import comfy.supported_models_base
import comfy.taesd.taesd
def load_model_weights(model, sd):
m, u = model.load_state_dict(sd, strict=False)
@ -95,10 +96,7 @@ class CLIP:
load_device = model_management.text_encoder_device()
offload_device = model_management.text_encoder_offload_device()
params['device'] = offload_device
if model_management.should_use_fp16(load_device, prioritize_performance=False):
params['dtype'] = torch.float16
else:
params['dtype'] = torch.float32
params['dtype'] = model_management.text_encoder_dtype(load_device)
self.cond_stage_model = clip(**(params))
@ -157,10 +155,24 @@ class VAE:
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
sd = diffusers_convert.convert_vae_state_dict(sd)
self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) #These are for AutoencoderKL and need tweaking (should be lower)
self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype)
if config is None:
#default SD1.x/SD2.x VAE parameters
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4)
if "decoder.mid.block_1.mix_factor" in sd:
encoder_config = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
decoder_config = encoder_config.copy()
decoder_config["video_kernel_size"] = [3, 1, 1]
decoder_config["alpha"] = 0.0
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': encoder_config},
decoder_config={'target': "comfy.ldm.modules.temporal_ae.VideoDecoder", 'params': decoder_config})
elif "taesd_decoder.1.weight" in sd:
self.first_stage_model = comfy.taesd.taesd.TAESD()
else:
#default SD1.x/SD2.x VAE parameters
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4)
else:
self.first_stage_model = AutoencoderKL(**(config['params']))
self.first_stage_model = self.first_stage_model.eval()
@ -175,10 +187,12 @@ class VAE:
if device is None:
device = model_management.vae_device()
self.device = device
self.offload_device = model_management.vae_offload_device()
offload_device = model_management.vae_offload_device()
self.vae_dtype = model_management.vae_dtype()
self.first_stage_model.to(self.vae_dtype)
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
steps = samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap)
steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap)
@ -207,10 +221,9 @@ class VAE:
return samples
def decode(self, samples_in):
self.first_stage_model = self.first_stage_model.to(self.device)
try:
memory_used = (2562 * samples_in.shape[2] * samples_in.shape[3] * 64) * 1.7
model_management.free_memory(memory_used, self.device)
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
free_memory = model_management.get_free_memory(self.device)
batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number)
@ -223,22 +236,19 @@ class VAE:
print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
pixel_samples = self.decode_tiled_(samples_in)
self.first_stage_model = self.first_stage_model.to(self.offload_device)
pixel_samples = pixel_samples.cpu().movedim(1,-1)
return pixel_samples
def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16):
self.first_stage_model = self.first_stage_model.to(self.device)
model_management.load_model_gpu(self.patcher)
output = self.decode_tiled_(samples, tile_x, tile_y, overlap)
self.first_stage_model = self.first_stage_model.to(self.offload_device)
return output.movedim(1,-1)
def encode(self, pixel_samples):
self.first_stage_model = self.first_stage_model.to(self.device)
pixel_samples = pixel_samples.movedim(-1,1)
try:
memory_used = (2078 * pixel_samples.shape[2] * pixel_samples.shape[3]) * 1.7 #NOTE: this constant along with the one in the decode above are estimated from the mem usage for the VAE and could change.
model_management.free_memory(memory_used, self.device)
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
free_memory = model_management.get_free_memory(self.device)
batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number)
@ -251,14 +261,12 @@ class VAE:
print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
samples = self.encode_tiled_(pixel_samples)
self.first_stage_model = self.first_stage_model.to(self.offload_device)
return samples
def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
self.first_stage_model = self.first_stage_model.to(self.device)
model_management.load_model_gpu(self.patcher)
pixel_samples = pixel_samples.movedim(-1,1)
samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap)
self.first_stage_model = self.first_stage_model.to(self.offload_device)
return samples
def get_sd(self):
@ -444,6 +452,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
if output_vae:
vae_sd = comfy.utils.state_dict_prefix_replace(sd, {"first_stage_model.": ""}, filter_keys=True)
vae_sd = model_config.process_vae_state_dict(vae_sd)
vae = VAE(sd=vae_sd)
if output_clip:
@ -468,20 +477,18 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
return (model_patcher, clip, vae, clipvision)
def load_unet(unet_path): #load unet in diffusers format
sd = comfy.utils.load_torch_file(unet_path)
def load_unet_state_dict(sd): #load unet in diffusers format
parameters = comfy.utils.calculate_parameters(sd)
unet_dtype = model_management.unet_dtype(model_params=parameters)
if "input_blocks.0.0.weight" in sd: #ldm
model_config = model_detection.model_config_from_unet(sd, "", unet_dtype)
if model_config is None:
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
return None
new_sd = sd
else: #diffusers
model_config = model_detection.model_config_from_diffusers_unet(sd, unet_dtype)
if model_config is None:
print("ERROR UNSUPPORTED UNET", unet_path)
return None
diffusers_keys = comfy.utils.unet_to_diffusers(model_config.unet_config)
@ -501,6 +508,14 @@ def load_unet(unet_path): #load unet in diffusers format
print("left over keys in unet:", left_over)
return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device)
def load_unet(unet_path):
sd = comfy.utils.load_torch_file(unet_path)
model = load_unet_state_dict(sd)
if model is None:
print("ERROR UNSUPPORTED UNET", unet_path)
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
return model
def save_checkpoint(output_path, model, clip, vae, metadata=None):
model_management.load_models_gpu([model, clip.load_model()])
sd = model.model.state_dict_for_saving(clip.get_sd(), vae.get_sd())

View File

@ -173,9 +173,9 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
if getattr(self.transformer, self.inner_name).final_layer_norm.weight.dtype != torch.float32:
precision_scope = torch.autocast
else:
precision_scope = lambda a, b: contextlib.nullcontext(a)
precision_scope = lambda a, dtype: contextlib.nullcontext(a)
with precision_scope(model_management.get_autocast_device(device), torch.float32):
with precision_scope(model_management.get_autocast_device(device), dtype=torch.float32):
attention_mask = None
if self.enable_attention_masks:
attention_mask = torch.zeros_like(tokens)

View File

@ -17,6 +17,7 @@ class SD15(supported_models_base.BASE):
"model_channels": 320,
"use_linear_in_transformer": False,
"adm_in_channels": None,
"use_temporal_attention": False,
}
unet_extra_config = {
@ -56,6 +57,7 @@ class SD20(supported_models_base.BASE):
"model_channels": 320,
"use_linear_in_transformer": True,
"adm_in_channels": None,
"use_temporal_attention": False,
}
latent_format = latent_formats.SD15
@ -69,6 +71,10 @@ class SD20(supported_models_base.BASE):
return model_base.ModelType.EPS
def process_clip_state_dict(self, state_dict):
replace_prefix = {}
replace_prefix["conditioner.embedders.0.model."] = "cond_stage_model.model." #SD2 in sgm format
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.clip_h.transformer.text_model.", 24)
return state_dict
@ -88,6 +94,7 @@ class SD21UnclipL(SD20):
"model_channels": 320,
"use_linear_in_transformer": True,
"adm_in_channels": 1536,
"use_temporal_attention": False,
}
clip_vision_prefix = "embedder.model.visual."
@ -100,6 +107,7 @@ class SD21UnclipH(SD20):
"model_channels": 320,
"use_linear_in_transformer": True,
"adm_in_channels": 2048,
"use_temporal_attention": False,
}
clip_vision_prefix = "embedder.model.visual."
@ -112,6 +120,7 @@ class SDXLRefiner(supported_models_base.BASE):
"context_dim": 1280,
"adm_in_channels": 2560,
"transformer_depth": [0, 0, 4, 4, 4, 4, 0, 0],
"use_temporal_attention": False,
}
latent_format = latent_formats.SDXL
@ -148,7 +157,8 @@ class SDXL(supported_models_base.BASE):
"use_linear_in_transformer": True,
"transformer_depth": [0, 0, 2, 2, 10, 10],
"context_dim": 2048,
"adm_in_channels": 2816
"adm_in_channels": 2816,
"use_temporal_attention": False,
}
latent_format = latent_formats.SDXL
@ -203,8 +213,34 @@ class SSD1B(SDXL):
"use_linear_in_transformer": True,
"transformer_depth": [0, 0, 2, 2, 4, 4],
"context_dim": 2048,
"adm_in_channels": 2816
"adm_in_channels": 2816,
"use_temporal_attention": False,
}
class SVD_img2vid(supported_models_base.BASE):
unet_config = {
"model_channels": 320,
"in_channels": 8,
"use_linear_in_transformer": True,
"transformer_depth": [1, 1, 1, 1, 1, 1, 0, 0],
"context_dim": 1024,
"adm_in_channels": 768,
"use_temporal_attention": True,
"use_temporal_resblock": True
}
clip_vision_prefix = "conditioner.embedders.0.open_clip.model.visual."
latent_format = latent_formats.SD15
sampling_settings = {"sigma_max": 700.0, "sigma_min": 0.002}
def get_model(self, state_dict, prefix="", device=None):
out = model_base.SVD_img2vid(self, device=device)
return out
def clip_target(self):
return None
models = [SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B]
models += [SVD_img2vid]

View File

@ -19,7 +19,7 @@ class BASE:
clip_prefix = []
clip_vision_prefix = None
noise_aug_config = None
beta_schedule = "linear"
sampling_settings = {}
latent_format = latent_formats.LatentFormat
@classmethod
@ -53,6 +53,12 @@ class BASE:
def process_clip_state_dict(self, state_dict):
return state_dict
def process_unet_state_dict(self, state_dict):
return state_dict
def process_vae_state_dict(self, state_dict):
return state_dict
def process_clip_state_dict_for_saving(self, state_dict):
replace_prefix = {"": "cond_stage_model."}
return utils.state_dict_prefix_replace(state_dict, replace_prefix)

View File

@ -46,15 +46,16 @@ class TAESD(nn.Module):
latent_magnitude = 3
latent_shift = 0.5
def __init__(self, encoder_path="taesd_encoder.pth", decoder_path="taesd_decoder.pth"):
def __init__(self, encoder_path=None, decoder_path=None):
"""Initialize pretrained TAESD on the given device from the given checkpoints."""
super().__init__()
self.encoder = Encoder()
self.decoder = Decoder()
self.taesd_encoder = Encoder()
self.taesd_decoder = Decoder()
self.vae_scale = torch.nn.Parameter(torch.tensor(1.0))
if encoder_path is not None:
self.encoder.load_state_dict(comfy.utils.load_torch_file(encoder_path, safe_load=True))
self.taesd_encoder.load_state_dict(comfy.utils.load_torch_file(encoder_path, safe_load=True))
if decoder_path is not None:
self.decoder.load_state_dict(comfy.utils.load_torch_file(decoder_path, safe_load=True))
self.taesd_decoder.load_state_dict(comfy.utils.load_torch_file(decoder_path, safe_load=True))
@staticmethod
def scale_latents(x):
@ -65,3 +66,11 @@ class TAESD(nn.Module):
def unscale_latents(x):
"""[0, 1] -> raw latents"""
return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)
def decode(self, x):
x_sample = self.taesd_decoder(x * self.vae_scale)
x_sample = x_sample.sub(0.5).mul(2)
return x_sample
def encode(self, x):
return self.taesd_encoder(x * 0.5 + 0.5) / self.vae_scale

View File

@ -258,7 +258,7 @@ def set_attr(obj, attr, value):
for name in attrs[:-1]:
obj = getattr(obj, name)
prev = getattr(obj, attrs[-1])
setattr(obj, attrs[-1], torch.nn.Parameter(value))
setattr(obj, attrs[-1], torch.nn.Parameter(value, requires_grad=False))
del prev
def copy_to_param(obj, attr, value):
@ -307,23 +307,25 @@ def bislerp(samples, width, height):
res[dot < 1e-5 - 1] = (b1 * (1.0-r) + b2 * r)[dot < 1e-5 - 1]
return res
def generate_bilinear_data(length_old, length_new):
coords_1 = torch.arange(length_old).reshape((1,1,1,-1)).to(torch.float32)
def generate_bilinear_data(length_old, length_new, device):
coords_1 = torch.arange(length_old, dtype=torch.float32, device=device).reshape((1,1,1,-1))
coords_1 = torch.nn.functional.interpolate(coords_1, size=(1, length_new), mode="bilinear")
ratios = coords_1 - coords_1.floor()
coords_1 = coords_1.to(torch.int64)
coords_2 = torch.arange(length_old).reshape((1,1,1,-1)).to(torch.float32) + 1
coords_2 = torch.arange(length_old, dtype=torch.float32, device=device).reshape((1,1,1,-1)) + 1
coords_2[:,:,:,-1] -= 1
coords_2 = torch.nn.functional.interpolate(coords_2, size=(1, length_new), mode="bilinear")
coords_2 = coords_2.to(torch.int64)
return ratios, coords_1, coords_2
orig_dtype = samples.dtype
samples = samples.float()
n,c,h,w = samples.shape
h_new, w_new = (height, width)
#linear w
ratios, coords_1, coords_2 = generate_bilinear_data(w, w_new)
ratios, coords_1, coords_2 = generate_bilinear_data(w, w_new, samples.device)
coords_1 = coords_1.expand((n, c, h, -1))
coords_2 = coords_2.expand((n, c, h, -1))
ratios = ratios.expand((n, 1, h, -1))
@ -336,7 +338,7 @@ def bislerp(samples, width, height):
result = result.reshape(n, h, w_new, c).movedim(-1, 1)
#linear h
ratios, coords_1, coords_2 = generate_bilinear_data(h, h_new)
ratios, coords_1, coords_2 = generate_bilinear_data(h, h_new, samples.device)
coords_1 = coords_1.reshape((1,1,-1,1)).expand((n, c, -1, w_new))
coords_2 = coords_2.reshape((1,1,-1,1)).expand((n, c, -1, w_new))
ratios = ratios.reshape((1,1,-1,1)).expand((n, 1, -1, w_new))
@ -347,7 +349,7 @@ def bislerp(samples, width, height):
result = slerp(pass_1, pass_2, ratios)
result = result.reshape(n, h_new, w_new, c).movedim(-1, 1)
return result
return result.to(orig_dtype)
def lanczos(samples, width, height):
images = [Image.fromarray(np.clip(255. * image.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8)) for image in samples]

View File

@ -16,7 +16,7 @@ class BasicScheduler:
}
}
RETURN_TYPES = ("SIGMAS",)
CATEGORY = "sampling/custom_sampling"
CATEGORY = "sampling/custom_sampling/schedulers"
FUNCTION = "get_sigmas"
@ -36,7 +36,7 @@ class KarrasScheduler:
}
}
RETURN_TYPES = ("SIGMAS",)
CATEGORY = "sampling/custom_sampling"
CATEGORY = "sampling/custom_sampling/schedulers"
FUNCTION = "get_sigmas"
@ -54,7 +54,7 @@ class ExponentialScheduler:
}
}
RETURN_TYPES = ("SIGMAS",)
CATEGORY = "sampling/custom_sampling"
CATEGORY = "sampling/custom_sampling/schedulers"
FUNCTION = "get_sigmas"
@ -73,7 +73,7 @@ class PolyexponentialScheduler:
}
}
RETURN_TYPES = ("SIGMAS",)
CATEGORY = "sampling/custom_sampling"
CATEGORY = "sampling/custom_sampling/schedulers"
FUNCTION = "get_sigmas"
@ -81,6 +81,25 @@ class PolyexponentialScheduler:
sigmas = k_diffusion_sampling.get_sigmas_polyexponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho)
return (sigmas, )
class SDTurboScheduler:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"model": ("MODEL",),
"steps": ("INT", {"default": 1, "min": 1, "max": 10}),
}
}
RETURN_TYPES = ("SIGMAS",)
CATEGORY = "sampling/custom_sampling/schedulers"
FUNCTION = "get_sigmas"
def get_sigmas(self, model, steps):
timesteps = torch.flip(torch.arange(1, 11) * 100 - 1, (0,))[:steps]
sigmas = model.model.model_sampling.sigma(timesteps)
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
return (sigmas, )
class VPScheduler:
@classmethod
def INPUT_TYPES(s):
@ -92,7 +111,7 @@ class VPScheduler:
}
}
RETURN_TYPES = ("SIGMAS",)
CATEGORY = "sampling/custom_sampling"
CATEGORY = "sampling/custom_sampling/schedulers"
FUNCTION = "get_sigmas"
@ -109,7 +128,7 @@ class SplitSigmas:
}
}
RETURN_TYPES = ("SIGMAS","SIGMAS")
CATEGORY = "sampling/custom_sampling"
CATEGORY = "sampling/custom_sampling/sigmas"
FUNCTION = "get_sigmas"
@ -118,6 +137,24 @@ class SplitSigmas:
sigmas2 = sigmas[step:]
return (sigmas1, sigmas2)
class FlipSigmas:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"sigmas": ("SIGMAS", ),
}
}
RETURN_TYPES = ("SIGMAS",)
CATEGORY = "sampling/custom_sampling/sigmas"
FUNCTION = "get_sigmas"
def get_sigmas(self, sigmas):
sigmas = sigmas.flip(0)
if sigmas[0] == 0:
sigmas[0] = 0.0001
return (sigmas,)
class KSamplerSelect:
@classmethod
def INPUT_TYPES(s):
@ -126,12 +163,12 @@ class KSamplerSelect:
}
}
RETURN_TYPES = ("SAMPLER",)
CATEGORY = "sampling/custom_sampling"
CATEGORY = "sampling/custom_sampling/samplers"
FUNCTION = "get_sampler"
def get_sampler(self, sampler_name):
sampler = comfy.samplers.sampler_class(sampler_name)()
sampler = comfy.samplers.sampler_object(sampler_name)
return (sampler, )
class SamplerDPMPP_2M_SDE:
@ -145,7 +182,7 @@ class SamplerDPMPP_2M_SDE:
}
}
RETURN_TYPES = ("SAMPLER",)
CATEGORY = "sampling/custom_sampling"
CATEGORY = "sampling/custom_sampling/samplers"
FUNCTION = "get_sampler"
@ -154,7 +191,7 @@ class SamplerDPMPP_2M_SDE:
sampler_name = "dpmpp_2m_sde"
else:
sampler_name = "dpmpp_2m_sde_gpu"
sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "solver_type": solver_type})()
sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "solver_type": solver_type})
return (sampler, )
@ -169,7 +206,7 @@ class SamplerDPMPP_SDE:
}
}
RETURN_TYPES = ("SAMPLER",)
CATEGORY = "sampling/custom_sampling"
CATEGORY = "sampling/custom_sampling/samplers"
FUNCTION = "get_sampler"
@ -178,7 +215,7 @@ class SamplerDPMPP_SDE:
sampler_name = "dpmpp_sde"
else:
sampler_name = "dpmpp_sde_gpu"
sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r})()
sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r})
return (sampler, )
class SamplerCustom:
@ -234,13 +271,15 @@ class SamplerCustom:
NODE_CLASS_MAPPINGS = {
"SamplerCustom": SamplerCustom,
"BasicScheduler": BasicScheduler,
"KarrasScheduler": KarrasScheduler,
"ExponentialScheduler": ExponentialScheduler,
"PolyexponentialScheduler": PolyexponentialScheduler,
"VPScheduler": VPScheduler,
"SDTurboScheduler": SDTurboScheduler,
"KSamplerSelect": KSamplerSelect,
"SamplerDPMPP_2M_SDE": SamplerDPMPP_2M_SDE,
"SamplerDPMPP_SDE": SamplerDPMPP_SDE,
"BasicScheduler": BasicScheduler,
"SplitSigmas": SplitSigmas,
"FlipSigmas": FlipSigmas,
}

View File

@ -0,0 +1,175 @@
import nodes
import folder_paths
from comfy.cli_args import args
from PIL import Image
from PIL.PngImagePlugin import PngInfo
import numpy as np
import json
import os
MAX_RESOLUTION = nodes.MAX_RESOLUTION
class ImageCrop:
@classmethod
def INPUT_TYPES(s):
return {"required": { "image": ("IMAGE",),
"width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
"height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "crop"
CATEGORY = "image/transform"
def crop(self, image, width, height, x, y):
x = min(x, image.shape[2] - 1)
y = min(y, image.shape[1] - 1)
to_x = width + x
to_y = height + y
img = image[:,y:to_y, x:to_x, :]
return (img,)
class RepeatImageBatch:
@classmethod
def INPUT_TYPES(s):
return {"required": { "image": ("IMAGE",),
"amount": ("INT", {"default": 1, "min": 1, "max": 64}),
}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "repeat"
CATEGORY = "image/batch"
def repeat(self, image, amount):
s = image.repeat((amount, 1,1,1))
return (s,)
class SaveAnimatedWEBP:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
self.type = "output"
self.prefix_append = ""
methods = {"default": 4, "fastest": 0, "slowest": 6}
@classmethod
def INPUT_TYPES(s):
return {"required":
{"images": ("IMAGE", ),
"filename_prefix": ("STRING", {"default": "ComfyUI"}),
"fps": ("FLOAT", {"default": 6.0, "min": 0.01, "max": 1000.0, "step": 0.01}),
"lossless": ("BOOLEAN", {"default": True}),
"quality": ("INT", {"default": 80, "min": 0, "max": 100}),
"method": (list(s.methods.keys()),),
# "num_frames": ("INT", {"default": 0, "min": 0, "max": 8192}),
},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
RETURN_TYPES = ()
FUNCTION = "save_images"
OUTPUT_NODE = True
CATEGORY = "_for_testing"
def save_images(self, images, fps, filename_prefix, lossless, quality, method, num_frames=0, prompt=None, extra_pnginfo=None):
method = self.methods.get(method)
filename_prefix += self.prefix_append
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
results = list()
pil_images = []
for image in images:
i = 255. * image.cpu().numpy()
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
pil_images.append(img)
metadata = pil_images[0].getexif()
if not args.disable_metadata:
if prompt is not None:
metadata[0x0110] = "prompt:{}".format(json.dumps(prompt))
if extra_pnginfo is not None:
inital_exif = 0x010f
for x in extra_pnginfo:
metadata[inital_exif] = "{}:{}".format(x, json.dumps(extra_pnginfo[x]))
inital_exif -= 1
if num_frames == 0:
num_frames = len(pil_images)
c = len(pil_images)
for i in range(0, c, num_frames):
file = f"{filename}_{counter:05}_.webp"
pil_images[i].save(os.path.join(full_output_folder, file), save_all=True, duration=int(1000.0/fps), append_images=pil_images[i + 1:i + num_frames], exif=metadata, lossless=lossless, quality=quality, method=method)
results.append({
"filename": file,
"subfolder": subfolder,
"type": self.type
})
counter += 1
animated = num_frames != 1
return { "ui": { "images": results, "animated": (animated,) } }
class SaveAnimatedPNG:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
self.type = "output"
self.prefix_append = ""
@classmethod
def INPUT_TYPES(s):
return {"required":
{"images": ("IMAGE", ),
"filename_prefix": ("STRING", {"default": "ComfyUI"}),
"fps": ("FLOAT", {"default": 6.0, "min": 0.01, "max": 1000.0, "step": 0.01}),
"compress_level": ("INT", {"default": 4, "min": 0, "max": 9})
},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
RETURN_TYPES = ()
FUNCTION = "save_images"
OUTPUT_NODE = True
CATEGORY = "_for_testing"
def save_images(self, images, fps, compress_level, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
filename_prefix += self.prefix_append
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
results = list()
pil_images = []
for image in images:
i = 255. * image.cpu().numpy()
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
pil_images.append(img)
metadata = None
if not args.disable_metadata:
metadata = PngInfo()
if prompt is not None:
metadata.add(b"comf", "prompt".encode("latin-1", "strict") + b"\0" + json.dumps(prompt).encode("latin-1", "strict"), after_idat=True)
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata.add(b"comf", x.encode("latin-1", "strict") + b"\0" + json.dumps(extra_pnginfo[x]).encode("latin-1", "strict"), after_idat=True)
file = f"{filename}_{counter:05}_.png"
pil_images[0].save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=compress_level, save_all=True, duration=int(1000.0/fps), append_images=pil_images[1:])
results.append({
"filename": file,
"subfolder": subfolder,
"type": self.type
})
return { "ui": { "images": results, "animated": (True,)} }
NODE_CLASS_MAPPINGS = {
"ImageCrop": ImageCrop,
"RepeatImageBatch": RepeatImageBatch,
"SaveAnimatedWEBP": SaveAnimatedWEBP,
"SaveAnimatedPNG": SaveAnimatedPNG,
}

View File

@ -1,4 +1,5 @@
import comfy.utils
import torch
def reshape_latent_to(target_shape, latent):
if latent.shape[1:] != target_shape[1:]:
@ -67,8 +68,43 @@ class LatentMultiply:
samples_out["samples"] = s1 * multiplier
return (samples_out,)
class LatentInterpolate:
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples1": ("LATENT",),
"samples2": ("LATENT",),
"ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "op"
CATEGORY = "latent/advanced"
def op(self, samples1, samples2, ratio):
samples_out = samples1.copy()
s1 = samples1["samples"]
s2 = samples2["samples"]
s2 = reshape_latent_to(s1.shape, s2)
m1 = torch.linalg.vector_norm(s1, dim=(1))
m2 = torch.linalg.vector_norm(s2, dim=(1))
s1 = torch.nan_to_num(s1 / m1)
s2 = torch.nan_to_num(s2 / m2)
t = (s1 * ratio + s2 * (1.0 - ratio))
mt = torch.linalg.vector_norm(t, dim=(1))
st = torch.nan_to_num(t / mt)
samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio))
return (samples_out,)
NODE_CLASS_MAPPINGS = {
"LatentAdd": LatentAdd,
"LatentSubtract": LatentSubtract,
"LatentMultiply": LatentMultiply,
"LatentInterpolate": LatentInterpolate,
}

View File

@ -17,7 +17,9 @@ class LCM(comfy.model_sampling.EPS):
return c_out * x0 + c_skip * model_input
class ModelSamplingDiscreteLCM(torch.nn.Module):
class ModelSamplingDiscreteDistilled(torch.nn.Module):
original_timesteps = 50
def __init__(self):
super().__init__()
self.sigma_data = 1.0
@ -29,13 +31,12 @@ class ModelSamplingDiscreteLCM(torch.nn.Module):
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
original_timesteps = 50
self.skip_steps = timesteps // original_timesteps
self.skip_steps = timesteps // self.original_timesteps
alphas_cumprod_valid = torch.zeros((original_timesteps), dtype=torch.float32)
for x in range(original_timesteps):
alphas_cumprod_valid[original_timesteps - 1 - x] = alphas_cumprod[timesteps - 1 - x * self.skip_steps]
alphas_cumprod_valid = torch.zeros((self.original_timesteps), dtype=torch.float32)
for x in range(self.original_timesteps):
alphas_cumprod_valid[self.original_timesteps - 1 - x] = alphas_cumprod[timesteps - 1 - x * self.skip_steps]
sigmas = ((1 - alphas_cumprod_valid) / alphas_cumprod_valid) ** 0.5
self.set_sigmas(sigmas)
@ -55,18 +56,23 @@ class ModelSamplingDiscreteLCM(torch.nn.Module):
def timestep(self, sigma):
log_sigma = sigma.log()
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
return dists.abs().argmin(dim=0).view(sigma.shape) * self.skip_steps + (self.skip_steps - 1)
return (dists.abs().argmin(dim=0).view(sigma.shape) * self.skip_steps + (self.skip_steps - 1)).to(sigma.device)
def sigma(self, timestep):
t = torch.clamp(((timestep - (self.skip_steps - 1)) / self.skip_steps).float(), min=0, max=(len(self.sigmas) - 1))
t = torch.clamp(((timestep.float().to(self.log_sigmas.device) - (self.skip_steps - 1)) / self.skip_steps).float(), min=0, max=(len(self.sigmas) - 1))
low_idx = t.floor().long()
high_idx = t.ceil().long()
w = t.frac()
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
return log_sigma.exp()
return log_sigma.exp().to(timestep.device)
def percent_to_sigma(self, percent):
return self.sigma(torch.tensor(percent * 999.0))
if percent <= 0.0:
return 999999999.9
if percent >= 1.0:
return 0.0
percent = 1.0 - percent
return self.sigma(torch.tensor(percent * 999.0)).item()
def rescale_zero_terminal_snr_sigmas(sigmas):
@ -111,7 +117,7 @@ class ModelSamplingDiscrete:
sampling_type = comfy.model_sampling.V_PREDICTION
elif sampling == "lcm":
sampling_type = LCM
sampling_base = ModelSamplingDiscreteLCM
sampling_base = ModelSamplingDiscreteDistilled
class ModelSamplingAdvanced(sampling_base, sampling_type):
pass
@ -123,6 +129,36 @@ class ModelSamplingDiscrete:
m.add_object_patch("model_sampling", model_sampling)
return (m, )
class ModelSamplingContinuousEDM:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"sampling": (["v_prediction", "eps"],),
"sigma_max": ("FLOAT", {"default": 120.0, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
"sigma_min": ("FLOAT", {"default": 0.002, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "advanced/model"
def patch(self, model, sampling, sigma_max, sigma_min):
m = model.clone()
if sampling == "eps":
sampling_type = comfy.model_sampling.EPS
elif sampling == "v_prediction":
sampling_type = comfy.model_sampling.V_PREDICTION
class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingContinuousEDM, sampling_type):
pass
model_sampling = ModelSamplingAdvanced()
model_sampling.set_sigma_range(sigma_min, sigma_max)
m.add_object_patch("model_sampling", model_sampling)
return (m, )
class RescaleCFG:
@classmethod
def INPUT_TYPES(s):
@ -164,5 +200,6 @@ class RescaleCFG:
NODE_CLASS_MAPPINGS = {
"ModelSamplingDiscrete": ModelSamplingDiscrete,
"ModelSamplingContinuousEDM": ModelSamplingContinuousEDM,
"RescaleCFG": RescaleCFG,
}

View File

@ -0,0 +1,53 @@
import torch
import comfy.utils
class PatchModelAddDownscale:
upscale_methods = ["bicubic", "nearest-exact", "bilinear", "area", "bislerp"]
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"block_number": ("INT", {"default": 3, "min": 1, "max": 32, "step": 1}),
"downscale_factor": ("FLOAT", {"default": 2.0, "min": 0.1, "max": 9.0, "step": 0.001}),
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
"end_percent": ("FLOAT", {"default": 0.35, "min": 0.0, "max": 1.0, "step": 0.001}),
"downscale_after_skip": ("BOOLEAN", {"default": True}),
"downscale_method": (s.upscale_methods,),
"upscale_method": (s.upscale_methods,),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "_for_testing"
def patch(self, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip, downscale_method, upscale_method):
sigma_start = model.model.model_sampling.percent_to_sigma(start_percent)
sigma_end = model.model.model_sampling.percent_to_sigma(end_percent)
def input_block_patch(h, transformer_options):
if transformer_options["block"][1] == block_number:
sigma = transformer_options["sigmas"][0].item()
if sigma <= sigma_start and sigma >= sigma_end:
h = comfy.utils.common_upscale(h, round(h.shape[-1] * (1.0 / downscale_factor)), round(h.shape[-2] * (1.0 / downscale_factor)), downscale_method, "disabled")
return h
def output_block_patch(h, hsp, transformer_options):
if h.shape[2] != hsp.shape[2]:
h = comfy.utils.common_upscale(h, hsp.shape[-1], hsp.shape[-2], upscale_method, "disabled")
return h, hsp
m = model.clone()
if downscale_after_skip:
m.set_model_input_block_patch_after_skip(input_block_patch)
else:
m.set_model_input_block_patch(input_block_patch)
m.set_model_output_block_patch(output_block_patch)
return (m, )
NODE_CLASS_MAPPINGS = {
"PatchModelAddDownscale": PatchModelAddDownscale,
}
NODE_DISPLAY_NAME_MAPPINGS = {
# Sampling
"PatchModelAddDownscale": "PatchModelAddDownscale (Kohya Deep Shrink)",
}

View File

@ -0,0 +1,89 @@
import nodes
import torch
import comfy.utils
import comfy.sd
import folder_paths
class ImageOnlyCheckpointLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
}}
RETURN_TYPES = ("MODEL", "CLIP_VISION", "VAE")
FUNCTION = "load_checkpoint"
CATEGORY = "loaders/video_models"
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=False, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
return (out[0], out[3], out[2])
class SVD_img2vid_Conditioning:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_vision": ("CLIP_VISION",),
"init_image": ("IMAGE",),
"vae": ("VAE",),
"width": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 576, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"video_frames": ("INT", {"default": 14, "min": 1, "max": 4096}),
"motion_bucket_id": ("INT", {"default": 127, "min": 1, "max": 1023}),
"fps": ("INT", {"default": 6, "min": 1, "max": 1024}),
"augmentation_level": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.01})
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
FUNCTION = "encode"
CATEGORY = "conditioning/video_models"
def encode(self, clip_vision, init_image, vae, width, height, video_frames, motion_bucket_id, fps, augmentation_level):
output = clip_vision.encode_image(init_image)
pooled = output.image_embeds.unsqueeze(0)
pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1)
encode_pixels = pixels[:,:,:,:3]
if augmentation_level > 0:
encode_pixels += torch.randn_like(pixels) * augmentation_level
t = vae.encode(encode_pixels)
positive = [[pooled, {"motion_bucket_id": motion_bucket_id, "fps": fps, "augmentation_level": augmentation_level, "concat_latent_image": t}]]
negative = [[torch.zeros_like(pooled), {"motion_bucket_id": motion_bucket_id, "fps": fps, "augmentation_level": augmentation_level, "concat_latent_image": torch.zeros_like(t)}]]
latent = torch.zeros([video_frames, 4, height // 8, width // 8])
return (positive, negative, {"samples":latent})
class VideoLinearCFGGuidance:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"min_cfg": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "sampling/video_models"
def patch(self, model, min_cfg):
def linear_cfg(args):
cond = args["cond"]
uncond = args["uncond"]
cond_scale = args["cond_scale"]
scale = torch.linspace(min_cfg, cond_scale, cond.shape[0], device=cond.device).reshape((cond.shape[0], 1, 1, 1))
return uncond + scale * (cond - uncond)
m = model.clone()
m.set_model_sampler_cfg_function(linear_cfg)
return (m, )
NODE_CLASS_MAPPINGS = {
"ImageOnlyCheckpointLoader": ImageOnlyCheckpointLoader,
"SVD_img2vid_Conditioning": SVD_img2vid_Conditioning,
"VideoLinearCFGGuidance": VideoLinearCFGGuidance,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"ImageOnlyCheckpointLoader": "Image Only Checkpoint Loader (img2vid model)",
}

View File

@ -681,6 +681,7 @@ def validate_prompt(prompt):
return (True, None, list(good_outputs), node_errors)
MAXIMUM_HISTORY_SIZE = 10000
class PromptQueue:
def __init__(self, server):
@ -699,10 +700,12 @@ class PromptQueue:
self.server.queue_updated()
self.not_empty.notify()
def get(self):
def get(self, timeout=None):
with self.not_empty:
while len(self.queue) == 0:
self.not_empty.wait()
self.not_empty.wait(timeout=timeout)
if timeout is not None and len(self.queue) == 0:
return None
item = heapq.heappop(self.queue)
i = self.task_counter
self.currently_running[i] = copy.deepcopy(item)
@ -713,6 +716,8 @@ class PromptQueue:
def task_done(self, item_id, outputs):
with self.mutex:
prompt = self.currently_running.pop(item_id)
if len(self.history) > MAXIMUM_HISTORY_SIZE:
self.history.pop(next(iter(self.history)))
self.history[prompt[1]] = { "prompt": prompt, "outputs": {} }
for o in outputs:
self.history[prompt[1]]["outputs"][o] = outputs[o]
@ -747,10 +752,20 @@ class PromptQueue:
return True
return False
def get_history(self, prompt_id=None):
def get_history(self, prompt_id=None, max_items=None, offset=-1):
with self.mutex:
if prompt_id is None:
return copy.deepcopy(self.history)
out = {}
i = 0
if offset < 0 and max_items is not None:
offset = len(self.history) - max_items
for k in self.history:
if i >= offset:
out[k] = self.history[k]
if max_items is not None and len(out) >= max_items:
break
i += 1
return out
elif prompt_id in self.history:
return {prompt_id: copy.deepcopy(self.history[prompt_id])}
else:

View File

@ -38,7 +38,10 @@ input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "inp
filename_list_cache = {}
if not os.path.exists(input_directory):
os.makedirs(input_directory)
try:
os.makedirs(input_directory)
except:
print("Failed to create input directory")
def set_output_directory(output_dir):
global output_directory
@ -228,8 +231,12 @@ def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height
full_output_folder = os.path.join(output_dir, subfolder)
if os.path.commonpath((output_dir, os.path.abspath(full_output_folder))) != output_dir:
print("Saving image outside the output folder is not allowed.")
return {}
err = "**** ERROR: Saving image outside the output folder is not allowed." + \
"\n full_output_folder: " + os.path.abspath(full_output_folder) + \
"\n output_dir: " + output_dir + \
"\n commonpath: " + os.path.commonpath((output_dir, os.path.abspath(full_output_folder)))
print(err)
raise Exception(err)
try:
counter = max(filter(lambda a: a[1][:-1] == filename and a[1][-1] == "_", map(map_filename, os.listdir(full_output_folder))))[0] + 1

View File

@ -22,10 +22,7 @@ class TAESDPreviewerImpl(LatentPreviewer):
self.taesd = taesd
def decode_latent_to_preview(self, x0):
x_sample = self.taesd.decoder(x0[:1])[0].detach()
# x_sample = self.taesd.unscale_latents(x_sample).div(4).add(0.5) # returns value in [-2, 2]
x_sample = x_sample.sub(0.5).mul(2)
x_sample = self.taesd.decode(x0[:1])[0].detach()
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)

41
main.py
View File

@ -88,18 +88,37 @@ def cuda_malloc_warning():
def prompt_worker(q, server):
e = execution.PromptExecutor(server)
while True:
item, item_id = q.get()
execution_start_time = time.perf_counter()
prompt_id = item[1]
e.execute(item[2], prompt_id, item[3], item[4])
q.task_done(item_id, e.outputs_ui)
if server.client_id is not None:
server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, server.client_id)
last_gc_collect = 0
need_gc = False
gc_collect_interval = 10.0
print("Prompt executed in {:.2f} seconds".format(time.perf_counter() - execution_start_time))
gc.collect()
comfy.model_management.soft_empty_cache()
while True:
timeout = None
if need_gc:
timeout = max(gc_collect_interval - (current_time - last_gc_collect), 0.0)
queue_item = q.get(timeout=timeout)
if queue_item is not None:
item, item_id = queue_item
execution_start_time = time.perf_counter()
prompt_id = item[1]
e.execute(item[2], prompt_id, item[3], item[4])
need_gc = True
q.task_done(item_id, e.outputs_ui)
if server.client_id is not None:
server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, server.client_id)
current_time = time.perf_counter()
execution_time = current_time - execution_start_time
print("Prompt executed in {:.2f} seconds".format(execution_time))
if need_gc:
current_time = time.perf_counter()
if (current_time - last_gc_collect) > gc_collect_interval:
gc.collect()
comfy.model_management.soft_empty_cache()
last_gc_collect = current_time
need_gc = False
async def run(server, address='', port=8188, verbose=True, call_on_start=None):
await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop())

View File

@ -248,8 +248,8 @@ class ConditioningSetTimestepRange:
c = []
for t in conditioning:
d = t[1].copy()
d['start_percent'] = 1.0 - start
d['end_percent'] = 1.0 - end
d['start_percent'] = start
d['end_percent'] = end
n = [t[0], d]
c.append(n)
return (c, )
@ -572,10 +572,69 @@ class LoraLoader:
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip)
return (model_lora, clip_lora)
class VAELoader:
class LoraLoaderModelOnly(LoraLoader):
@classmethod
def INPUT_TYPES(s):
return {"required": { "vae_name": (folder_paths.get_filename_list("vae"), )}}
return {"required": { "model": ("MODEL",),
"lora_name": (folder_paths.get_filename_list("loras"), ),
"strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "load_lora_model_only"
def load_lora_model_only(self, model, lora_name, strength_model):
return (self.load_lora(model, None, lora_name, strength_model, 0)[0],)
class VAELoader:
@staticmethod
def vae_list():
vaes = folder_paths.get_filename_list("vae")
approx_vaes = folder_paths.get_filename_list("vae_approx")
sdxl_taesd_enc = False
sdxl_taesd_dec = False
sd1_taesd_enc = False
sd1_taesd_dec = False
for v in approx_vaes:
if v.startswith("taesd_decoder."):
sd1_taesd_dec = True
elif v.startswith("taesd_encoder."):
sd1_taesd_enc = True
elif v.startswith("taesdxl_decoder."):
sdxl_taesd_dec = True
elif v.startswith("taesdxl_encoder."):
sdxl_taesd_enc = True
if sd1_taesd_dec and sd1_taesd_enc:
vaes.append("taesd")
if sdxl_taesd_dec and sdxl_taesd_enc:
vaes.append("taesdxl")
return vaes
@staticmethod
def load_taesd(name):
sd = {}
approx_vaes = folder_paths.get_filename_list("vae_approx")
encoder = next(filter(lambda a: a.startswith("{}_encoder.".format(name)), approx_vaes))
decoder = next(filter(lambda a: a.startswith("{}_decoder.".format(name)), approx_vaes))
enc = comfy.utils.load_torch_file(folder_paths.get_full_path("vae_approx", encoder))
for k in enc:
sd["taesd_encoder.{}".format(k)] = enc[k]
dec = comfy.utils.load_torch_file(folder_paths.get_full_path("vae_approx", decoder))
for k in dec:
sd["taesd_decoder.{}".format(k)] = dec[k]
if name == "taesd":
sd["vae_scale"] = torch.tensor(0.18215)
elif name == "taesdxl":
sd["vae_scale"] = torch.tensor(0.13025)
return sd
@classmethod
def INPUT_TYPES(s):
return {"required": { "vae_name": (s.vae_list(), )}}
RETURN_TYPES = ("VAE",)
FUNCTION = "load_vae"
@ -583,8 +642,11 @@ class VAELoader:
#TODO: scale factor?
def load_vae(self, vae_name):
vae_path = folder_paths.get_full_path("vae", vae_name)
sd = comfy.utils.load_torch_file(vae_path)
if vae_name in ["taesd", "taesdxl"]:
sd = self.load_taesd(vae_name)
else:
vae_path = folder_paths.get_full_path("vae", vae_name)
sd = comfy.utils.load_torch_file(vae_path)
vae = comfy.sd.VAE(sd=sd)
return (vae,)
@ -685,7 +747,7 @@ class ControlNetApplyAdvanced:
if prev_cnet in cnets:
c_net = cnets[prev_cnet]
else:
c_net = control_net.copy().set_cond_hint(control_hint, strength, (1.0 - start_percent, 1.0 - end_percent))
c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent))
c_net.set_previous_controlnet(prev_cnet)
cnets[prev_cnet] = c_net
@ -1275,6 +1337,7 @@ class SaveImage:
self.output_dir = folder_paths.get_output_directory()
self.type = "output"
self.prefix_append = ""
self.compress_level = 4
@classmethod
def INPUT_TYPES(s):
@ -1308,7 +1371,7 @@ class SaveImage:
metadata.add_text(x, json.dumps(extra_pnginfo[x]))
file = f"{filename}_{counter:05}_.png"
img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=4)
img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=self.compress_level)
results.append({
"filename": file,
"subfolder": subfolder,
@ -1323,6 +1386,7 @@ class PreviewImage(SaveImage):
self.output_dir = folder_paths.get_temp_directory()
self.type = "temp"
self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5))
self.compress_level = 1
@classmethod
def INPUT_TYPES(s):
@ -1654,6 +1718,7 @@ NODE_CLASS_MAPPINGS = {
"ConditioningZeroOut": ConditioningZeroOut,
"ConditioningSetTimestepRange": ConditioningSetTimestepRange,
"LoraLoaderModelOnly": LoraLoaderModelOnly,
}
NODE_DISPLAY_NAME_MAPPINGS = {
@ -1759,7 +1824,7 @@ def load_custom_nodes():
node_paths = folder_paths.get_folder_paths("custom_nodes")
node_import_times = []
for custom_node_path in node_paths:
possible_modules = os.listdir(custom_node_path)
possible_modules = os.listdir(os.path.realpath(custom_node_path))
if "__pycache__" in possible_modules:
possible_modules.remove("__pycache__")
@ -1799,6 +1864,9 @@ def init_custom_nodes():
"nodes_custom_sampler.py",
"nodes_hypertile.py",
"nodes_model_advanced.py",
"nodes_model_downscale.py",
"nodes_images.py",
"nodes_video_model.py",
]
for node_file in extras_files:

View File

@ -431,7 +431,10 @@ class PromptServer():
@routes.get("/history")
async def get_history(request):
return web.json_response(self.prompt_queue.get_history())
max_items = request.rel_url.query.get("max_items", None)
if max_items is not None:
max_items = int(max_items)
return web.json_response(self.prompt_queue.get_history(max_items=max_items))
@routes.get("/history/{prompt_id}")
async def get_history(request):
@ -573,7 +576,7 @@ class PromptServer():
bytesIO = BytesIO()
header = struct.pack(">I", type_num)
bytesIO.write(header)
image.save(bytesIO, format=image_type, quality=95, compress_level=4)
image.save(bytesIO, format=image_type, quality=95, compress_level=1)
preview_bytes = bytesIO.getvalue()
await self.send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid)

View File

@ -20,6 +20,7 @@ async function setup() {
// Modify the response data to add some checkpoints
const objectInfo = JSON.parse(data);
objectInfo.CheckpointLoaderSimple.input.required.ckpt_name[0] = ["model1.safetensors", "model2.ckpt"];
objectInfo.VAELoader.input.required.vae_name[0] = ["vae1.safetensors", "vae2.ckpt"];
data = JSON.stringify(objectInfo, undefined, "\t");

View File

@ -0,0 +1,196 @@
// @ts-check
/// <reference path="../node_modules/@types/jest/index.d.ts" />
const { start } = require("../utils");
const lg = require("../utils/litegraph");
describe("extensions", () => {
beforeEach(() => {
lg.setup(global);
});
afterEach(() => {
lg.teardown(global);
});
it("calls each extension hook", async () => {
const mockExtension = {
name: "TestExtension",
init: jest.fn(),
setup: jest.fn(),
addCustomNodeDefs: jest.fn(),
getCustomWidgets: jest.fn(),
beforeRegisterNodeDef: jest.fn(),
registerCustomNodes: jest.fn(),
loadedGraphNode: jest.fn(),
nodeCreated: jest.fn(),
beforeConfigureGraph: jest.fn(),
afterConfigureGraph: jest.fn(),
};
const { app, ez, graph } = await start({
async preSetup(app) {
app.registerExtension(mockExtension);
},
});
// Basic initialisation hooks should be called once, with app
expect(mockExtension.init).toHaveBeenCalledTimes(1);
expect(mockExtension.init).toHaveBeenCalledWith(app);
// Adding custom node defs should be passed the full list of nodes
expect(mockExtension.addCustomNodeDefs).toHaveBeenCalledTimes(1);
expect(mockExtension.addCustomNodeDefs.mock.calls[0][1]).toStrictEqual(app);
const defs = mockExtension.addCustomNodeDefs.mock.calls[0][0];
expect(defs).toHaveProperty("KSampler");
expect(defs).toHaveProperty("LoadImage");
// Get custom widgets is called once and should return new widget types
expect(mockExtension.getCustomWidgets).toHaveBeenCalledTimes(1);
expect(mockExtension.getCustomWidgets).toHaveBeenCalledWith(app);
// Before register node def will be called once per node type
const nodeNames = Object.keys(defs);
const nodeCount = nodeNames.length;
expect(mockExtension.beforeRegisterNodeDef).toHaveBeenCalledTimes(nodeCount);
for (let i = 0; i < nodeCount; i++) {
// It should be send the JS class and the original JSON definition
const nodeClass = mockExtension.beforeRegisterNodeDef.mock.calls[i][0];
const nodeDef = mockExtension.beforeRegisterNodeDef.mock.calls[i][1];
expect(nodeClass.name).toBe("ComfyNode");
expect(nodeClass.comfyClass).toBe(nodeNames[i]);
expect(nodeDef.name).toBe(nodeNames[i]);
expect(nodeDef).toHaveProperty("input");
expect(nodeDef).toHaveProperty("output");
}
// Register custom nodes is called once after registerNode defs to allow adding other frontend nodes
expect(mockExtension.registerCustomNodes).toHaveBeenCalledTimes(1);
// Before configure graph will be called here as the default graph is being loaded
expect(mockExtension.beforeConfigureGraph).toHaveBeenCalledTimes(1);
// it gets sent the graph data that is going to be loaded
const graphData = mockExtension.beforeConfigureGraph.mock.calls[0][0];
// A node created is fired for each node constructor that is called
expect(mockExtension.nodeCreated).toHaveBeenCalledTimes(graphData.nodes.length);
for (let i = 0; i < graphData.nodes.length; i++) {
expect(mockExtension.nodeCreated.mock.calls[i][0].type).toBe(graphData.nodes[i].type);
}
// Each node then calls loadedGraphNode to allow them to be updated
expect(mockExtension.loadedGraphNode).toHaveBeenCalledTimes(graphData.nodes.length);
for (let i = 0; i < graphData.nodes.length; i++) {
expect(mockExtension.loadedGraphNode.mock.calls[i][0].type).toBe(graphData.nodes[i].type);
}
// After configure is then called once all the setup is done
expect(mockExtension.afterConfigureGraph).toHaveBeenCalledTimes(1);
expect(mockExtension.setup).toHaveBeenCalledTimes(1);
expect(mockExtension.setup).toHaveBeenCalledWith(app);
// Ensure hooks are called in the correct order
const callOrder = [
"init",
"addCustomNodeDefs",
"getCustomWidgets",
"beforeRegisterNodeDef",
"registerCustomNodes",
"beforeConfigureGraph",
"nodeCreated",
"loadedGraphNode",
"afterConfigureGraph",
"setup",
];
for (let i = 1; i < callOrder.length; i++) {
const fn1 = mockExtension[callOrder[i - 1]];
const fn2 = mockExtension[callOrder[i]];
expect(fn1.mock.invocationCallOrder[0]).toBeLessThan(fn2.mock.invocationCallOrder[0]);
}
graph.clear();
// Ensure adding a new node calls the correct callback
ez.LoadImage();
expect(mockExtension.loadedGraphNode).toHaveBeenCalledTimes(graphData.nodes.length);
expect(mockExtension.nodeCreated).toHaveBeenCalledTimes(graphData.nodes.length + 1);
expect(mockExtension.nodeCreated.mock.lastCall[0].type).toBe("LoadImage");
// Reload the graph to ensure correct hooks are fired
await graph.reload();
// These hooks should not be fired again
expect(mockExtension.init).toHaveBeenCalledTimes(1);
expect(mockExtension.addCustomNodeDefs).toHaveBeenCalledTimes(1);
expect(mockExtension.getCustomWidgets).toHaveBeenCalledTimes(1);
expect(mockExtension.registerCustomNodes).toHaveBeenCalledTimes(1);
expect(mockExtension.beforeRegisterNodeDef).toHaveBeenCalledTimes(nodeCount);
expect(mockExtension.setup).toHaveBeenCalledTimes(1);
// These should be called again
expect(mockExtension.beforeConfigureGraph).toHaveBeenCalledTimes(2);
expect(mockExtension.nodeCreated).toHaveBeenCalledTimes(graphData.nodes.length + 2);
expect(mockExtension.loadedGraphNode).toHaveBeenCalledTimes(graphData.nodes.length + 1);
expect(mockExtension.afterConfigureGraph).toHaveBeenCalledTimes(2);
});
it("allows custom nodeDefs and widgets to be registered", async () => {
const widgetMock = jest.fn((node, inputName, inputData, app) => {
expect(node.constructor.comfyClass).toBe("TestNode");
expect(inputName).toBe("test_input");
expect(inputData[0]).toBe("CUSTOMWIDGET");
expect(inputData[1]?.hello).toBe("world");
expect(app).toStrictEqual(app);
return {
widget: node.addWidget("button", inputName, "hello", () => {}),
};
});
// Register our extension that adds a custom node + widget type
const mockExtension = {
name: "TestExtension",
addCustomNodeDefs: (nodeDefs) => {
nodeDefs["TestNode"] = {
output: [],
output_name: [],
output_is_list: [],
name: "TestNode",
display_name: "TestNode",
category: "Test",
input: {
required: {
test_input: ["CUSTOMWIDGET", { hello: "world" }],
},
},
};
},
getCustomWidgets: jest.fn(() => {
return {
CUSTOMWIDGET: widgetMock,
};
}),
};
const { graph, ez } = await start({
async preSetup(app) {
app.registerExtension(mockExtension);
},
});
expect(mockExtension.getCustomWidgets).toBeCalledTimes(1);
graph.clear();
expect(widgetMock).toBeCalledTimes(0);
const node = ez.TestNode();
expect(widgetMock).toBeCalledTimes(1);
// Ensure our custom widget is created
expect(node.inputs.length).toBe(0);
expect(node.widgets.length).toBe(1);
const w = node.widgets[0].widget;
expect(w.name).toBe("test_input");
expect(w.type).toBe("button");
});
});

View File

@ -0,0 +1,818 @@
// @ts-check
/// <reference path="../node_modules/@types/jest/index.d.ts" />
const { start, createDefaultWorkflow } = require("../utils");
const lg = require("../utils/litegraph");
describe("group node", () => {
beforeEach(() => {
lg.setup(global);
});
afterEach(() => {
lg.teardown(global);
});
/**
*
* @param {*} app
* @param {*} graph
* @param {*} name
* @param {*} nodes
* @returns { Promise<InstanceType<import("../utils/ezgraph")["EzNode"]>> }
*/
async function convertToGroup(app, graph, name, nodes) {
// Select the nodes we are converting
for (const n of nodes) {
n.select(true);
}
expect(Object.keys(app.canvas.selected_nodes).sort((a, b) => +a - +b)).toEqual(
nodes.map((n) => n.id + "").sort((a, b) => +a - +b)
);
global.prompt = jest.fn().mockImplementation(() => name);
const groupNode = await nodes[0].menu["Convert to Group Node"].call(false);
// Check group name was requested
expect(window.prompt).toHaveBeenCalled();
// Ensure old nodes are removed
for (const n of nodes) {
expect(n.isRemoved).toBeTruthy();
}
expect(groupNode.type).toEqual("workflow/" + name);
return graph.find(groupNode);
}
/**
* @param { Record<string, string | number> | number[] } idMap
* @param { Record<string, Record<string, unknown>> } valueMap
*/
function getOutput(idMap = {}, valueMap = {}) {
if (idMap instanceof Array) {
idMap = idMap.reduce((p, n) => {
p[n] = n + "";
return p;
}, {});
}
const expected = {
1: { inputs: { ckpt_name: "model1.safetensors", ...valueMap?.[1] }, class_type: "CheckpointLoaderSimple" },
2: { inputs: { text: "positive", clip: ["1", 1], ...valueMap?.[2] }, class_type: "CLIPTextEncode" },
3: { inputs: { text: "negative", clip: ["1", 1], ...valueMap?.[3] }, class_type: "CLIPTextEncode" },
4: { inputs: { width: 512, height: 512, batch_size: 1, ...valueMap?.[4] }, class_type: "EmptyLatentImage" },
5: {
inputs: {
seed: 0,
steps: 20,
cfg: 8,
sampler_name: "euler",
scheduler: "normal",
denoise: 1,
model: ["1", 0],
positive: ["2", 0],
negative: ["3", 0],
latent_image: ["4", 0],
...valueMap?.[5],
},
class_type: "KSampler",
},
6: { inputs: { samples: ["5", 0], vae: ["1", 2], ...valueMap?.[6] }, class_type: "VAEDecode" },
7: { inputs: { filename_prefix: "ComfyUI", images: ["6", 0], ...valueMap?.[7] }, class_type: "SaveImage" },
};
// Map old IDs to new at the top level
const mapped = {};
for (const oldId in idMap) {
mapped[idMap[oldId]] = expected[oldId];
delete expected[oldId];
}
Object.assign(mapped, expected);
// Map old IDs to new inside links
for (const k in mapped) {
for (const input in mapped[k].inputs) {
const v = mapped[k].inputs[input];
if (v instanceof Array) {
if (v[0] in idMap) {
v[0] = idMap[v[0]] + "";
}
}
}
}
return mapped;
}
test("can be created from selected nodes", async () => {
const { ez, graph, app } = await start();
const nodes = createDefaultWorkflow(ez, graph);
const group = await convertToGroup(app, graph, "test", [nodes.pos, nodes.neg, nodes.empty]);
// Ensure links are now to the group node
expect(group.inputs).toHaveLength(2);
expect(group.outputs).toHaveLength(3);
expect(group.inputs.map((i) => i.input.name)).toEqual(["clip", "CLIPTextEncode clip"]);
expect(group.outputs.map((i) => i.output.name)).toEqual(["LATENT", "CONDITIONING", "CLIPTextEncode CONDITIONING"]);
// ckpt clip to both clip inputs on the group
expect(nodes.ckpt.outputs.CLIP.connections.map((t) => [t.targetNode.id, t.targetInput.index])).toEqual([
[group.id, 0],
[group.id, 1],
]);
// group conditioning to sampler
expect(group.outputs["CONDITIONING"].connections.map((t) => [t.targetNode.id, t.targetInput.index])).toEqual([
[nodes.sampler.id, 1],
]);
// group conditioning 2 to sampler
expect(
group.outputs["CLIPTextEncode CONDITIONING"].connections.map((t) => [t.targetNode.id, t.targetInput.index])
).toEqual([[nodes.sampler.id, 2]]);
// group latent to sampler
expect(group.outputs["LATENT"].connections.map((t) => [t.targetNode.id, t.targetInput.index])).toEqual([
[nodes.sampler.id, 3],
]);
});
test("maintains all output links on conversion", async () => {
const { ez, graph, app } = await start();
const nodes = createDefaultWorkflow(ez, graph);
const save2 = ez.SaveImage(...nodes.decode.outputs);
const save3 = ez.SaveImage(...nodes.decode.outputs);
// Ensure an output with multiple links maintains them on convert to group
const group = await convertToGroup(app, graph, "test", [nodes.sampler, nodes.decode]);
expect(group.outputs[0].connections.length).toBe(3);
expect(group.outputs[0].connections[0].targetNode.id).toBe(nodes.save.id);
expect(group.outputs[0].connections[1].targetNode.id).toBe(save2.id);
expect(group.outputs[0].connections[2].targetNode.id).toBe(save3.id);
// and they're still linked when converting back to nodes
const newNodes = group.menu["Convert to nodes"].call();
const decode = graph.find(newNodes.find((n) => n.type === "VAEDecode"));
expect(decode.outputs[0].connections.length).toBe(3);
expect(decode.outputs[0].connections[0].targetNode.id).toBe(nodes.save.id);
expect(decode.outputs[0].connections[1].targetNode.id).toBe(save2.id);
expect(decode.outputs[0].connections[2].targetNode.id).toBe(save3.id);
});
test("can be be converted back to nodes", async () => {
const { ez, graph, app } = await start();
const nodes = createDefaultWorkflow(ez, graph);
const toConvert = [nodes.pos, nodes.neg, nodes.empty, nodes.sampler];
const group = await convertToGroup(app, graph, "test", toConvert);
// Edit some values to ensure they are set back onto the converted nodes
expect(group.widgets["text"].value).toBe("positive");
group.widgets["text"].value = "pos";
expect(group.widgets["CLIPTextEncode text"].value).toBe("negative");
group.widgets["CLIPTextEncode text"].value = "neg";
expect(group.widgets["width"].value).toBe(512);
group.widgets["width"].value = 1024;
expect(group.widgets["sampler_name"].value).toBe("euler");
group.widgets["sampler_name"].value = "ddim";
expect(group.widgets["control_after_generate"].value).toBe("randomize");
group.widgets["control_after_generate"].value = "fixed";
/** @type { Array<any> } */
group.menu["Convert to nodes"].call();
// ensure widget values are set
const pos = graph.find(nodes.pos.id);
expect(pos.node.type).toBe("CLIPTextEncode");
expect(pos.widgets["text"].value).toBe("pos");
const neg = graph.find(nodes.neg.id);
expect(neg.node.type).toBe("CLIPTextEncode");
expect(neg.widgets["text"].value).toBe("neg");
const empty = graph.find(nodes.empty.id);
expect(empty.node.type).toBe("EmptyLatentImage");
expect(empty.widgets["width"].value).toBe(1024);
const sampler = graph.find(nodes.sampler.id);
expect(sampler.node.type).toBe("KSampler");
expect(sampler.widgets["sampler_name"].value).toBe("ddim");
expect(sampler.widgets["control_after_generate"].value).toBe("fixed");
// validate links
expect(nodes.ckpt.outputs.CLIP.connections.map((t) => [t.targetNode.id, t.targetInput.index])).toEqual([
[pos.id, 0],
[neg.id, 0],
]);
expect(pos.outputs["CONDITIONING"].connections.map((t) => [t.targetNode.id, t.targetInput.index])).toEqual([
[nodes.sampler.id, 1],
]);
expect(neg.outputs["CONDITIONING"].connections.map((t) => [t.targetNode.id, t.targetInput.index])).toEqual([
[nodes.sampler.id, 2],
]);
expect(empty.outputs["LATENT"].connections.map((t) => [t.targetNode.id, t.targetInput.index])).toEqual([
[nodes.sampler.id, 3],
]);
});
test("it can embed reroutes as inputs", async () => {
const { ez, graph, app } = await start();
const nodes = createDefaultWorkflow(ez, graph);
// Add and connect a reroute to the clip text encodes
const reroute = ez.Reroute();
nodes.ckpt.outputs.CLIP.connectTo(reroute.inputs[0]);
reroute.outputs[0].connectTo(nodes.pos.inputs[0]);
reroute.outputs[0].connectTo(nodes.neg.inputs[0]);
// Convert to group and ensure we only have 1 input of the correct type
const group = await convertToGroup(app, graph, "test", [nodes.pos, nodes.neg, nodes.empty, reroute]);
expect(group.inputs).toHaveLength(1);
expect(group.inputs[0].input.type).toEqual("CLIP");
expect((await graph.toPrompt()).output).toEqual(getOutput());
});
test("it can embed reroutes as outputs", async () => {
const { ez, graph, app } = await start();
const nodes = createDefaultWorkflow(ez, graph);
// Add a reroute with no output so we output IMAGE even though its used internally
const reroute = ez.Reroute();
nodes.decode.outputs.IMAGE.connectTo(reroute.inputs[0]);
// Convert to group and ensure there is an IMAGE output
const group = await convertToGroup(app, graph, "test", [nodes.decode, nodes.save, reroute]);
expect(group.outputs).toHaveLength(1);
expect(group.outputs[0].output.type).toEqual("IMAGE");
expect((await graph.toPrompt()).output).toEqual(getOutput([nodes.decode.id, nodes.save.id]));
});
test("it can embed reroutes as pipes", async () => {
const { ez, graph, app } = await start();
const nodes = createDefaultWorkflow(ez, graph);
// Use reroutes as a pipe
const rerouteModel = ez.Reroute();
const rerouteClip = ez.Reroute();
const rerouteVae = ez.Reroute();
nodes.ckpt.outputs.MODEL.connectTo(rerouteModel.inputs[0]);
nodes.ckpt.outputs.CLIP.connectTo(rerouteClip.inputs[0]);
nodes.ckpt.outputs.VAE.connectTo(rerouteVae.inputs[0]);
const group = await convertToGroup(app, graph, "test", [rerouteModel, rerouteClip, rerouteVae]);
expect(group.outputs).toHaveLength(3);
expect(group.outputs.map((o) => o.output.type)).toEqual(["MODEL", "CLIP", "VAE"]);
expect(group.outputs).toHaveLength(3);
expect(group.outputs.map((o) => o.output.type)).toEqual(["MODEL", "CLIP", "VAE"]);
group.outputs[0].connectTo(nodes.sampler.inputs.model);
group.outputs[1].connectTo(nodes.pos.inputs.clip);
group.outputs[1].connectTo(nodes.neg.inputs.clip);
});
test("can handle reroutes used internally", async () => {
const { ez, graph, app } = await start();
const nodes = createDefaultWorkflow(ez, graph);
let reroutes = [];
let prevNode = nodes.ckpt;
for(let i = 0; i < 5; i++) {
const reroute = ez.Reroute();
prevNode.outputs[0].connectTo(reroute.inputs[0]);
prevNode = reroute;
reroutes.push(reroute);
}
prevNode.outputs[0].connectTo(nodes.sampler.inputs.model);
const group = await convertToGroup(app, graph, "test", [...reroutes, ...Object.values(nodes)]);
expect((await graph.toPrompt()).output).toEqual(getOutput());
group.menu["Convert to nodes"].call();
expect((await graph.toPrompt()).output).toEqual(getOutput());
});
test("creates with widget values from inner nodes", async () => {
const { ez, graph, app } = await start();
const nodes = createDefaultWorkflow(ez, graph);
nodes.ckpt.widgets.ckpt_name.value = "model2.ckpt";
nodes.pos.widgets.text.value = "hello";
nodes.neg.widgets.text.value = "world";
nodes.empty.widgets.width.value = 256;
nodes.empty.widgets.height.value = 1024;
nodes.sampler.widgets.seed.value = 1;
nodes.sampler.widgets.control_after_generate.value = "increment";
nodes.sampler.widgets.steps.value = 8;
nodes.sampler.widgets.cfg.value = 4.5;
nodes.sampler.widgets.sampler_name.value = "uni_pc";
nodes.sampler.widgets.scheduler.value = "karras";
nodes.sampler.widgets.denoise.value = 0.9;
const group = await convertToGroup(app, graph, "test", [
nodes.ckpt,
nodes.pos,
nodes.neg,
nodes.empty,
nodes.sampler,
]);
expect(group.widgets["ckpt_name"].value).toEqual("model2.ckpt");
expect(group.widgets["text"].value).toEqual("hello");
expect(group.widgets["CLIPTextEncode text"].value).toEqual("world");
expect(group.widgets["width"].value).toEqual(256);
expect(group.widgets["height"].value).toEqual(1024);
expect(group.widgets["seed"].value).toEqual(1);
expect(group.widgets["control_after_generate"].value).toEqual("increment");
expect(group.widgets["steps"].value).toEqual(8);
expect(group.widgets["cfg"].value).toEqual(4.5);
expect(group.widgets["sampler_name"].value).toEqual("uni_pc");
expect(group.widgets["scheduler"].value).toEqual("karras");
expect(group.widgets["denoise"].value).toEqual(0.9);
expect((await graph.toPrompt()).output).toEqual(
getOutput([nodes.ckpt.id, nodes.pos.id, nodes.neg.id, nodes.empty.id, nodes.sampler.id], {
[nodes.ckpt.id]: { ckpt_name: "model2.ckpt" },
[nodes.pos.id]: { text: "hello" },
[nodes.neg.id]: { text: "world" },
[nodes.empty.id]: { width: 256, height: 1024 },
[nodes.sampler.id]: {
seed: 1,
steps: 8,
cfg: 4.5,
sampler_name: "uni_pc",
scheduler: "karras",
denoise: 0.9,
},
})
);
});
test("group inputs can be reroutes", async () => {
const { ez, graph, app } = await start();
const nodes = createDefaultWorkflow(ez, graph);
const group = await convertToGroup(app, graph, "test", [nodes.pos, nodes.neg]);
const reroute = ez.Reroute();
nodes.ckpt.outputs.CLIP.connectTo(reroute.inputs[0]);
reroute.outputs[0].connectTo(group.inputs[0]);
reroute.outputs[0].connectTo(group.inputs[1]);
expect((await graph.toPrompt()).output).toEqual(getOutput([nodes.pos.id, nodes.neg.id]));
});
test("group outputs can be reroutes", async () => {
const { ez, graph, app } = await start();
const nodes = createDefaultWorkflow(ez, graph);
const group = await convertToGroup(app, graph, "test", [nodes.pos, nodes.neg]);
const reroute1 = ez.Reroute();
const reroute2 = ez.Reroute();
group.outputs[0].connectTo(reroute1.inputs[0]);
group.outputs[1].connectTo(reroute2.inputs[0]);
reroute1.outputs[0].connectTo(nodes.sampler.inputs.positive);
reroute2.outputs[0].connectTo(nodes.sampler.inputs.negative);
expect((await graph.toPrompt()).output).toEqual(getOutput([nodes.pos.id, nodes.neg.id]));
});
test("groups can connect to each other", async () => {
const { ez, graph, app } = await start();
const nodes = createDefaultWorkflow(ez, graph);
const group1 = await convertToGroup(app, graph, "test", [nodes.pos, nodes.neg]);
const group2 = await convertToGroup(app, graph, "test2", [nodes.empty, nodes.sampler]);
group1.outputs[0].connectTo(group2.inputs["positive"]);
group1.outputs[1].connectTo(group2.inputs["negative"]);
expect((await graph.toPrompt()).output).toEqual(
getOutput([nodes.pos.id, nodes.neg.id, nodes.empty.id, nodes.sampler.id])
);
});
test("displays generated image on group node", async () => {
const { ez, graph, app } = await start();
const nodes = createDefaultWorkflow(ez, graph);
let group = await convertToGroup(app, graph, "test", [
nodes.pos,
nodes.neg,
nodes.empty,
nodes.sampler,
nodes.decode,
nodes.save,
]);
const { api } = require("../../web/scripts/api");
api.dispatchEvent(new CustomEvent("execution_start", {}));
api.dispatchEvent(new CustomEvent("executing", { detail: `${nodes.save.id}` }));
// Event should be forwarded to group node id
expect(+app.runningNodeId).toEqual(group.id);
expect(group.node["imgs"]).toBeFalsy();
api.dispatchEvent(
new CustomEvent("executed", {
detail: {
node: `${nodes.save.id}`,
output: {
images: [
{
filename: "test.png",
type: "output",
},
],
},
},
})
);
// Trigger paint
group.node.onDrawBackground?.(app.canvas.ctx, app.canvas.canvas);
expect(group.node["images"]).toEqual([
{
filename: "test.png",
type: "output",
},
]);
// Reload
const workflow = JSON.stringify((await graph.toPrompt()).workflow);
await app.loadGraphData(JSON.parse(workflow));
group = graph.find(group);
// Trigger inner nodes to get created
group.node["getInnerNodes"]();
// Check it works for internal node ids
api.dispatchEvent(new CustomEvent("execution_start", {}));
api.dispatchEvent(new CustomEvent("executing", { detail: `${group.id}:5` }));
// Event should be forwarded to group node id
expect(+app.runningNodeId).toEqual(group.id);
expect(group.node["imgs"]).toBeFalsy();
api.dispatchEvent(
new CustomEvent("executed", {
detail: {
node: `${group.id}:5`,
output: {
images: [
{
filename: "test2.png",
type: "output",
},
],
},
},
})
);
// Trigger paint
group.node.onDrawBackground?.(app.canvas.ctx, app.canvas.canvas);
expect(group.node["images"]).toEqual([
{
filename: "test2.png",
type: "output",
},
]);
});
test("allows widgets to be converted to inputs", async () => {
const { ez, graph, app } = await start();
const nodes = createDefaultWorkflow(ez, graph);
const group = await convertToGroup(app, graph, "test", [nodes.pos, nodes.neg]);
group.widgets[0].convertToInput();
const primitive = ez.PrimitiveNode();
primitive.outputs[0].connectTo(group.inputs["text"]);
primitive.widgets[0].value = "hello";
expect((await graph.toPrompt()).output).toEqual(
getOutput([nodes.pos.id, nodes.neg.id], {
[nodes.pos.id]: { text: "hello" },
})
);
});
test("can be copied", async () => {
const { ez, graph, app } = await start();
const nodes = createDefaultWorkflow(ez, graph);
const group1 = await convertToGroup(app, graph, "test", [
nodes.pos,
nodes.neg,
nodes.empty,
nodes.sampler,
nodes.decode,
nodes.save,
]);
group1.widgets["text"].value = "hello";
group1.widgets["width"].value = 256;
group1.widgets["seed"].value = 1;
// Clone the node
group1.menu.Clone.call();
expect(app.graph._nodes).toHaveLength(3);
const group2 = graph.find(app.graph._nodes[2]);
expect(group2.node.type).toEqual("workflow/test");
expect(group2.id).not.toEqual(group1.id);
// Reconnect ckpt
nodes.ckpt.outputs.MODEL.connectTo(group2.inputs["model"]);
nodes.ckpt.outputs.CLIP.connectTo(group2.inputs["clip"]);
nodes.ckpt.outputs.CLIP.connectTo(group2.inputs["CLIPTextEncode clip"]);
nodes.ckpt.outputs.VAE.connectTo(group2.inputs["vae"]);
group2.widgets["text"].value = "world";
group2.widgets["width"].value = 1024;
group2.widgets["seed"].value = 100;
let i = 0;
expect((await graph.toPrompt()).output).toEqual({
...getOutput([nodes.empty.id, nodes.pos.id, nodes.neg.id, nodes.sampler.id, nodes.decode.id, nodes.save.id], {
[nodes.empty.id]: { width: 256 },
[nodes.pos.id]: { text: "hello" },
[nodes.sampler.id]: { seed: 1 },
}),
...getOutput(
{
[nodes.empty.id]: `${group2.id}:${i++}`,
[nodes.pos.id]: `${group2.id}:${i++}`,
[nodes.neg.id]: `${group2.id}:${i++}`,
[nodes.sampler.id]: `${group2.id}:${i++}`,
[nodes.decode.id]: `${group2.id}:${i++}`,
[nodes.save.id]: `${group2.id}:${i++}`,
},
{
[nodes.empty.id]: { width: 1024 },
[nodes.pos.id]: { text: "world" },
[nodes.sampler.id]: { seed: 100 },
}
),
});
graph.arrange();
});
test("is embedded in workflow", async () => {
let { ez, graph, app } = await start();
const nodes = createDefaultWorkflow(ez, graph);
let group = await convertToGroup(app, graph, "test", [nodes.pos, nodes.neg]);
const workflow = JSON.stringify((await graph.toPrompt()).workflow);
// Clear the environment
({ ez, graph, app } = await start({
resetEnv: true,
}));
// Ensure the node isnt registered
expect(() => ez["workflow/test"]).toThrow();
// Reload the workflow
await app.loadGraphData(JSON.parse(workflow));
// Ensure the node is found
group = graph.find(group);
// Generate prompt and ensure it is as expected
expect((await graph.toPrompt()).output).toEqual(
getOutput({
[nodes.pos.id]: `${group.id}:0`,
[nodes.neg.id]: `${group.id}:1`,
})
);
});
test("shows missing node error on missing internal node when loading graph data", async () => {
const { graph } = await start();
const dialogShow = jest.spyOn(graph.app.ui.dialog, "show");
await graph.app.loadGraphData({
last_node_id: 3,
last_link_id: 1,
nodes: [
{
id: 3,
type: "workflow/testerror",
},
],
links: [],
groups: [],
config: {},
extra: {
groupNodes: {
testerror: {
nodes: [
{
type: "NotKSampler",
},
{
type: "NotVAEDecode",
},
],
},
},
},
});
expect(dialogShow).toBeCalledTimes(1);
const call = dialogShow.mock.calls[0][0].innerHTML;
expect(call).toContain("the following node types were not found");
expect(call).toContain("NotKSampler");
expect(call).toContain("NotVAEDecode");
expect(call).toContain("workflow/testerror");
});
test("maintains widget inputs on conversion back to nodes", async () => {
const { ez, graph, app } = await start();
let pos = ez.CLIPTextEncode({ text: "positive" });
pos.node.title = "Positive";
let neg = ez.CLIPTextEncode({ text: "negative" });
neg.node.title = "Negative";
pos.widgets.text.convertToInput();
neg.widgets.text.convertToInput();
let primitive = ez.PrimitiveNode();
primitive.outputs[0].connectTo(pos.inputs.text);
primitive.outputs[0].connectTo(neg.inputs.text);
const group = await convertToGroup(app, graph, "test", [pos, neg, primitive]);
// This will use a primitive widget named 'value'
expect(group.widgets.length).toBe(1);
expect(group.widgets["value"].value).toBe("positive");
const newNodes = group.menu["Convert to nodes"].call();
pos = graph.find(newNodes.find((n) => n.title === "Positive"));
neg = graph.find(newNodes.find((n) => n.title === "Negative"));
primitive = graph.find(newNodes.find((n) => n.type === "PrimitiveNode"));
expect(pos.inputs).toHaveLength(2);
expect(neg.inputs).toHaveLength(2);
expect(primitive.outputs[0].connections).toHaveLength(2);
expect((await graph.toPrompt()).output).toEqual({
1: { inputs: { text: "positive" }, class_type: "CLIPTextEncode" },
2: { inputs: { text: "positive" }, class_type: "CLIPTextEncode" },
});
});
test("adds widgets in node execution order", async () => {
const { ez, graph, app } = await start();
const scale = ez.LatentUpscale();
const save = ez.SaveImage();
const empty = ez.EmptyLatentImage();
const decode = ez.VAEDecode();
scale.outputs.LATENT.connectTo(decode.inputs.samples);
decode.outputs.IMAGE.connectTo(save.inputs.images);
empty.outputs.LATENT.connectTo(scale.inputs.samples);
const group = await convertToGroup(app, graph, "test", [scale, save, empty, decode]);
const widgets = group.widgets.map((w) => w.widget.name);
expect(widgets).toStrictEqual([
"width",
"height",
"batch_size",
"upscale_method",
"LatentUpscale width",
"LatentUpscale height",
"crop",
"filename_prefix",
]);
});
test("adds output for external links when converting to group", async () => {
const { ez, graph, app } = await start();
const img = ez.EmptyLatentImage();
let decode = ez.VAEDecode(...img.outputs);
const preview1 = ez.PreviewImage(...decode.outputs);
const preview2 = ez.PreviewImage(...decode.outputs);
const group = await convertToGroup(app, graph, "test", [img, decode, preview1]);
// Ensure we have an output connected to the 2nd preview node
expect(group.outputs.length).toBe(1);
expect(group.outputs[0].connections.length).toBe(1);
expect(group.outputs[0].connections[0].targetNode.id).toBe(preview2.id);
// Convert back and ensure bothe previews are still connected
group.menu["Convert to nodes"].call();
decode = graph.find(decode);
expect(decode.outputs[0].connections.length).toBe(2);
expect(decode.outputs[0].connections[0].targetNode.id).toBe(preview1.id);
expect(decode.outputs[0].connections[1].targetNode.id).toBe(preview2.id);
});
test("adds output for external links when converting to group when nodes are not in execution order", async () => {
const { ez, graph, app } = await start();
const sampler = ez.KSampler();
const ckpt = ez.CheckpointLoaderSimple();
const empty = ez.EmptyLatentImage();
const pos = ez.CLIPTextEncode(ckpt.outputs.CLIP, { text: "positive" });
const neg = ez.CLIPTextEncode(ckpt.outputs.CLIP, { text: "negative" });
const decode1 = ez.VAEDecode(sampler.outputs.LATENT, ckpt.outputs.VAE);
const save = ez.SaveImage(decode1.outputs.IMAGE);
ckpt.outputs.MODEL.connectTo(sampler.inputs.model);
pos.outputs.CONDITIONING.connectTo(sampler.inputs.positive);
neg.outputs.CONDITIONING.connectTo(sampler.inputs.negative);
empty.outputs.LATENT.connectTo(sampler.inputs.latent_image);
const encode = ez.VAEEncode(decode1.outputs.IMAGE);
const vae = ez.VAELoader();
const decode2 = ez.VAEDecode(encode.outputs.LATENT, vae.outputs.VAE);
const preview = ez.PreviewImage(decode2.outputs.IMAGE);
vae.outputs.VAE.connectTo(encode.inputs.vae);
const group = await convertToGroup(app, graph, "test", [vae, decode1, encode, sampler]);
expect(group.outputs.length).toBe(3);
expect(group.outputs[0].output.name).toBe("VAE");
expect(group.outputs[0].output.type).toBe("VAE");
expect(group.outputs[1].output.name).toBe("IMAGE");
expect(group.outputs[1].output.type).toBe("IMAGE");
expect(group.outputs[2].output.name).toBe("LATENT");
expect(group.outputs[2].output.type).toBe("LATENT");
expect(group.outputs[0].connections.length).toBe(1);
expect(group.outputs[0].connections[0].targetNode.id).toBe(decode2.id);
expect(group.outputs[0].connections[0].targetInput.index).toBe(1);
expect(group.outputs[1].connections.length).toBe(1);
expect(group.outputs[1].connections[0].targetNode.id).toBe(save.id);
expect(group.outputs[1].connections[0].targetInput.index).toBe(0);
expect(group.outputs[2].connections.length).toBe(1);
expect(group.outputs[2].connections[0].targetNode.id).toBe(decode2.id);
expect(group.outputs[2].connections[0].targetInput.index).toBe(0);
expect((await graph.toPrompt()).output).toEqual({
...getOutput({ 1: ckpt.id, 2: pos.id, 3: neg.id, 4: empty.id, 5: sampler.id, 6: decode1.id, 7: save.id }),
[vae.id]: { inputs: { vae_name: "vae1.safetensors" }, class_type: vae.node.type },
[encode.id]: { inputs: { pixels: ["6", 0], vae: [vae.id + "", 0] }, class_type: encode.node.type },
[decode2.id]: { inputs: { samples: [encode.id + "", 0], vae: [vae.id + "", 0] }, class_type: decode2.node.type },
[preview.id]: { inputs: { images: [decode2.id + "", 0] }, class_type: preview.node.type },
});
});
test("works with IMAGEUPLOAD widget", async () => {
const { ez, graph, app } = await start();
const img = ez.LoadImage();
const preview1 = ez.PreviewImage(img.outputs[0]);
const group = await convertToGroup(app, graph, "test", [img, preview1]);
const widget = group.widgets["upload"];
expect(widget).toBeTruthy();
expect(widget.widget.type).toBe("button");
});
test("internal primitive populates widgets for all linked inputs", async () => {
const { ez, graph, app } = await start();
const img = ez.LoadImage();
const scale1 = ez.ImageScale(img.outputs[0]);
const scale2 = ez.ImageScale(img.outputs[0]);
ez.PreviewImage(scale1.outputs[0]);
ez.PreviewImage(scale2.outputs[0]);
scale1.widgets.width.convertToInput();
scale2.widgets.height.convertToInput();
const primitive = ez.PrimitiveNode();
primitive.outputs[0].connectTo(scale1.inputs.width);
primitive.outputs[0].connectTo(scale2.inputs.height);
const group = await convertToGroup(app, graph, "test", [img, primitive, scale1, scale2]);
group.widgets.value.value = 100;
expect((await graph.toPrompt()).output).toEqual({
1: {
inputs: { image: img.widgets.image.value, upload: "image" },
class_type: "LoadImage",
},
2: {
inputs: { upscale_method: "nearest-exact", width: 100, height: 512, crop: "disabled", image: ["1", 0] },
class_type: "ImageScale",
},
3: {
inputs: { upscale_method: "nearest-exact", width: 512, height: 100, crop: "disabled", image: ["1", 0] },
class_type: "ImageScale",
},
4: { inputs: { images: ["2", 0] }, class_type: "PreviewImage" },
5: { inputs: { images: ["3", 0] }, class_type: "PreviewImage" },
});
});
test("primitive control widgets values are copied on convert", async () => {
const { ez, graph, app } = await start();
const sampler = ez.KSampler();
sampler.widgets.seed.convertToInput();
sampler.widgets.sampler_name.convertToInput();
let p1 = ez.PrimitiveNode();
let p2 = ez.PrimitiveNode();
p1.outputs[0].connectTo(sampler.inputs.seed);
p2.outputs[0].connectTo(sampler.inputs.sampler_name);
p1.widgets.control_after_generate.value = "increment";
p2.widgets.control_after_generate.value = "decrement";
p2.widgets.control_filter_list.value = "/.*/";
p2.node.title = "p2";
const group = await convertToGroup(app, graph, "test", [sampler, p1, p2]);
expect(group.widgets.control_after_generate.value).toBe("increment");
expect(group.widgets["p2 control_after_generate"].value).toBe("decrement");
expect(group.widgets["p2 control_filter_list"].value).toBe("/.*/");
group.widgets.control_after_generate.value = "fixed";
group.widgets["p2 control_after_generate"].value = "randomize";
group.widgets["p2 control_filter_list"].value = "/.+/";
group.menu["Convert to nodes"].call();
p1 = graph.find(p1);
p2 = graph.find(p2);
expect(p1.widgets.control_after_generate.value).toBe("fixed");
expect(p2.widgets.control_after_generate.value).toBe("randomize");
expect(p2.widgets.control_filter_list.value).toBe("/.+/");
});
});

View File

@ -14,10 +14,10 @@ const lg = require("../utils/litegraph");
* @param { InstanceType<Ez["EzGraph"]> } graph
* @param { InstanceType<Ez["EzInput"]> } input
* @param { string } widgetType
* @param { boolean } hasControlWidget
* @param { number } controlWidgetCount
* @returns
*/
async function connectPrimitiveAndReload(ez, graph, input, widgetType, hasControlWidget) {
async function connectPrimitiveAndReload(ez, graph, input, widgetType, controlWidgetCount = 0) {
// Connect to primitive and ensure its still connected after
let primitive = ez.PrimitiveNode();
primitive.outputs[0].connectTo(input);
@ -33,13 +33,17 @@ async function connectPrimitiveAndReload(ez, graph, input, widgetType, hasContro
expect(valueWidget.widget.type).toBe(widgetType);
// Check if control_after_generate should be added
if (hasControlWidget) {
if (controlWidgetCount) {
const controlWidget = primitive.widgets.control_after_generate;
expect(controlWidget.widget.type).toBe("combo");
if(widgetType === "combo") {
const filterWidget = primitive.widgets.control_filter_list;
expect(filterWidget.widget.type).toBe("string");
}
}
// Ensure we dont have other widgets
expect(primitive.node.widgets).toHaveLength(1 + +!!hasControlWidget);
expect(primitive.node.widgets).toHaveLength(1 + controlWidgetCount);
});
return primitive;
@ -55,8 +59,8 @@ describe("widget inputs", () => {
});
[
{ name: "int", type: "INT", widget: "number", control: true },
{ name: "float", type: "FLOAT", widget: "number", control: true },
{ name: "int", type: "INT", widget: "number", control: 1 },
{ name: "float", type: "FLOAT", widget: "number", control: 1 },
{ name: "text", type: "STRING" },
{
name: "customtext",
@ -64,7 +68,7 @@ describe("widget inputs", () => {
opt: { multiline: true },
},
{ name: "toggle", type: "BOOLEAN" },
{ name: "combo", type: ["a", "b", "c"], control: true },
{ name: "combo", type: ["a", "b", "c"], control: 2 },
].forEach((c) => {
test(`widget conversion + primitive works on ${c.name}`, async () => {
const { ez, graph } = await start({
@ -106,7 +110,7 @@ describe("widget inputs", () => {
n.widgets.ckpt_name.convertToInput();
expect(n.inputs.length).toEqual(inputCount + 1);
const primitive = await connectPrimitiveAndReload(ez, graph, n.inputs.ckpt_name, "combo", true);
const primitive = await connectPrimitiveAndReload(ez, graph, n.inputs.ckpt_name, "combo", 2);
// Disconnect & reconnect
primitive.outputs[0].connections[0].disconnect();
@ -198,8 +202,8 @@ describe("widget inputs", () => {
});
expect(dialogShow).toBeCalledTimes(1);
expect(dialogShow.mock.calls[0][0]).toContain("the following node types were not found");
expect(dialogShow.mock.calls[0][0]).toContain("TestNode");
expect(dialogShow.mock.calls[0][0].innerHTML).toContain("the following node types were not found");
expect(dialogShow.mock.calls[0][0].innerHTML).toContain("TestNode");
});
test("defaultInput widgets can be converted back to inputs", async () => {
@ -226,7 +230,7 @@ describe("widget inputs", () => {
// Reload and ensure it still only has 1 converted widget
if (!assertNotNullOrUndefined(input)) return;
await connectPrimitiveAndReload(ez, graph, input, "number", true);
await connectPrimitiveAndReload(ez, graph, input, "number", 1);
n = graph.find(n);
expect(n.widgets).toHaveLength(1);
w = n.widgets.example;
@ -258,7 +262,7 @@ describe("widget inputs", () => {
// Reload and ensure it still only has 1 converted widget
if (assertNotNullOrUndefined(input)) {
await connectPrimitiveAndReload(ez, graph, input, "number", true);
await connectPrimitiveAndReload(ez, graph, input, "number", 1);
n = graph.find(n);
expect(n.widgets).toHaveLength(1);
expect(n.widgets.example.isConvertedToInput).toBeTruthy();
@ -316,4 +320,76 @@ describe("widget inputs", () => {
n1.outputs[0].connectTo(n2.inputs[0]);
expect(() => n1.outputs[0].connectTo(n3.inputs[0])).toThrow();
});
test("combo primitive can filter list when control_after_generate called", async () => {
const { ez } = await start({
mockNodeDefs: {
...makeNodeDef("TestNode1", { example: [["A", "B", "C", "D", "AA", "BB", "CC", "DD", "AAA", "BBB"], {}] }),
},
});
const n1 = ez.TestNode1();
n1.widgets.example.convertToInput();
const p = ez.PrimitiveNode()
p.outputs[0].connectTo(n1.inputs[0]);
const value = p.widgets.value;
const control = p.widgets.control_after_generate.widget;
const filter = p.widgets.control_filter_list;
expect(p.widgets.length).toBe(3);
control.value = "increment";
expect(value.value).toBe("A");
// Manually trigger after queue when set to increment
control["afterQueued"]();
expect(value.value).toBe("B");
// Filter to items containing D
filter.value = "D";
control["afterQueued"]();
expect(value.value).toBe("D");
control["afterQueued"]();
expect(value.value).toBe("DD");
// Check decrement
value.value = "BBB";
control.value = "decrement";
filter.value = "B";
control["afterQueued"]();
expect(value.value).toBe("BB");
control["afterQueued"]();
expect(value.value).toBe("B");
// Check regex works
value.value = "BBB";
filter.value = "/[AB]|^C$/";
control["afterQueued"]();
expect(value.value).toBe("AAA");
control["afterQueued"]();
expect(value.value).toBe("BB");
control["afterQueued"]();
expect(value.value).toBe("AA");
control["afterQueued"]();
expect(value.value).toBe("C");
control["afterQueued"]();
expect(value.value).toBe("B");
control["afterQueued"]();
expect(value.value).toBe("A");
// Check random
control.value = "randomize";
filter.value = "/D/";
for(let i = 0; i < 100; i++) {
control["afterQueued"]();
expect(value.value === "D" || value.value === "DD").toBeTruthy();
}
// Ensure it doesnt apply when fixed
control.value = "fixed";
value.value = "B";
filter.value = "C";
control["afterQueued"]();
expect(value.value).toBe("B");
});
});

View File

@ -150,7 +150,7 @@ export class EzNodeMenuItem {
if (selectNode) {
this.node.select();
}
this.item.callback.call(this.node.node, undefined, undefined, undefined, undefined, this.node.node);
return this.item.callback.call(this.node.node, undefined, undefined, undefined, undefined, this.node.node);
}
}
@ -240,8 +240,12 @@ export class EzNode {
return this.#makeLookupArray(() => this.app.canvas.getNodeMenuOptions(this.node), "content", EzNodeMenuItem);
}
select() {
this.app.canvas.selectNode(this.node);
get isRemoved() {
return !this.app.graph.getNodeById(this.id);
}
select(addToSelection = false) {
this.app.canvas.selectNode(this.node, addToSelection);
}
// /**
@ -275,12 +279,17 @@ export class EzNode {
if (!s) return p;
const name = s[nameProperty];
const item = new ctor(this, i, s);
// @ts-ignore
if (!name || name in p) {
throw new Error(`Unable to store ${nodeProperty} ${name} on array as name conflicts.`);
p.push(item);
if (name) {
// @ts-ignore
if (name in p) {
throw new Error(`Unable to store ${nodeProperty} ${name} on array as name conflicts.`);
}
}
// @ts-ignore
p.push((p[name] = new ctor(this, i, s)));
p[name] = item;
return p;
}, Object.assign([], { $: this }));
}
@ -348,6 +357,19 @@ export class EzGraph {
}, 10);
});
}
/**
* @returns { Promise<{
* workflow: {},
* output: Record<string, {
* class_name: string,
* inputs: Record<string, [string, number] | unknown>
* }>}> }
*/
toPrompt() {
// @ts-ignore
return this.app.graphToPrompt();
}
}
export const Ez = {
@ -356,12 +378,12 @@ export const Ez = {
* @example
* const { ez, graph } = Ez.graph(app);
* graph.clear();
* const [model, clip, vae] = ez.CheckpointLoaderSimple();
* const [pos] = ez.CLIPTextEncode(clip, { text: "positive" });
* const [neg] = ez.CLIPTextEncode(clip, { text: "negative" });
* const [latent] = ez.KSampler(model, pos, neg, ...ez.EmptyLatentImage());
* const [image] = ez.VAEDecode(latent, vae);
* const saveNode = ez.SaveImage(image).node;
* const [model, clip, vae] = ez.CheckpointLoaderSimple().outputs;
* const [pos] = ez.CLIPTextEncode(clip, { text: "positive" }).outputs;
* const [neg] = ez.CLIPTextEncode(clip, { text: "negative" }).outputs;
* const [latent] = ez.KSampler(model, pos, neg, ...ez.EmptyLatentImage().outputs).outputs;
* const [image] = ez.VAEDecode(latent, vae).outputs;
* const saveNode = ez.SaveImage(image);
* console.log(saveNode);
* graph.arrange();
* @param { app } app

View File

@ -1,21 +1,29 @@
const { mockApi } = require("./setup");
const { Ez } = require("./ezgraph");
const lg = require("./litegraph");
/**
*
* @param { Parameters<mockApi>[0] } config
* @param { Parameters<mockApi>[0] & { resetEnv?: boolean, preSetup?(app): Promise<void> } } config
* @returns
*/
export async function start(config = undefined) {
export async function start(config = {}) {
if(config.resetEnv) {
jest.resetModules();
jest.resetAllMocks();
lg.setup(global);
}
mockApi(config);
const { app } = require("../../web/scripts/app");
config.preSetup?.(app);
await app.setup();
return Ez.graph(app, global["LiteGraph"], global["LGraphCanvas"]);
return { ...Ez.graph(app, global["LiteGraph"], global["LGraphCanvas"]), app };
}
/**
* @param { ReturnType<Ez["graph"]>["graph"] } graph
* @param { (hasReloaded: boolean) => (Promise<void> | void) } cb
* @param { ReturnType<Ez["graph"]>["graph"] } graph
* @param { (hasReloaded: boolean) => (Promise<void> | void) } cb
*/
export async function checkBeforeAndAfterReload(graph, cb) {
await cb(false);
@ -24,10 +32,10 @@ export async function checkBeforeAndAfterReload(graph, cb) {
}
/**
* @param { string } name
* @param { Record<string, string | [string | string[], any]> } input
* @param { string } name
* @param { Record<string, string | [string | string[], any]> } input
* @param { (string | string[])[] | Record<string, string | string[]> } output
* @returns { Record<string, import("../../web/types/comfy").ComfyObjectInfo> }
* @returns { Record<string, import("../../web/types/comfy").ComfyObjectInfo> }
*/
export function makeNodeDef(name, input, output = {}) {
const nodeDef = {
@ -37,19 +45,19 @@ export function makeNodeDef(name, input, output = {}) {
output_name: [],
output_is_list: [],
input: {
required: {}
required: {},
},
};
for(const k in input) {
for (const k in input) {
nodeDef.input.required[k] = typeof input[k] === "string" ? [input[k], {}] : [...input[k]];
}
if(output instanceof Array) {
if (output instanceof Array) {
output = output.reduce((p, c) => {
p[c] = c;
return p;
}, {})
}, {});
}
for(const k in output) {
for (const k in output) {
nodeDef.output.push(output[k]);
nodeDef.output_name.push(k);
nodeDef.output_is_list.push(false);
@ -68,4 +76,31 @@ export function assertNotNullOrUndefined(x) {
expect(x).not.toEqual(null);
expect(x).not.toEqual(undefined);
return true;
}
}
/**
*
* @param { ReturnType<Ez["graph"]>["ez"] } ez
* @param { ReturnType<Ez["graph"]>["graph"] } graph
*/
export function createDefaultWorkflow(ez, graph) {
graph.clear();
const ckpt = ez.CheckpointLoaderSimple();
const pos = ez.CLIPTextEncode(ckpt.outputs.CLIP, { text: "positive" });
const neg = ez.CLIPTextEncode(ckpt.outputs.CLIP, { text: "negative" });
const empty = ez.EmptyLatentImage();
const sampler = ez.KSampler(
ckpt.outputs.MODEL,
pos.outputs.CONDITIONING,
neg.outputs.CONDITIONING,
empty.outputs.LATENT
);
const decode = ez.VAEDecode(sampler.outputs.LATENT, ckpt.outputs.VAE);
const save = ez.SaveImage(decode.outputs.IMAGE);
graph.arrange();
return { ckpt, pos, neg, empty, sampler, decode, save };
}

View File

@ -30,16 +30,20 @@ export function mockApi({ mockExtensions, mockNodeDefs } = {}) {
mockNodeDefs = JSON.parse(fs.readFileSync(path.resolve("./data/object_info.json")));
}
const events = new EventTarget();
const mockApi = {
addEventListener: events.addEventListener.bind(events),
removeEventListener: events.removeEventListener.bind(events),
dispatchEvent: events.dispatchEvent.bind(events),
getSystemStats: jest.fn(),
getExtensions: jest.fn(() => mockExtensions),
getNodeDefs: jest.fn(() => mockNodeDefs),
init: jest.fn(),
apiURL: jest.fn((x) => "../../web/" + x),
};
jest.mock("../../web/scripts/api", () => ({
get api() {
return {
addEventListener: jest.fn(),
getSystemStats: jest.fn(),
getExtensions: jest.fn(() => mockExtensions),
getNodeDefs: jest.fn(() => mockNodeDefs),
init: jest.fn(),
apiURL: jest.fn((x) => "../../web/" + x),
};
return mockApi;
},
}));
}

View File

@ -174,6 +174,213 @@ const colorPalettes = {
"tr-odd-bg-color": "#073642",
}
},
},
"arc": {
"id": "arc",
"name": "Arc",
"colors": {
"node_slot": {
"BOOLEAN": "",
"CLIP": "#eacb8b",
"CLIP_VISION": "#A8DADC",
"CLIP_VISION_OUTPUT": "#ad7452",
"CONDITIONING": "#cf876f",
"CONTROL_NET": "#00d78d",
"CONTROL_NET_WEIGHTS": "",
"FLOAT": "",
"GLIGEN": "",
"IMAGE": "#80a1c0",
"IMAGEUPLOAD": "",
"INT": "",
"LATENT": "#b38ead",
"LATENT_KEYFRAME": "",
"MASK": "#a3bd8d",
"MODEL": "#8978a7",
"SAMPLER": "",
"SIGMAS": "",
"STRING": "",
"STYLE_MODEL": "#C2FFAE",
"T2I_ADAPTER_WEIGHTS": "",
"TAESD": "#DCC274",
"TIMESTEP_KEYFRAME": "",
"UPSCALE_MODEL": "",
"VAE": "#be616b"
},
"litegraph_base": {
"BACKGROUND_IMAGE": "",
"CLEAR_BACKGROUND_COLOR": "#2b2f38",
"NODE_TITLE_COLOR": "#b2b7bd",
"NODE_SELECTED_TITLE_COLOR": "#FFF",
"NODE_TEXT_SIZE": 14,
"NODE_TEXT_COLOR": "#AAA",
"NODE_SUBTEXT_SIZE": 12,
"NODE_DEFAULT_COLOR": "#2b2f38",
"NODE_DEFAULT_BGCOLOR": "#242730",
"NODE_DEFAULT_BOXCOLOR": "#6e7581",
"NODE_DEFAULT_SHAPE": "box",
"NODE_BOX_OUTLINE_COLOR": "#FFF",
"DEFAULT_SHADOW_COLOR": "rgba(0,0,0,0.5)",
"DEFAULT_GROUP_FONT": 22,
"WIDGET_BGCOLOR": "#2b2f38",
"WIDGET_OUTLINE_COLOR": "#6e7581",
"WIDGET_TEXT_COLOR": "#DDD",
"WIDGET_SECONDARY_TEXT_COLOR": "#b2b7bd",
"LINK_COLOR": "#9A9",
"EVENT_LINK_COLOR": "#A86",
"CONNECTING_LINK_COLOR": "#AFA"
},
"comfy_base": {
"fg-color": "#fff",
"bg-color": "#2b2f38",
"comfy-menu-bg": "#242730",
"comfy-input-bg": "#2b2f38",
"input-text": "#ddd",
"descrip-text": "#b2b7bd",
"drag-text": "#ccc",
"error-text": "#ff4444",
"border-color": "#6e7581",
"tr-even-bg-color": "#2b2f38",
"tr-odd-bg-color": "#242730"
}
},
},
"nord": {
"id": "nord",
"name": "Nord",
"colors": {
"node_slot": {
"BOOLEAN": "",
"CLIP": "#eacb8b",
"CLIP_VISION": "#A8DADC",
"CLIP_VISION_OUTPUT": "#ad7452",
"CONDITIONING": "#cf876f",
"CONTROL_NET": "#00d78d",
"CONTROL_NET_WEIGHTS": "",
"FLOAT": "",
"GLIGEN": "",
"IMAGE": "#80a1c0",
"IMAGEUPLOAD": "",
"INT": "",
"LATENT": "#b38ead",
"LATENT_KEYFRAME": "",
"MASK": "#a3bd8d",
"MODEL": "#8978a7",
"SAMPLER": "",
"SIGMAS": "",
"STRING": "",
"STYLE_MODEL": "#C2FFAE",
"T2I_ADAPTER_WEIGHTS": "",
"TAESD": "#DCC274",
"TIMESTEP_KEYFRAME": "",
"UPSCALE_MODEL": "",
"VAE": "#be616b"
},
"litegraph_base": {
"BACKGROUND_IMAGE": "",
"CLEAR_BACKGROUND_COLOR": "#212732",
"NODE_TITLE_COLOR": "#999",
"NODE_SELECTED_TITLE_COLOR": "#e5eaf0",
"NODE_TEXT_SIZE": 14,
"NODE_TEXT_COLOR": "#bcc2c8",
"NODE_SUBTEXT_SIZE": 12,
"NODE_DEFAULT_COLOR": "#2e3440",
"NODE_DEFAULT_BGCOLOR": "#161b22",
"NODE_DEFAULT_BOXCOLOR": "#545d70",
"NODE_DEFAULT_SHAPE": "box",
"NODE_BOX_OUTLINE_COLOR": "#e5eaf0",
"DEFAULT_SHADOW_COLOR": "rgba(0,0,0,0.5)",
"DEFAULT_GROUP_FONT": 24,
"WIDGET_BGCOLOR": "#2e3440",
"WIDGET_OUTLINE_COLOR": "#545d70",
"WIDGET_TEXT_COLOR": "#bcc2c8",
"WIDGET_SECONDARY_TEXT_COLOR": "#999",
"LINK_COLOR": "#9A9",
"EVENT_LINK_COLOR": "#A86",
"CONNECTING_LINK_COLOR": "#AFA"
},
"comfy_base": {
"fg-color": "#e5eaf0",
"bg-color": "#2e3440",
"comfy-menu-bg": "#161b22",
"comfy-input-bg": "#2e3440",
"input-text": "#bcc2c8",
"descrip-text": "#999",
"drag-text": "#ccc",
"error-text": "#ff4444",
"border-color": "#545d70",
"tr-even-bg-color": "#2e3440",
"tr-odd-bg-color": "#161b22"
}
},
},
"github": {
"id": "github",
"name": "Github",
"colors": {
"node_slot": {
"BOOLEAN": "",
"CLIP": "#eacb8b",
"CLIP_VISION": "#A8DADC",
"CLIP_VISION_OUTPUT": "#ad7452",
"CONDITIONING": "#cf876f",
"CONTROL_NET": "#00d78d",
"CONTROL_NET_WEIGHTS": "",
"FLOAT": "",
"GLIGEN": "",
"IMAGE": "#80a1c0",
"IMAGEUPLOAD": "",
"INT": "",
"LATENT": "#b38ead",
"LATENT_KEYFRAME": "",
"MASK": "#a3bd8d",
"MODEL": "#8978a7",
"SAMPLER": "",
"SIGMAS": "",
"STRING": "",
"STYLE_MODEL": "#C2FFAE",
"T2I_ADAPTER_WEIGHTS": "",
"TAESD": "#DCC274",
"TIMESTEP_KEYFRAME": "",
"UPSCALE_MODEL": "",
"VAE": "#be616b"
},
"litegraph_base": {
"BACKGROUND_IMAGE": "",
"CLEAR_BACKGROUND_COLOR": "#040506",
"NODE_TITLE_COLOR": "#999",
"NODE_SELECTED_TITLE_COLOR": "#e5eaf0",
"NODE_TEXT_SIZE": 14,
"NODE_TEXT_COLOR": "#bcc2c8",
"NODE_SUBTEXT_SIZE": 12,
"NODE_DEFAULT_COLOR": "#161b22",
"NODE_DEFAULT_BGCOLOR": "#13171d",
"NODE_DEFAULT_BOXCOLOR": "#30363d",
"NODE_DEFAULT_SHAPE": "box",
"NODE_BOX_OUTLINE_COLOR": "#e5eaf0",
"DEFAULT_SHADOW_COLOR": "rgba(0,0,0,0.5)",
"DEFAULT_GROUP_FONT": 24,
"WIDGET_BGCOLOR": "#161b22",
"WIDGET_OUTLINE_COLOR": "#30363d",
"WIDGET_TEXT_COLOR": "#bcc2c8",
"WIDGET_SECONDARY_TEXT_COLOR": "#999",
"LINK_COLOR": "#9A9",
"EVENT_LINK_COLOR": "#A86",
"CONNECTING_LINK_COLOR": "#AFA"
},
"comfy_base": {
"fg-color": "#e5eaf0",
"bg-color": "#161b22",
"comfy-menu-bg": "#13171d",
"comfy-input-bg": "#161b22",
"input-text": "#bcc2c8",
"descrip-text": "#999",
"drag-text": "#ccc",
"error-text": "#ff4444",
"border-color": "#30363d",
"tr-even-bg-color": "#161b22",
"tr-odd-bg-color": "#13171d"
}
},
}
};

File diff suppressed because it is too large Load Diff

View File

@ -1,5 +1,6 @@
import { app } from "../../scripts/app.js";
import { ComfyDialog, $el } from "../../scripts/ui.js";
import { GroupNodeConfig, GroupNodeHandler } from "./groupNode.js";
// Adds the ability to save and add multiple nodes as a template
// To save:
@ -34,7 +35,7 @@ class ManageTemplates extends ComfyDialog {
type: "file",
accept: ".json",
multiple: true,
style: {display: "none"},
style: { display: "none" },
parent: document.body,
onchange: () => this.importAll(),
});
@ -109,13 +110,13 @@ class ManageTemplates extends ComfyDialog {
return;
}
const json = JSON.stringify({templates: this.templates}, null, 2); // convert the data to a JSON string
const blob = new Blob([json], {type: "application/json"});
const json = JSON.stringify({ templates: this.templates }, null, 2); // convert the data to a JSON string
const blob = new Blob([json], { type: "application/json" });
const url = URL.createObjectURL(blob);
const a = $el("a", {
href: url,
download: "node_templates.json",
style: {display: "none"},
style: { display: "none" },
parent: document.body,
});
a.click();
@ -291,11 +292,11 @@ app.registerExtension({
setup() {
const manage = new ManageTemplates();
const clipboardAction = (cb) => {
const clipboardAction = async (cb) => {
// We use the clipboard functions but dont want to overwrite the current user clipboard
// Restore it after we've run our callback
const old = localStorage.getItem("litegrapheditor_clipboard");
cb();
await cb();
localStorage.setItem("litegrapheditor_clipboard", old);
};
@ -309,13 +310,31 @@ app.registerExtension({
disabled: !Object.keys(app.canvas.selected_nodes || {}).length,
callback: () => {
const name = prompt("Enter name");
if (!name || !name.trim()) return;
if (!name?.trim()) return;
clipboardAction(() => {
app.canvas.copyToClipboard();
let data = localStorage.getItem("litegrapheditor_clipboard");
data = JSON.parse(data);
const nodeIds = Object.keys(app.canvas.selected_nodes);
for (let i = 0; i < nodeIds.length; i++) {
const node = app.graph.getNodeById(nodeIds[i]);
const nodeData = node?.constructor.nodeData;
let groupData = GroupNodeHandler.getGroupData(node);
if (groupData) {
groupData = groupData.nodeData;
if (!data.groupNodes) {
data.groupNodes = {};
}
data.groupNodes[nodeData.name] = groupData;
data.nodes[i].type = nodeData.name;
}
}
manage.templates.push({
name,
data: localStorage.getItem("litegrapheditor_clipboard"),
data: JSON.stringify(data),
});
manage.store();
});
@ -323,15 +342,19 @@ app.registerExtension({
});
// Map each template to a menu item
const subItems = manage.templates.map((t) => ({
content: t.name,
callback: () => {
clipboardAction(() => {
localStorage.setItem("litegrapheditor_clipboard", t.data);
app.canvas.pasteFromClipboard();
});
},
}));
const subItems = manage.templates.map((t) => {
return {
content: t.name,
callback: () => {
clipboardAction(async () => {
const data = JSON.parse(t.data);
await GroupNodeConfig.registerFromWorkflow(data.groupNodes, {});
localStorage.setItem("litegrapheditor_clipboard", t.data);
app.canvas.pasteFromClipboard();
});
},
};
});
subItems.push(null, {
content: "Manage",

View File

@ -0,0 +1,150 @@
import { app } from "../../scripts/app.js";
const MAX_HISTORY = 50;
let undo = [];
let redo = [];
let activeState = null;
let isOurLoad = false;
function checkState() {
const currentState = app.graph.serialize();
if (!graphEqual(activeState, currentState)) {
undo.push(activeState);
if (undo.length > MAX_HISTORY) {
undo.shift();
}
activeState = clone(currentState);
redo.length = 0;
}
}
const loadGraphData = app.loadGraphData;
app.loadGraphData = async function () {
const v = await loadGraphData.apply(this, arguments);
if (isOurLoad) {
isOurLoad = false;
} else {
checkState();
}
return v;
};
function clone(obj) {
try {
if (typeof structuredClone !== "undefined") {
return structuredClone(obj);
}
} catch (error) {
// structuredClone is stricter than using JSON.parse/stringify so fallback to that
}
return JSON.parse(JSON.stringify(obj));
}
function graphEqual(a, b, root = true) {
if (a === b) return true;
if (typeof a == "object" && a && typeof b == "object" && b) {
const keys = Object.getOwnPropertyNames(a);
if (keys.length != Object.getOwnPropertyNames(b).length) {
return false;
}
for (const key of keys) {
let av = a[key];
let bv = b[key];
if (root && key === "nodes") {
// Nodes need to be sorted as the order changes when selecting nodes
av = [...av].sort((a, b) => a.id - b.id);
bv = [...bv].sort((a, b) => a.id - b.id);
}
if (!graphEqual(av, bv, false)) {
return false;
}
}
return true;
}
return false;
}
const undoRedo = async (e) => {
if (e.ctrlKey || e.metaKey) {
if (e.key === "y") {
const prevState = redo.pop();
if (prevState) {
undo.push(activeState);
isOurLoad = true;
await app.loadGraphData(prevState);
activeState = prevState;
}
return true;
} else if (e.key === "z") {
const prevState = undo.pop();
if (prevState) {
redo.push(activeState);
isOurLoad = true;
await app.loadGraphData(prevState);
activeState = prevState;
}
return true;
}
}
};
const bindInput = (activeEl) => {
if (activeEl?.tagName !== "CANVAS" && activeEl?.tagName !== "BODY") {
for (const evt of ["change", "input", "blur"]) {
if (`on${evt}` in activeEl) {
const listener = () => {
checkState();
activeEl.removeEventListener(evt, listener);
};
activeEl.addEventListener(evt, listener);
return true;
}
}
}
};
window.addEventListener(
"keydown",
(e) => {
requestAnimationFrame(async () => {
const activeEl = document.activeElement;
if (activeEl?.tagName === "INPUT" || activeEl?.type === "textarea") {
// Ignore events on inputs, they have their native history
return;
}
// Check if this is a ctrl+z ctrl+y
if (await undoRedo(e)) return;
// If our active element is some type of input then handle changes after they're done
if (bindInput(activeEl)) return;
checkState();
});
},
true
);
// Handle clicking DOM elements (e.g. widgets)
window.addEventListener("mouseup", () => {
checkState();
});
// Handle litegraph clicks
const processMouseUp = LGraphCanvas.prototype.processMouseUp;
LGraphCanvas.prototype.processMouseUp = function (e) {
const v = processMouseUp.apply(this, arguments);
checkState();
return v;
};
const processMouseDown = LGraphCanvas.prototype.processMouseDown;
LGraphCanvas.prototype.processMouseDown = function (e) {
const v = processMouseDown.apply(this, arguments);
checkState();
return v;
};

View File

@ -1,4 +1,4 @@
import { ComfyWidgets, addValueControlWidget } from "../../scripts/widgets.js";
import { ComfyWidgets, addValueControlWidgets } from "../../scripts/widgets.js";
import { app } from "../../scripts/app.js";
const CONVERTED_TYPE = "converted-widget";
@ -121,6 +121,110 @@ function isValidCombo(combo, obj) {
return true;
}
export function mergeIfValid(output, config2, forceUpdate, recreateWidget, config1) {
if (!config1) {
config1 = output.widget[CONFIG] ?? output.widget[GET_CONFIG]();
}
if (config1[0] instanceof Array) {
if (!isValidCombo(config1[0], config2[0])) return false;
} else if (config1[0] !== config2[0]) {
// Types dont match
console.log(`connection rejected: types dont match`, config1[0], config2[0]);
return false;
}
const keys = new Set([...Object.keys(config1[1] ?? {}), ...Object.keys(config2[1] ?? {})]);
let customConfig;
const getCustomConfig = () => {
if (!customConfig) {
if (typeof structuredClone === "undefined") {
customConfig = JSON.parse(JSON.stringify(config1[1] ?? {}));
} else {
customConfig = structuredClone(config1[1] ?? {});
}
}
return customConfig;
};
const isNumber = config1[0] === "INT" || config1[0] === "FLOAT";
for (const k of keys.values()) {
if (k !== "default" && k !== "forceInput" && k !== "defaultInput") {
let v1 = config1[1][k];
let v2 = config2[1]?.[k];
if (v1 === v2 || (!v1 && !v2)) continue;
if (isNumber) {
if (k === "min") {
const theirMax = config2[1]?.["max"];
if (theirMax != null && v1 > theirMax) {
console.log("connection rejected: min > max", v1, theirMax);
return false;
}
getCustomConfig()[k] = v1 == null ? v2 : v2 == null ? v1 : Math.max(v1, v2);
continue;
} else if (k === "max") {
const theirMin = config2[1]?.["min"];
if (theirMin != null && v1 < theirMin) {
console.log("connection rejected: max < min", v1, theirMin);
return false;
}
getCustomConfig()[k] = v1 == null ? v2 : v2 == null ? v1 : Math.min(v1, v2);
continue;
} else if (k === "step") {
let step;
if (v1 == null) {
// No current step
step = v2;
} else if (v2 == null) {
// No new step
step = v1;
} else {
if (v1 < v2) {
// Ensure v1 is larger for the mod
const a = v2;
v2 = v1;
v1 = a;
}
if (v1 % v2) {
console.log("connection rejected: steps not divisible", "current:", v1, "new:", v2);
return false;
}
step = v1;
}
getCustomConfig()[k] = step;
continue;
}
}
console.log(`connection rejected: config ${k} values dont match`, v1, v2);
return false;
}
}
if (customConfig || forceUpdate) {
if (customConfig) {
output.widget[CONFIG] = [config1[0], customConfig];
}
const widget = recreateWidget?.call(this);
// When deleting a node this can be null
if (widget) {
const min = widget.options.min;
const max = widget.options.max;
if (min != null && widget.value < min) widget.value = min;
if (max != null && widget.value > max) widget.value = max;
widget.callback(widget.value);
}
}
return { customConfig };
}
app.registerExtension({
name: "Comfy.WidgetInputs",
async beforeRegisterNodeDef(nodeType, nodeData, app) {
@ -308,7 +412,7 @@ app.registerExtension({
this.isVirtualNode = true;
}
applyToGraph() {
applyToGraph(extraLinks = []) {
if (!this.outputs[0].links?.length) return;
function get_links(node) {
@ -325,10 +429,9 @@ app.registerExtension({
return links;
}
let links = get_links(this);
let links = [...get_links(this).map((l) => app.graph.links[l]), ...extraLinks];
// For each output link copy our value over the original widget value
for (const l of links) {
const linkInfo = app.graph.links[l];
for (const linkInfo of links) {
const node = this.graph.getNodeById(linkInfo.target_id);
const input = node.inputs[linkInfo.target_slot];
const widgetName = input.widget.name;
@ -405,7 +508,12 @@ app.registerExtension({
}
if (this.outputs[slot].links?.length) {
return this.#isValidConnection(input);
const valid = this.#isValidConnection(input);
if (valid) {
// On connect of additional outputs, copy our value to their widget
this.applyToGraph([{ target_id: target_node.id, target_slot }]);
}
return valid;
}
}
@ -462,12 +570,16 @@ app.registerExtension({
}
}
if (widget.type === "number" || widget.type === "combo") {
if (!inputData?.[1]?.control_after_generate && (widget.type === "number" || widget.type === "combo")) {
let control_value = this.widgets_values?.[1];
if (!control_value) {
control_value = "fixed";
}
addValueControlWidget(this, widget, control_value);
addValueControlWidgets(this, widget, control_value, undefined, inputData);
let filter = this.widgets_values?.[2];
if(filter && this.widgets.length === 3) {
this.widgets[2].value = filter;
}
}
// When our value changes, update other widgets to reflect our changes
@ -503,6 +615,7 @@ app.registerExtension({
this.#removeWidgets();
this.#onFirstConnection(true);
for (let i = 0; i < this.widgets?.length; i++) this.widgets[i].value = values[i];
return this.widgets[0];
}
#mergeWidgetConfig() {
@ -543,108 +656,8 @@ app.registerExtension({
#isValidConnection(input, forceUpdate) {
// Only allow connections where the configs match
const output = this.outputs[0];
const config1 = output.widget[CONFIG] ?? output.widget[GET_CONFIG]();
const config2 = input.widget[GET_CONFIG]();
if (config1[0] instanceof Array) {
if (!isValidCombo(config1[0], config2[0])) return false;
} else if (config1[0] !== config2[0]) {
// Types dont match
console.log(`connection rejected: types dont match`, config1[0], config2[0]);
return false;
}
const keys = new Set([...Object.keys(config1[1] ?? {}), ...Object.keys(config2[1] ?? {})]);
let customConfig;
const getCustomConfig = () => {
if (!customConfig) {
if (typeof structuredClone === "undefined") {
customConfig = JSON.parse(JSON.stringify(config1[1] ?? {}));
} else {
customConfig = structuredClone(config1[1] ?? {});
}
}
return customConfig;
};
const isNumber = config1[0] === "INT" || config1[0] === "FLOAT";
for (const k of keys.values()) {
if (k !== "default" && k !== "forceInput" && k !== "defaultInput") {
let v1 = config1[1][k];
let v2 = config2[1][k];
if (v1 === v2 || (!v1 && !v2)) continue;
if (isNumber) {
if (k === "min") {
const theirMax = config2[1]["max"];
if (theirMax != null && v1 > theirMax) {
console.log("connection rejected: min > max", v1, theirMax);
return false;
}
getCustomConfig()[k] = v1 == null ? v2 : v2 == null ? v1 : Math.max(v1, v2);
continue;
} else if (k === "max") {
const theirMin = config2[1]["min"];
if (theirMin != null && v1 < theirMin) {
console.log("connection rejected: max < min", v1, theirMin);
return false;
}
getCustomConfig()[k] = v1 == null ? v2 : v2 == null ? v1 : Math.min(v1, v2);
continue;
} else if (k === "step") {
let step;
if (v1 == null) {
// No current step
step = v2;
} else if (v2 == null) {
// No new step
step = v1;
} else {
if (v1 < v2) {
// Ensure v1 is larger for the mod
const a = v2;
v2 = v1;
v1 = a;
}
if (v1 % v2) {
console.log("connection rejected: steps not divisible", "current:", v1, "new:", v2);
return false;
}
step = v1;
}
getCustomConfig()[k] = step;
continue;
}
}
console.log(`connection rejected: config ${k} values dont match`, v1, v2);
return false;
}
}
if (customConfig || forceUpdate) {
if (customConfig) {
output.widget[CONFIG] = [config1[0], customConfig];
}
this.#recreateWidget();
const widget = this.widgets[0];
// When deleting a node this can be null
if (widget) {
const min = widget.options.min;
const max = widget.options.max;
if (min != null && widget.value < min) widget.value = min;
if (max != null && widget.value > max) widget.value = max;
widget.callback(widget.value);
}
}
return true;
return !!mergeIfValid.call(this, output, config2, forceUpdate, this.#recreateWidget);
}
#removeWidgets() {

View File

@ -2533,7 +2533,7 @@
var w = this.widgets[i];
if(!w)
continue;
if(w.options && w.options.property && this.properties[ w.options.property ])
if(w.options && w.options.property && (this.properties[ w.options.property ] != undefined))
w.value = JSON.parse( JSON.stringify( this.properties[ w.options.property ] ) );
}
if (info.widgets_values) {
@ -5714,10 +5714,10 @@ LGraphNode.prototype.executeAction = function(action)
* @method enableWebGL
**/
LGraphCanvas.prototype.enableWebGL = function() {
if (typeof GL === undefined) {
if (typeof GL === "undefined") {
throw "litegl.js must be included to use a WebGL canvas";
}
if (typeof enableWebGLCanvas === undefined) {
if (typeof enableWebGLCanvas === "undefined") {
throw "webglCanvas.js must be included to use this feature";
}
@ -7110,15 +7110,16 @@ LGraphNode.prototype.executeAction = function(action)
}
};
LGraphCanvas.prototype.copyToClipboard = function() {
LGraphCanvas.prototype.copyToClipboard = function(nodes) {
var clipboard_info = {
nodes: [],
links: []
};
var index = 0;
var selected_nodes_array = [];
for (var i in this.selected_nodes) {
var node = this.selected_nodes[i];
if (!nodes) nodes = this.selected_nodes;
for (var i in nodes) {
var node = nodes[i];
if (node.clonable === false)
continue;
node._relative_id = index;
@ -11702,7 +11703,7 @@ LGraphNode.prototype.executeAction = function(action)
default:
iS = 0; // try with first if no name set
}
if (typeof options.node_from.outputs[iS] !== undefined){
if (typeof options.node_from.outputs[iS] !== "undefined"){
if (iS!==false && iS>-1){
options.node_from.connectByType( iS, node, options.node_from.outputs[iS].type );
}
@ -11730,7 +11731,7 @@ LGraphNode.prototype.executeAction = function(action)
default:
iS = 0; // try with first if no name set
}
if (typeof options.node_to.inputs[iS] !== undefined){
if (typeof options.node_to.inputs[iS] !== "undefined"){
if (iS!==false && iS>-1){
// try connection
options.node_to.connectByTypeOutput(iS,node,options.node_to.inputs[iS].type);

View File

@ -254,9 +254,9 @@ class ComfyApi extends EventTarget {
* Gets the prompt execution history
* @returns Prompt history including node outputs
*/
async getHistory() {
async getHistory(max_items=200) {
try {
const res = await this.fetchApi("/history");
const res = await this.fetchApi(`/history?max_items=${max_items}`);
return { History: Object.values(await res.json()) };
} catch (error) {
console.error(error);

View File

@ -4,7 +4,10 @@ import { ComfyUI, $el } from "./ui.js";
import { api } from "./api.js";
import { defaultGraph } from "./defaultGraph.js";
import { getPngMetadata, getWebpMetadata, importA1111, getLatentMetadata } from "./pnginfo.js";
import { addDomClippingSetting } from "./domWidget.js";
import { createImageHost, calculateImageGrid } from "./ui/imagePreview.js"
export const ANIM_PREVIEW_WIDGET = "$$comfy_animation_preview"
function sanitizeNodeName(string) {
let entityMap = {
@ -409,7 +412,9 @@ export class ComfyApp {
return shiftY;
}
node.prototype.setSizeForImage = function () {
node.prototype.setSizeForImage = function (force) {
if(!force && this.animatedImages) return;
if (this.inputHeight) {
this.setSize(this.size);
return;
@ -426,13 +431,20 @@ export class ComfyApp {
let imagesChanged = false
const output = app.nodeOutputs[this.id + ""];
if (output && output.images) {
if (output?.images) {
this.animatedImages = output?.animated?.find(Boolean);
if (this.images !== output.images) {
this.images = output.images;
imagesChanged = true;
imgURLs = imgURLs.concat(output.images.map(params => {
return api.apiURL("/view?" + new URLSearchParams(params).toString() + app.getPreviewFormatParam() + app.getRandParam());
}))
imgURLs = imgURLs.concat(
output.images.map((params) => {
return api.apiURL(
"/view?" +
new URLSearchParams(params).toString() +
(this.animatedImages ? "" : app.getPreviewFormatParam()) + app.getRandParam()
);
})
);
}
}
@ -511,7 +523,35 @@ export class ComfyApp {
return true;
}
if (this.imgs && this.imgs.length) {
if (this.imgs?.length) {
const widgetIdx = this.widgets?.findIndex((w) => w.name === ANIM_PREVIEW_WIDGET);
if(this.animatedImages) {
// Instead of using the canvas we'll use a IMG
if(widgetIdx > -1) {
// Replace content
const widget = this.widgets[widgetIdx];
widget.options.host.updateImages(this.imgs);
} else {
const host = createImageHost(this);
this.setSizeForImage(true);
const widget = this.addDOMWidget(ANIM_PREVIEW_WIDGET, "img", host.el, {
host,
getHeight: host.getHeight,
onDraw: host.onDraw,
hideOnZoom: false
});
widget.serializeValue = () => undefined;
widget.options.host.updateImages(this.imgs);
}
return;
}
if (widgetIdx > -1) {
this.widgets[widgetIdx].onRemove?.();
this.widgets.splice(widgetIdx, 1);
}
const canvas = app.graph.list_of_graphcanvas[0];
const mouse = canvas.graph_mouse;
if (!canvas.pointer_is_down && this.pointerDown) {
@ -551,31 +591,7 @@ export class ComfyApp {
}
else {
cell_padding = 0;
let best = 0;
let w = this.imgs[0].naturalWidth;
let h = this.imgs[0].naturalHeight;
// compact style
for (let c = 1; c <= numImages; c++) {
const rows = Math.ceil(numImages / c);
const cW = dw / c;
const cH = dh / rows;
const scaleX = cW / w;
const scaleY = cH / h;
const scale = Math.min(scaleX, scaleY, 1);
const imageW = w * scale;
const imageH = h * scale;
const area = imageW * imageH * numImages;
if (area > best) {
best = area;
cellWidth = imageW;
cellHeight = imageH;
cols = c;
shiftX = c * ((cW - imageW) / 2);
}
}
({ cellWidth, cellHeight, cols, shiftX } = calculateImageGrid(this.imgs, dw, dh));
}
let anyHovered = false;
@ -767,7 +783,7 @@ export class ComfyApp {
* Adds a handler on paste that extracts and loads images or workflows from pasted JSON data
*/
#addPasteHandler() {
document.addEventListener("paste", (e) => {
document.addEventListener("paste", async (e) => {
// ctrl+shift+v is used to paste nodes with connections
// this is handled by litegraph
if(this.shiftDown) return;
@ -815,7 +831,7 @@ export class ComfyApp {
}
if (workflow && workflow.version && workflow.nodes && workflow.extra) {
this.loadGraphData(workflow);
await this.loadGraphData(workflow);
}
else {
if (e.target.type === "text" || e.target.type === "textarea") {
@ -1165,7 +1181,19 @@ export class ComfyApp {
});
api.addEventListener("executed", ({ detail }) => {
this.nodeOutputs[detail.node] = detail.output;
const output = this.nodeOutputs[detail.node];
if (detail.merge && output) {
for (const k in detail.output ?? {}) {
const v = output[k];
if (v instanceof Array) {
output[k] = v.concat(detail.output[k]);
} else {
output[k] = detail.output[k];
}
}
} else {
this.nodeOutputs[detail.node] = detail.output;
}
const node = this.graph.getNodeById(detail.node);
if (node) {
if (node.onExecuted)
@ -1276,9 +1304,11 @@ export class ComfyApp {
canvasEl.tabIndex = "1";
document.body.prepend(canvasEl);
addDomClippingSetting();
this.#addProcessMouseHandler();
this.#addProcessKeyHandler();
this.#addConfigureHandler();
this.#addApiUpdateHandlers();
this.graph = new LGraph();
@ -1315,7 +1345,7 @@ export class ComfyApp {
const json = localStorage.getItem("workflow");
if (json) {
const workflow = JSON.parse(json);
this.loadGraphData(workflow);
await this.loadGraphData(workflow);
restored = true;
}
} catch (err) {
@ -1324,7 +1354,7 @@ export class ComfyApp {
// We failed to restore a workflow so load the default
if (!restored) {
this.loadGraphData();
await this.loadGraphData();
}
// Save current workflow automatically
@ -1332,7 +1362,6 @@ export class ComfyApp {
this.#addDrawNodeHandler();
this.#addDrawGroupsHandler();
this.#addApiUpdateHandlers();
this.#addDropHandler();
this.#addCopyHandler();
this.#addPasteHandler();
@ -1352,11 +1381,95 @@ export class ComfyApp {
await this.#invokeExtensionsAsync("registerCustomNodes");
}
getWidgetType(inputData, inputName) {
const type = inputData[0];
if (Array.isArray(type)) {
return "COMBO";
} else if (`${type}:${inputName}` in this.widgets) {
return `${type}:${inputName}`;
} else if (type in this.widgets) {
return type;
} else {
return null;
}
}
async registerNodeDef(nodeId, nodeData) {
const self = this;
const node = Object.assign(
function ComfyNode() {
var inputs = nodeData["input"]["required"];
if (nodeData["input"]["optional"] != undefined) {
inputs = Object.assign({}, nodeData["input"]["required"], nodeData["input"]["optional"]);
}
const config = { minWidth: 1, minHeight: 1 };
for (const inputName in inputs) {
const inputData = inputs[inputName];
const type = inputData[0];
let widgetCreated = true;
const widgetType = self.getWidgetType(inputData, inputName);
if(widgetType) {
if(widgetType === "COMBO") {
Object.assign(config, self.widgets.COMBO(this, inputName, inputData, app) || {});
} else {
Object.assign(config, self.widgets[widgetType](this, inputName, inputData, app) || {});
}
} else {
// Node connection inputs
this.addInput(inputName, type);
widgetCreated = false;
}
if(widgetCreated && inputData[1]?.forceInput && config?.widget) {
if (!config.widget.options) config.widget.options = {};
config.widget.options.forceInput = inputData[1].forceInput;
}
if(widgetCreated && inputData[1]?.defaultInput && config?.widget) {
if (!config.widget.options) config.widget.options = {};
config.widget.options.defaultInput = inputData[1].defaultInput;
}
}
for (const o in nodeData["output"]) {
let output = nodeData["output"][o];
if(output instanceof Array) output = "COMBO";
const outputName = nodeData["output_name"][o] || output;
const outputShape = nodeData["output_is_list"][o] ? LiteGraph.GRID_SHAPE : LiteGraph.CIRCLE_SHAPE ;
this.addOutput(outputName, output, { shape: outputShape });
}
const s = this.computeSize();
s[0] = Math.max(config.minWidth, s[0] * 1.5);
s[1] = Math.max(config.minHeight, s[1]);
this.size = s;
this.serialize_widgets = true;
app.#invokeExtensionsAsync("nodeCreated", this);
},
{
title: nodeData.display_name || nodeData.name,
comfyClass: nodeData.name,
nodeData
}
);
node.prototype.comfyClass = nodeData.name;
this.#addNodeContextMenuHandler(node);
this.#addDrawBackgroundHandler(node, app);
this.#addNodeKeyHandler(node);
await this.#invokeExtensionsAsync("beforeRegisterNodeDef", node, nodeData);
LiteGraph.registerNodeType(nodeId, node);
node.category = nodeData.category;
}
async registerNodesFromDefs(defs) {
await this.#invokeExtensionsAsync("addCustomNodeDefs", defs);
// Generate list of known widgets
const widgets = Object.assign(
this.widgets = Object.assign(
{},
ComfyWidgets,
...(await this.#invokeExtensionsAsync("getCustomWidgets")).filter(Boolean)
@ -1364,75 +1477,7 @@ export class ComfyApp {
// Register a node for each definition
for (const nodeId in defs) {
const nodeData = defs[nodeId];
const node = Object.assign(
function ComfyNode() {
var inputs = nodeData["input"]["required"];
if (nodeData["input"]["optional"] != undefined){
inputs = Object.assign({}, nodeData["input"]["required"], nodeData["input"]["optional"])
}
const config = { minWidth: 1, minHeight: 1 };
for (const inputName in inputs) {
const inputData = inputs[inputName];
const type = inputData[0];
let widgetCreated = true;
if (Array.isArray(type)) {
// Enums
Object.assign(config, widgets.COMBO(this, inputName, inputData, app) || {});
} else if (`${type}:${inputName}` in widgets) {
// Support custom widgets by Type:Name
Object.assign(config, widgets[`${type}:${inputName}`](this, inputName, inputData, app) || {});
} else if (type in widgets) {
// Standard type widgets
Object.assign(config, widgets[type](this, inputName, inputData, app) || {});
} else {
// Node connection inputs
this.addInput(inputName, type);
widgetCreated = false;
}
if(widgetCreated && inputData[1]?.forceInput && config?.widget) {
if (!config.widget.options) config.widget.options = {};
config.widget.options.forceInput = inputData[1].forceInput;
}
if(widgetCreated && inputData[1]?.defaultInput && config?.widget) {
if (!config.widget.options) config.widget.options = {};
config.widget.options.defaultInput = inputData[1].defaultInput;
}
}
for (const o in nodeData["output"]) {
let output = nodeData["output"][o];
if(output instanceof Array) output = "COMBO";
const outputName = nodeData["output_name"][o] || output;
const outputShape = nodeData["output_is_list"][o] ? LiteGraph.GRID_SHAPE : LiteGraph.CIRCLE_SHAPE ;
this.addOutput(outputName, output, { shape: outputShape });
}
const s = this.computeSize();
s[0] = Math.max(config.minWidth, s[0] * 1.5);
s[1] = Math.max(config.minHeight, s[1]);
this.size = s;
this.serialize_widgets = true;
app.#invokeExtensionsAsync("nodeCreated", this);
},
{
title: nodeData.display_name || nodeData.name,
comfyClass: nodeData.name,
nodeData
}
);
node.prototype.comfyClass = nodeData.name;
this.#addNodeContextMenuHandler(node);
this.#addDrawBackgroundHandler(node, app);
this.#addNodeKeyHandler(node);
await this.#invokeExtensionsAsync("beforeRegisterNodeDef", node, nodeData);
LiteGraph.registerNodeType(nodeId, node);
node.category = nodeData.category;
this.registerNodeDef(nodeId, defs[nodeId]);
}
}
@ -1475,9 +1520,14 @@ export class ComfyApp {
showMissingNodesError(missingNodeTypes, hasAddedNodes = true) {
this.ui.dialog.show(
`When loading the graph, the following node types were not found: <ul>${Array.from(new Set(missingNodeTypes)).map(
(t) => `<li>${t}</li>`
).join("")}</ul>${hasAddedNodes ? "Nodes that have failed to load will show as red on the graph." : ""}`
$el("div", [
$el("span", { textContent: "When loading the graph, the following node types were not found: " }),
$el(
"ul",
Array.from(new Set(missingNodeTypes)).map((t) => $el("li", { textContent: t }))
),
...(hasAddedNodes ? [$el("span", { textContent: "Nodes that have failed to load will show as red on the graph." })] : []),
])
);
this.logging.addEntry("Comfy.App", "warn", {
MissingNodes: missingNodeTypes,
@ -1488,31 +1538,35 @@ export class ComfyApp {
* Populates the graph with the specified workflow data
* @param {*} graphData A serialized graph object
*/
loadGraphData(graphData) {
async loadGraphData(graphData) {
this.clean();
let reset_invalid_values = false;
if (!graphData) {
if (typeof structuredClone === "undefined")
{
graphData = JSON.parse(JSON.stringify(defaultGraph));
}else
{
graphData = structuredClone(defaultGraph);
}
graphData = defaultGraph;
reset_invalid_values = true;
}
if (typeof structuredClone === "undefined")
{
graphData = JSON.parse(JSON.stringify(graphData));
}else
{
graphData = structuredClone(graphData);
}
const missingNodeTypes = [];
await this.#invokeExtensionsAsync("beforeConfigureGraph", graphData, missingNodeTypes);
for (let n of graphData.nodes) {
// Patch T2IAdapterLoader to ControlNetLoader since they are the same node now
if (n.type == "T2IAdapterLoader") n.type = "ControlNetLoader";
if (n.type == "ConditioningAverage ") n.type = "ConditioningAverage"; //typo fix
if (n.type == "SDV_img2vid_Conditioning") n.type = "SVD_img2vid_Conditioning"; //typo fix
// Find missing node types
if (!(n.type in LiteGraph.registered_node_types)) {
n.type = sanitizeNodeName(n.type);
missingNodeTypes.push(n.type);
n.type = sanitizeNodeName(n.type);
}
}
@ -1604,6 +1658,7 @@ export class ComfyApp {
if (missingNodeTypes.length) {
this.showMissingNodesError(missingNodeTypes);
}
await this.#invokeExtensionsAsync("afterConfigureGraph", missingNodeTypes);
}
/**
@ -1611,92 +1666,98 @@ export class ComfyApp {
* @returns The workflow and node links
*/
async graphToPrompt() {
for (const node of this.graph.computeExecutionOrder(false)) {
if (node.isVirtualNode) {
// Don't serialize frontend only nodes but let them make changes
if (node.applyToGraph) {
node.applyToGraph();
for (const outerNode of this.graph.computeExecutionOrder(false)) {
const innerNodes = outerNode.getInnerNodes ? outerNode.getInnerNodes() : [outerNode];
for (const node of innerNodes) {
if (node.isVirtualNode) {
// Don't serialize frontend only nodes but let them make changes
if (node.applyToGraph) {
node.applyToGraph();
}
}
continue;
}
}
const workflow = this.graph.serialize();
const output = {};
// Process nodes in order of execution
for (const node of this.graph.computeExecutionOrder(false)) {
const n = workflow.nodes.find((n) => n.id === node.id);
for (const outerNode of this.graph.computeExecutionOrder(false)) {
const innerNodes = outerNode.getInnerNodes ? outerNode.getInnerNodes() : [outerNode];
for (const node of innerNodes) {
if (node.isVirtualNode) {
continue;
}
if (node.isVirtualNode) {
continue;
}
if (node.mode === 2 || node.mode === 4) {
// Don't serialize muted nodes
continue;
}
if (node.mode === 2 || node.mode === 4) {
// Don't serialize muted nodes
continue;
}
const inputs = {};
const widgets = node.widgets;
const inputs = {};
const widgets = node.widgets;
// Store all widget values
if (widgets) {
for (const i in widgets) {
const widget = widgets[i];
if (!widget.options || widget.options.serialize !== false) {
inputs[widget.name] = widget.serializeValue ? await widget.serializeValue(n, i) : widget.value;
// Store all widget values
if (widgets) {
for (const i in widgets) {
const widget = widgets[i];
if (!widget.options || widget.options.serialize !== false) {
inputs[widget.name] = widget.serializeValue ? await widget.serializeValue(node, i) : widget.value;
}
}
}
}
// Store all node links
for (let i in node.inputs) {
let parent = node.getInputNode(i);
if (parent) {
let link = node.getInputLink(i);
while (parent.mode === 4 || parent.isVirtualNode) {
let found = false;
if (parent.isVirtualNode) {
link = parent.getInputLink(link.origin_slot);
if (link) {
parent = parent.getInputNode(link.target_slot);
if (parent) {
found = true;
}
}
} else if (link && parent.mode === 4) {
let all_inputs = [link.origin_slot];
if (parent.inputs) {
all_inputs = all_inputs.concat(Object.keys(parent.inputs))
for (let parent_input in all_inputs) {
parent_input = all_inputs[parent_input];
if (parent.inputs[parent_input]?.type === node.inputs[i].type) {
link = parent.getInputLink(parent_input);
if (link) {
parent = parent.getInputNode(parent_input);
}
// Store all node links
for (let i in node.inputs) {
let parent = node.getInputNode(i);
if (parent) {
let link = node.getInputLink(i);
while (parent.mode === 4 || parent.isVirtualNode) {
let found = false;
if (parent.isVirtualNode) {
link = parent.getInputLink(link.origin_slot);
if (link) {
parent = parent.getInputNode(link.target_slot);
if (parent) {
found = true;
break;
}
}
} else if (link && parent.mode === 4) {
let all_inputs = [link.origin_slot];
if (parent.inputs) {
all_inputs = all_inputs.concat(Object.keys(parent.inputs))
for (let parent_input in all_inputs) {
parent_input = all_inputs[parent_input];
if (parent.inputs[parent_input]?.type === node.inputs[i].type) {
link = parent.getInputLink(parent_input);
if (link) {
parent = parent.getInputNode(parent_input);
}
found = true;
break;
}
}
}
}
if (!found) {
break;
}
}
if (!found) {
break;
if (link) {
if (parent?.updateLink) {
link = parent.updateLink(link);
}
inputs[node.inputs[i].name] = [String(link.origin_id), parseInt(link.origin_slot)];
}
}
if (link) {
inputs[node.inputs[i].name] = [String(link.origin_id), parseInt(link.origin_slot)];
}
}
}
output[String(node.id)] = {
inputs,
class_type: node.comfyClass,
};
output[String(node.id)] = {
inputs,
class_type: node.comfyClass,
};
}
}
// Remove inputs connected to removed nodes
@ -1816,7 +1877,7 @@ export class ComfyApp {
const pngInfo = await getPngMetadata(file);
if (pngInfo) {
if (pngInfo.workflow) {
this.loadGraphData(JSON.parse(pngInfo.workflow));
await this.loadGraphData(JSON.parse(pngInfo.workflow));
} else if (pngInfo.parameters) {
importA1111(this.graph, pngInfo.parameters);
}
@ -1832,21 +1893,21 @@ export class ComfyApp {
}
} else if (file.type === "application/json" || file.name?.endsWith(".json")) {
const reader = new FileReader();
reader.onload = () => {
reader.onload = async () => {
const jsonContent = JSON.parse(reader.result);
if (jsonContent?.templates) {
this.loadTemplateData(jsonContent);
} else if(this.isApiJson(jsonContent)) {
this.loadApiJson(jsonContent);
} else {
this.loadGraphData(jsonContent);
await this.loadGraphData(jsonContent);
}
};
reader.readAsText(file);
} else if (file.name?.endsWith(".latent") || file.name?.endsWith(".safetensors")) {
const info = await getLatentMetadata(file);
if (info.workflow) {
this.loadGraphData(JSON.parse(info.workflow));
await this.loadGraphData(JSON.parse(info.workflow));
}
}
}
@ -1867,7 +1928,7 @@ export class ComfyApp {
for (const id of ids) {
const data = apiData[id];
const node = LiteGraph.createNode(data.class_type);
node.id = id;
node.id = isNaN(+id) ? id : +id;
graph.add(node);
}

322
web/scripts/domWidget.js Normal file
View File

@ -0,0 +1,322 @@
import { app, ANIM_PREVIEW_WIDGET } from "./app.js";
const SIZE = Symbol();
function intersect(a, b) {
const x = Math.max(a.x, b.x);
const num1 = Math.min(a.x + a.width, b.x + b.width);
const y = Math.max(a.y, b.y);
const num2 = Math.min(a.y + a.height, b.y + b.height);
if (num1 >= x && num2 >= y) return [x, y, num1 - x, num2 - y];
else return null;
}
function getClipPath(node, element, elRect) {
const selectedNode = Object.values(app.canvas.selected_nodes)[0];
if (selectedNode && selectedNode !== node) {
const MARGIN = 7;
const scale = app.canvas.ds.scale;
const bounding = selectedNode.getBounding();
const intersection = intersect(
{ x: elRect.x / scale, y: elRect.y / scale, width: elRect.width / scale, height: elRect.height / scale },
{
x: selectedNode.pos[0] + app.canvas.ds.offset[0] - MARGIN,
y: selectedNode.pos[1] + app.canvas.ds.offset[1] - LiteGraph.NODE_TITLE_HEIGHT - MARGIN,
width: bounding[2] + MARGIN + MARGIN,
height: bounding[3] + MARGIN + MARGIN,
}
);
if (!intersection) {
return "";
}
const widgetRect = element.getBoundingClientRect();
const clipX = intersection[0] - widgetRect.x / scale + "px";
const clipY = intersection[1] - widgetRect.y / scale + "px";
const clipWidth = intersection[2] + "px";
const clipHeight = intersection[3] + "px";
const path = `polygon(0% 0%, 0% 100%, ${clipX} 100%, ${clipX} ${clipY}, calc(${clipX} + ${clipWidth}) ${clipY}, calc(${clipX} + ${clipWidth}) calc(${clipY} + ${clipHeight}), ${clipX} calc(${clipY} + ${clipHeight}), ${clipX} 100%, 100% 100%, 100% 0%)`;
return path;
}
return "";
}
function computeSize(size) {
if (this.widgets?.[0]?.last_y == null) return;
let y = this.widgets[0].last_y;
let freeSpace = size[1] - y;
let widgetHeight = 0;
let dom = [];
for (const w of this.widgets) {
if (w.type === "converted-widget") {
// Ignore
delete w.computedHeight;
} else if (w.computeSize) {
widgetHeight += w.computeSize()[1] + 4;
} else if (w.element) {
// Extract DOM widget size info
const styles = getComputedStyle(w.element);
let minHeight = w.options.getMinHeight?.() ?? parseInt(styles.getPropertyValue("--comfy-widget-min-height"));
let maxHeight = w.options.getMaxHeight?.() ?? parseInt(styles.getPropertyValue("--comfy-widget-max-height"));
let prefHeight = w.options.getHeight?.() ?? styles.getPropertyValue("--comfy-widget-height");
if (prefHeight.endsWith?.("%")) {
prefHeight = size[1] * (parseFloat(prefHeight.substring(0, prefHeight.length - 1)) / 100);
} else {
prefHeight = parseInt(prefHeight);
if (isNaN(minHeight)) {
minHeight = prefHeight;
}
}
if (isNaN(minHeight)) {
minHeight = 50;
}
if (!isNaN(maxHeight)) {
if (!isNaN(prefHeight)) {
prefHeight = Math.min(prefHeight, maxHeight);
} else {
prefHeight = maxHeight;
}
}
dom.push({
minHeight,
prefHeight,
w,
});
} else {
widgetHeight += LiteGraph.NODE_WIDGET_HEIGHT + 4;
}
}
freeSpace -= widgetHeight;
// Calculate sizes with all widgets at their min height
const prefGrow = []; // Nodes that want to grow to their prefd size
const canGrow = []; // Nodes that can grow to auto size
let growBy = 0;
for (const d of dom) {
freeSpace -= d.minHeight;
if (isNaN(d.prefHeight)) {
canGrow.push(d);
d.w.computedHeight = d.minHeight;
} else {
const diff = d.prefHeight - d.minHeight;
if (diff > 0) {
prefGrow.push(d);
growBy += diff;
d.diff = diff;
} else {
d.w.computedHeight = d.minHeight;
}
}
}
if (this.imgs && !this.widgets.find((w) => w.name === ANIM_PREVIEW_WIDGET)) {
// Allocate space for image
freeSpace -= 220;
}
if (freeSpace < 0) {
// Not enough space for all widgets so we need to grow
size[1] -= freeSpace;
this.graph.setDirtyCanvas(true);
} else {
// Share the space between each
const growDiff = freeSpace - growBy;
if (growDiff > 0) {
// All pref sizes can be fulfilled
freeSpace = growDiff;
for (const d of prefGrow) {
d.w.computedHeight = d.prefHeight;
}
} else {
// We need to grow evenly
const shared = -growDiff / prefGrow.length;
for (const d of prefGrow) {
d.w.computedHeight = d.prefHeight - shared;
}
freeSpace = 0;
}
if (freeSpace > 0 && canGrow.length) {
// Grow any that are auto height
const shared = freeSpace / canGrow.length;
for (const d of canGrow) {
d.w.computedHeight += shared;
}
}
}
// Position each of the widgets
for (const w of this.widgets) {
w.y = y;
if (w.computedHeight) {
y += w.computedHeight;
} else if (w.computeSize) {
y += w.computeSize()[1] + 4;
} else {
y += LiteGraph.NODE_WIDGET_HEIGHT + 4;
}
}
}
// Override the compute visible nodes function to allow us to hide/show DOM elements when the node goes offscreen
const elementWidgets = new Set();
const computeVisibleNodes = LGraphCanvas.prototype.computeVisibleNodes;
LGraphCanvas.prototype.computeVisibleNodes = function () {
const visibleNodes = computeVisibleNodes.apply(this, arguments);
for (const node of app.graph._nodes) {
if (elementWidgets.has(node)) {
const hidden = visibleNodes.indexOf(node) === -1;
for (const w of node.widgets) {
if (w.element) {
w.element.hidden = hidden;
if (hidden) {
w.options.onHide?.(w);
}
}
}
}
}
return visibleNodes;
};
let enableDomClipping = true;
export function addDomClippingSetting() {
app.ui.settings.addSetting({
id: "Comfy.DOMClippingEnabled",
name: "Enable DOM element clipping (enabling may reduce performance)",
type: "boolean",
defaultValue: enableDomClipping,
onChange(value) {
enableDomClipping = !!value;
},
});
}
LGraphNode.prototype.addDOMWidget = function (name, type, element, options) {
options = { hideOnZoom: true, selectOn: ["focus", "click"], ...options };
if (!element.parentElement) {
document.body.append(element);
}
let mouseDownHandler;
if (element.blur) {
mouseDownHandler = (event) => {
if (!element.contains(event.target)) {
element.blur();
}
};
document.addEventListener("mousedown", mouseDownHandler);
}
const widget = {
type,
name,
get value() {
return options.getValue?.() ?? undefined;
},
set value(v) {
options.setValue?.(v);
widget.callback?.(widget.value);
},
draw: function (ctx, node, widgetWidth, y, widgetHeight) {
if (widget.computedHeight == null) {
computeSize.call(node, node.size);
}
const hidden =
node.flags?.collapsed ||
(!!options.hideOnZoom && app.canvas.ds.scale < 0.5) ||
widget.computedHeight <= 0 ||
widget.type === "converted-widget";
element.hidden = hidden;
element.style.display = hidden ? "none" : null;
if (hidden) {
widget.options.onHide?.(widget);
return;
}
const margin = 10;
const elRect = ctx.canvas.getBoundingClientRect();
const transform = new DOMMatrix()
.scaleSelf(elRect.width / ctx.canvas.width, elRect.height / ctx.canvas.height)
.multiplySelf(ctx.getTransform())
.translateSelf(margin, margin + y);
const scale = new DOMMatrix().scaleSelf(transform.a, transform.d);
Object.assign(element.style, {
transformOrigin: "0 0",
transform: scale,
left: `${transform.a + transform.e}px`,
top: `${transform.d + transform.f}px`,
width: `${widgetWidth - margin * 2}px`,
height: `${(widget.computedHeight ?? 50) - margin * 2}px`,
position: "absolute",
zIndex: app.graph._nodes.indexOf(node),
});
if (enableDomClipping) {
element.style.clipPath = getClipPath(node, element, elRect);
element.style.willChange = "clip-path";
}
this.options.onDraw?.(widget);
},
element,
options,
onRemove() {
if (mouseDownHandler) {
document.removeEventListener("mousedown", mouseDownHandler);
}
element.remove();
},
};
for (const evt of options.selectOn) {
element.addEventListener(evt, () => {
app.canvas.selectNode(this);
app.canvas.bringToFront(this);
});
}
this.addCustomWidget(widget);
elementWidgets.add(this);
const collapse = this.collapse;
this.collapse = function() {
collapse.apply(this, arguments);
if(this.flags?.collapsed) {
element.hidden = true;
element.style.display = "none";
}
}
const onRemoved = this.onRemoved;
this.onRemoved = function () {
element.remove();
elementWidgets.delete(this);
onRemoved?.apply(this, arguments);
};
if (!this[SIZE]) {
this[SIZE] = true;
const onResize = this.onResize;
this.onResize = function (size) {
options.beforeResize?.call(widget, this);
computeSize.call(this, size);
onResize?.apply(this, arguments);
options.afterResize?.call(widget, this);
};
}
return widget;
};

View File

@ -24,7 +24,7 @@ export function getPngMetadata(file) {
const length = dataView.getUint32(offset);
// Get the chunk type
const type = String.fromCharCode(...pngData.slice(offset + 4, offset + 8));
if (type === "tEXt") {
if (type === "tEXt" || type == "comf") {
// Get the keyword
let keyword_end = offset + 8;
while (pngData[keyword_end] !== 0) {
@ -50,7 +50,6 @@ export function getPngMetadata(file) {
function parseExifData(exifData) {
// Check for the correct TIFF header (0x4949 for little-endian or 0x4D4D for big-endian)
const isLittleEndian = new Uint16Array(exifData.slice(0, 2))[0] === 0x4949;
console.log(exifData);
// Function to read 16-bit and 32-bit integers from binary data
function readInt(offset, isLittleEndian, length) {
@ -126,6 +125,9 @@ export function getWebpMetadata(file) {
const chunk_length = dataView.getUint32(offset + 4, true);
const chunk_type = String.fromCharCode(...webp.slice(offset, offset + 4));
if (chunk_type === "EXIF") {
if (String.fromCharCode(...webp.slice(offset + 8, offset + 8 + 6)) == "Exif\0\0") {
offset += 6;
}
let data = parseExifData(webp.slice(offset + 8, offset + 8 + chunk_length));
for (var key in data) {
var value = data[key];

View File

@ -462,8 +462,8 @@ class ComfyList {
return $el("div", {textContent: item.prompt[0] + ": "}, [
$el("button", {
textContent: "Load",
onclick: () => {
app.loadGraphData(item.prompt[3].extra_pnginfo.workflow);
onclick: async () => {
await app.loadGraphData(item.prompt[3].extra_pnginfo.workflow);
if (item.outputs) {
app.nodeOutputs = item.outputs;
}
@ -599,7 +599,7 @@ export class ComfyUI {
const fileInput = $el("input", {
id: "comfy-file-input",
type: "file",
accept: ".json,image/png,.latent,.safetensors",
accept: ".json,image/png,.latent,.safetensors,image/webp",
style: {display: "none"},
parent: document.body,
onchange: () => {
@ -784,9 +784,9 @@ export class ComfyUI {
}
}),
$el("button", {
id: "comfy-load-default-button", textContent: "Load Default", onclick: () => {
id: "comfy-load-default-button", textContent: "Load Default", onclick: async () => {
if (!confirmClear.value || confirm("Load default workflow?")) {
app.loadGraphData()
await app.loadGraphData()
}
}
}),

View File

@ -0,0 +1,97 @@
import { $el } from "../ui.js";
export function calculateImageGrid(imgs, dw, dh) {
let best = 0;
let w = imgs[0].naturalWidth;
let h = imgs[0].naturalHeight;
const numImages = imgs.length;
let cellWidth, cellHeight, cols, rows, shiftX;
// compact style
for (let c = 1; c <= numImages; c++) {
const r = Math.ceil(numImages / c);
const cW = dw / c;
const cH = dh / r;
const scaleX = cW / w;
const scaleY = cH / h;
const scale = Math.min(scaleX, scaleY, 1);
const imageW = w * scale;
const imageH = h * scale;
const area = imageW * imageH * numImages;
if (area > best) {
best = area;
cellWidth = imageW;
cellHeight = imageH;
cols = c;
rows = r;
shiftX = c * ((cW - imageW) / 2);
}
}
return { cellWidth, cellHeight, cols, rows, shiftX };
}
export function createImageHost(node) {
const el = $el("div.comfy-img-preview");
let currentImgs;
let first = true;
function updateSize() {
let w = null;
let h = null;
if (currentImgs) {
let elH = el.clientHeight;
if (first) {
first = false;
// On first run, if we are small then grow a bit
if (elH < 190) {
elH = 190;
}
el.style.setProperty("--comfy-widget-min-height", elH);
} else {
el.style.setProperty("--comfy-widget-min-height", null);
}
const nw = node.size[0];
({ cellWidth: w, cellHeight: h } = calculateImageGrid(currentImgs, nw - 20, elH));
w += "px";
h += "px";
el.style.setProperty("--comfy-img-preview-width", w);
el.style.setProperty("--comfy-img-preview-height", h);
}
}
return {
el,
updateImages(imgs) {
if (imgs !== currentImgs) {
if (currentImgs == null) {
requestAnimationFrame(() => {
updateSize();
});
}
el.replaceChildren(...imgs);
currentImgs = imgs;
node.onResize(node.size);
node.graph.setDirtyCanvas(true, true);
}
},
getHeight() {
updateSize();
},
onDraw() {
// Element from point uses a hittest find elements so we need to toggle pointer events
el.style.pointerEvents = "all";
const over = document.elementFromPoint(app.canvas.mouse[0], app.canvas.mouse[1]);
el.style.pointerEvents = "none";
if(!over) return;
// Set the overIndex so Open Image etc work
const idx = currentImgs.indexOf(over);
node.overIndex = idx;
},
};
}

View File

@ -1,4 +1,5 @@
import { api } from "./api.js"
import "./domWidget.js";
function getNumberDefaults(inputData, defaultStep, precision, enable_rounding) {
let defaultVal = inputData[1]["default"];
@ -22,18 +23,89 @@ function getNumberDefaults(inputData, defaultStep, precision, enable_rounding) {
return { val: defaultVal, config: { min, max, step: 10.0 * step, round, precision } };
}
export function addValueControlWidget(node, targetWidget, defaultValue = "randomize", values) {
const valueControl = node.addWidget("combo", "control_after_generate", defaultValue, function (v) { }, {
values: ["fixed", "increment", "decrement", "randomize"],
serialize: false, // Don't include this in prompt.
});
valueControl.afterQueued = () => {
export function addValueControlWidget(node, targetWidget, defaultValue = "randomize", values, widgetName, inputData) {
let name = inputData[1]?.control_after_generate;
if(typeof name !== "string") {
name = widgetName;
}
const widgets = addValueControlWidgets(node, targetWidget, defaultValue, {
addFilterList: false,
controlAfterGenerateName: name
}, inputData);
return widgets[0];
}
export function addValueControlWidgets(node, targetWidget, defaultValue = "randomize", options, inputData) {
if (!defaultValue) defaultValue = "randomize";
if (!options) options = {};
const getName = (defaultName, optionName) => {
let name = defaultName;
if (options[optionName]) {
name = options[optionName];
} else if (typeof inputData?.[1]?.[defaultName] === "string") {
name = inputData?.[1]?.[defaultName];
} else if (inputData?.[1]?.control_prefix) {
name = inputData?.[1]?.control_prefix + " " + name
}
return name;
}
const widgets = [];
const valueControl = node.addWidget(
"combo",
getName("control_after_generate", "controlAfterGenerateName"),
defaultValue,
function () {},
{
values: ["fixed", "increment", "decrement", "randomize"],
serialize: false, // Don't include this in prompt.
}
);
widgets.push(valueControl);
const isCombo = targetWidget.type === "combo";
let comboFilter;
if (isCombo && options.addFilterList !== false) {
comboFilter = node.addWidget(
"string",
getName("control_filter_list", "controlFilterListName"),
"",
function () {},
{
serialize: false, // Don't include this in prompt.
}
);
widgets.push(comboFilter);
}
valueControl.afterQueued = () => {
var v = valueControl.value;
if (targetWidget.type == "combo" && v !== "fixed") {
let current_index = targetWidget.options.values.indexOf(targetWidget.value);
let current_length = targetWidget.options.values.length;
if (isCombo && v !== "fixed") {
let values = targetWidget.options.values;
const filter = comboFilter?.value;
if (filter) {
let check;
if (filter.startsWith("/") && filter.endsWith("/")) {
try {
const regex = new RegExp(filter.substring(1, filter.length - 1));
check = (item) => regex.test(item);
} catch (error) {
console.error("Error constructing RegExp filter for node " + node.id, filter, error);
}
}
if (!check) {
const lower = filter.toLocaleLowerCase();
check = (item) => item.toLocaleLowerCase().includes(lower);
}
values = values.filter(item => check(item));
if (!values.length && targetWidget.options.values.length) {
console.warn("Filter for node " + node.id + " has filtered out all items", filter);
}
}
let current_index = values.indexOf(targetWidget.value);
let current_length = values.length;
switch (v) {
case "increment":
@ -50,11 +122,12 @@ export function addValueControlWidget(node, targetWidget, defaultValue = "random
current_index = Math.max(0, current_index);
current_index = Math.min(current_length - 1, current_index);
if (current_index >= 0) {
let value = targetWidget.options.values[current_index];
let value = values[current_index];
targetWidget.value = value;
targetWidget.callback(value);
}
} else { //number
} else {
//number
let min = targetWidget.options.min;
let max = targetWidget.options.max;
// limit to something that javascript can handle
@ -77,186 +150,68 @@ export function addValueControlWidget(node, targetWidget, defaultValue = "random
default:
break;
}
/*check if values are over or under their respective
* ranges and set them to min or max.*/
if (targetWidget.value < min)
targetWidget.value = min;
/*check if values are over or under their respective
* ranges and set them to min or max.*/
if (targetWidget.value < min) targetWidget.value = min;
if (targetWidget.value > max)
targetWidget.value = max;
targetWidget.callback(targetWidget.value);
}
}
return valueControl;
};
return widgets;
};
function seedWidget(node, inputName, inputData, app) {
const seed = ComfyWidgets.INT(node, inputName, inputData, app);
const seedControl = addValueControlWidget(node, seed.widget, "randomize");
function seedWidget(node, inputName, inputData, app, widgetName) {
const seed = createIntWidget(node, inputName, inputData, app, true);
const seedControl = addValueControlWidget(node, seed.widget, "randomize", undefined, widgetName, inputData);
seed.widget.linkedWidgets = [seedControl];
return seed;
}
const MultilineSymbol = Symbol();
const MultilineResizeSymbol = Symbol();
function createIntWidget(node, inputName, inputData, app, isSeedInput) {
const control = inputData[1]?.control_after_generate;
if (!isSeedInput && control) {
return seedWidget(node, inputName, inputData, app, typeof control === "string" ? control : undefined);
}
let widgetType = isSlider(inputData[1]["display"], app);
const { val, config } = getNumberDefaults(inputData, 1, 0, true);
Object.assign(config, { precision: 0 });
return {
widget: node.addWidget(
widgetType,
inputName,
val,
function (v) {
const s = this.options.step / 10;
this.value = Math.round(v / s) * s;
},
config
),
};
}
function addMultilineWidget(node, name, opts, app) {
const MIN_SIZE = 50;
const inputEl = document.createElement("textarea");
inputEl.className = "comfy-multiline-input";
inputEl.value = opts.defaultVal;
inputEl.placeholder = opts.placeholder || name;
function computeSize(size) {
if (node.widgets[0].last_y == null) return;
let y = node.widgets[0].last_y;
let freeSpace = size[1] - y;
// Compute the height of all non customtext widgets
let widgetHeight = 0;
const multi = [];
for (let i = 0; i < node.widgets.length; i++) {
const w = node.widgets[i];
if (w.type === "customtext") {
multi.push(w);
} else {
if (w.computeSize) {
widgetHeight += w.computeSize()[1] + 4;
} else {
widgetHeight += LiteGraph.NODE_WIDGET_HEIGHT + 4;
}
}
}
// See how large each text input can be
freeSpace -= widgetHeight;
freeSpace /= multi.length + (!!node.imgs?.length);
if (freeSpace < MIN_SIZE) {
// There isnt enough space for all the widgets, increase the size of the node
freeSpace = MIN_SIZE;
node.size[1] = y + widgetHeight + freeSpace * (multi.length + (!!node.imgs?.length));
node.graph.setDirtyCanvas(true);
}
// Position each of the widgets
for (const w of node.widgets) {
w.y = y;
if (w.type === "customtext") {
y += freeSpace;
w.computedHeight = freeSpace - multi.length*4;
} else if (w.computeSize) {
y += w.computeSize()[1] + 4;
} else {
y += LiteGraph.NODE_WIDGET_HEIGHT + 4;
}
}
node.inputHeight = freeSpace;
}
const widget = {
type: "customtext",
name,
get value() {
return this.inputEl.value;
const widget = node.addDOMWidget(name, "customtext", inputEl, {
getValue() {
return inputEl.value;
},
set value(x) {
this.inputEl.value = x;
setValue(v) {
inputEl.value = v;
},
draw: function (ctx, _, widgetWidth, y, widgetHeight) {
if (!this.parent.inputHeight) {
// If we are initially offscreen when created we wont have received a resize event
// Calculate it here instead
computeSize(node.size);
}
const visible = app.canvas.ds.scale > 0.5 && this.type === "customtext";
const margin = 10;
const elRect = ctx.canvas.getBoundingClientRect();
const transform = new DOMMatrix()
.scaleSelf(elRect.width / ctx.canvas.width, elRect.height / ctx.canvas.height)
.multiplySelf(ctx.getTransform())
.translateSelf(margin, margin + y);
const scale = new DOMMatrix().scaleSelf(transform.a, transform.d)
Object.assign(this.inputEl.style, {
transformOrigin: "0 0",
transform: scale,
left: `${transform.a + transform.e}px`,
top: `${transform.d + transform.f}px`,
width: `${widgetWidth - (margin * 2)}px`,
height: `${this.parent.inputHeight - (margin * 2)}px`,
position: "absolute",
background: (!node.color)?'':node.color,
color: (!node.color)?'':'white',
zIndex: app.graph._nodes.indexOf(node),
});
this.inputEl.hidden = !visible;
},
};
widget.inputEl = document.createElement("textarea");
widget.inputEl.className = "comfy-multiline-input";
widget.inputEl.value = opts.defaultVal;
widget.inputEl.placeholder = opts.placeholder || "";
document.addEventListener("mousedown", function (event) {
if (!widget.inputEl.contains(event.target)) {
widget.inputEl.blur();
}
});
widget.parent = node;
document.body.appendChild(widget.inputEl);
widget.inputEl = inputEl;
node.addCustomWidget(widget);
app.canvas.onDrawBackground = function () {
// Draw node isnt fired once the node is off the screen
// if it goes off screen quickly, the input may not be removed
// this shifts it off screen so it can be moved back if the node is visible.
for (let n in app.graph._nodes) {
n = graph._nodes[n];
for (let w in n.widgets) {
let wid = n.widgets[w];
if (Object.hasOwn(wid, "inputEl")) {
wid.inputEl.style.left = -8000 + "px";
wid.inputEl.style.position = "absolute";
}
}
}
};
node.onRemoved = function () {
// When removing this node we need to remove the input from the DOM
for (let y in this.widgets) {
if (this.widgets[y].inputEl) {
this.widgets[y].inputEl.remove();
}
}
};
widget.onRemove = () => {
widget.inputEl?.remove();
// Restore original size handler if we are the last
if (!--node[MultilineSymbol]) {
node.onResize = node[MultilineResizeSymbol];
delete node[MultilineSymbol];
delete node[MultilineResizeSymbol];
}
};
if (node[MultilineSymbol]) {
node[MultilineSymbol]++;
} else {
node[MultilineSymbol] = 1;
const onResize = (node[MultilineResizeSymbol] = node.onResize);
node.onResize = function (size) {
computeSize(size);
// Call original resizer handler
if (onResize) {
onResize.apply(this, arguments);
}
};
}
inputEl.addEventListener("input", () => {
widget.callback?.(widget.value);
});
return { minWidth: 400, minHeight: 200, widget };
}
@ -288,31 +243,26 @@ export const ComfyWidgets = {
}, config) };
},
INT(node, inputName, inputData, app) {
let widgetType = isSlider(inputData[1]["display"], app);
const { val, config } = getNumberDefaults(inputData, 1, 0, true);
Object.assign(config, { precision: 0 });
return {
widget: node.addWidget(
widgetType,
inputName,
val,
function (v) {
const s = this.options.step / 10;
this.value = Math.round(v / s) * s;
},
config
),
};
return createIntWidget(node, inputName, inputData, app);
},
BOOLEAN(node, inputName, inputData) {
let defaultVal = inputData[1]["default"];
let defaultVal = false;
let options = {};
if (inputData[1]) {
if (inputData[1].default)
defaultVal = inputData[1].default;
if (inputData[1].label_on)
options["on"] = inputData[1].label_on;
if (inputData[1].label_off)
options["off"] = inputData[1].label_off;
}
return {
widget: node.addWidget(
"toggle",
inputName,
defaultVal,
() => {},
{"on": inputData[1].label_on, "off": inputData[1].label_off}
options,
)
};
},
@ -338,10 +288,14 @@ export const ComfyWidgets = {
if (inputData[1] && inputData[1].default) {
defaultValue = inputData[1].default;
}
return { widget: node.addWidget("combo", inputName, defaultValue, () => {}, { values: type }) };
const res = { widget: node.addWidget("combo", inputName, defaultValue, () => {}, { values: type }) };
if (inputData[1]?.control_after_generate) {
res.widget.linkedWidgets = addValueControlWidgets(node, res.widget, undefined, undefined, inputData);
}
return res;
},
IMAGEUPLOAD(node, inputName, inputData, app) {
const imageWidget = node.widgets.find((w) => w.name === "image");
const imageWidget = node.widgets.find((w) => w.name === (inputData[1]?.widget ?? "image"));
let uploadWidget;
function showImage(name) {
@ -455,9 +409,10 @@ export const ComfyWidgets = {
document.body.append(fileInput);
// Create the button widget for selecting the files
uploadWidget = node.addWidget("button", "choose file to upload", "image", () => {
uploadWidget = node.addWidget("button", inputName, "image", () => {
fileInput.click();
});
uploadWidget.label = "choose file to upload";
uploadWidget.serialize = false;
// Add handler to check if an image is being dragged over our node

View File

@ -409,6 +409,21 @@ dialog::backdrop {
width: calc(100% - 10px);
}
.comfy-img-preview {
pointer-events: none;
overflow: hidden;
display: flex;
flex-wrap: wrap;
align-content: flex-start;
justify-content: center;
}
.comfy-img-preview img {
object-fit: contain;
width: var(--comfy-img-preview-width);
height: var(--comfy-img-preview-height);
}
/* Search box */
.litegraph.litesearchbox {