mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Merge branch 'master' into image-cache
This commit is contained in:
commit
c92f3dca73
9
.vscode/settings.json
vendored
Normal file
9
.vscode/settings.json
vendored
Normal file
@ -0,0 +1,9 @@
|
||||
{
|
||||
"path-intellisense.mappings": {
|
||||
"../": "${workspaceFolder}/web/extensions/core"
|
||||
},
|
||||
"[python]": {
|
||||
"editor.defaultFormatter": "ms-python.autopep8"
|
||||
},
|
||||
"python.formatting.provider": "none"
|
||||
}
|
@ -11,7 +11,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
|
||||
|
||||
## Features
|
||||
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
|
||||
- Fully supports SD1.x, SD2.x and SDXL
|
||||
- Fully supports SD1.x, SD2.x, [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/) and [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/)
|
||||
- Asynchronous Queue system
|
||||
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
|
||||
- Command line option: ```--lowvram``` to make it work on GPUs with less than 3GB vram (enabled automatically on GPUs with low vram)
|
||||
@ -30,6 +30,8 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
|
||||
- [unCLIP Models](https://comfyanonymous.github.io/ComfyUI_examples/unclip/)
|
||||
- [GLIGEN](https://comfyanonymous.github.io/ComfyUI_examples/gligen/)
|
||||
- [Model Merging](https://comfyanonymous.github.io/ComfyUI_examples/model_merging/)
|
||||
- [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/)
|
||||
- [SDXL Turbo](https://comfyanonymous.github.io/ComfyUI_examples/sdturbo/)
|
||||
- Latent previews with [TAESD](#how-to-show-high-quality-previews)
|
||||
- Starts up very fast.
|
||||
- Works fully offline: will never download anything.
|
||||
@ -43,6 +45,7 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
|
||||
|---------------------------|--------------------------------------------------------------------------------------------------------------------|
|
||||
| Ctrl + Enter | Queue up current graph for generation |
|
||||
| Ctrl + Shift + Enter | Queue up current graph as first for generation |
|
||||
| Ctrl + Z/Ctrl + Y | Undo/Redo |
|
||||
| Ctrl + S | Save workflow |
|
||||
| Ctrl + O | Load workflow |
|
||||
| Ctrl + A | Select all nodes |
|
||||
@ -98,6 +101,7 @@ AMD users can install rocm and pytorch with pip if you don't have it already ins
|
||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm5.6```
|
||||
|
||||
This is the command to install the nightly with ROCm 5.7 that might have some performance improvements:
|
||||
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm5.7```
|
||||
|
||||
### NVIDIA
|
||||
@ -190,7 +194,7 @@ To use a textual inversion concepts/embeddings in a text prompt put them in the
|
||||
|
||||
Make sure you use the regular loaders/Load Checkpoint node to load checkpoints. It will auto pick the right settings depending on your GPU.
|
||||
|
||||
You can set this command line setting to disable the upcasting to fp32 in some cross attention operations which will increase your speed. Note that this will very likely give you black images on SD2.x models. If you use xformers this option does not do anything.
|
||||
You can set this command line setting to disable the upcasting to fp32 in some cross attention operations which will increase your speed. Note that this will very likely give you black images on SD2.x models. If you use xformers or pytorch attention this option does not do anything.
|
||||
|
||||
```--dont-upcast-attention```
|
||||
|
||||
|
@ -54,6 +54,7 @@ class ControlNet(nn.Module):
|
||||
transformer_depth_output=None,
|
||||
device=None,
|
||||
operations=comfy.ops,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
assert use_spatial_transformer == True, "use_spatial_transformer has to be true"
|
||||
|
@ -62,6 +62,13 @@ fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in
|
||||
fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.")
|
||||
fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16.")
|
||||
|
||||
fpte_group = parser.add_mutually_exclusive_group()
|
||||
fpte_group.add_argument("--fp8_e4m3fn-text-enc", action="store_true", help="Store text encoder weights in fp8 (e4m3fn variant).")
|
||||
fpte_group.add_argument("--fp8_e5m2-text-enc", action="store_true", help="Store text encoder weights in fp8 (e5m2 variant).")
|
||||
fpte_group.add_argument("--fp16-text-enc", action="store_true", help="Store text encoder weights in fp16.")
|
||||
fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Store text encoder weights in fp32.")
|
||||
|
||||
|
||||
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
|
||||
|
||||
parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize when loading models with Intel GPUs.")
|
||||
|
@ -33,7 +33,7 @@ class ControlBase:
|
||||
self.cond_hint_original = None
|
||||
self.cond_hint = None
|
||||
self.strength = 1.0
|
||||
self.timestep_percent_range = (1.0, 0.0)
|
||||
self.timestep_percent_range = (0.0, 1.0)
|
||||
self.timestep_range = None
|
||||
|
||||
if device is None:
|
||||
@ -42,7 +42,7 @@ class ControlBase:
|
||||
self.previous_controlnet = None
|
||||
self.global_average_pooling = False
|
||||
|
||||
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(1.0, 0.0)):
|
||||
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0)):
|
||||
self.cond_hint_original = cond_hint
|
||||
self.strength = strength
|
||||
self.timestep_percent_range = timestep_percent_range
|
||||
|
@ -858,7 +858,7 @@ def predict_eps_sigma(model, input, sigma_in, **kwargs):
|
||||
return (input - model(input, sigma_in, **kwargs)) / sigma
|
||||
|
||||
|
||||
def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, extra_args=None, callback=None, disable=False, noise_mask=None, variant='bh1'):
|
||||
def sample_unipc(model, noise, image, sigmas, max_denoise, extra_args=None, callback=None, disable=False, noise_mask=None, variant='bh1'):
|
||||
timesteps = sigmas.clone()
|
||||
if sigmas[-1] == 0:
|
||||
timesteps = sigmas[:]
|
||||
|
@ -750,3 +750,61 @@ def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, n
|
||||
if sigmas[i + 1] > 0:
|
||||
x += sigmas[i + 1] * noise_sampler(sigmas[i], sigmas[i + 1])
|
||||
return x
|
||||
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_heunpp2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
|
||||
# From MIT licensed: https://github.com/Carzit/sd-webui-samplers-scheduler/
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
s_end = sigmas[-1]
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
||||
eps = torch.randn_like(x) * s_noise
|
||||
sigma_hat = sigmas[i] * (gamma + 1)
|
||||
if gamma > 0:
|
||||
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
|
||||
denoised = model(x, sigma_hat * s_in, **extra_args)
|
||||
d = to_d(x, sigma_hat, denoised)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
|
||||
dt = sigmas[i + 1] - sigma_hat
|
||||
if sigmas[i + 1] == s_end:
|
||||
# Euler method
|
||||
x = x + d * dt
|
||||
elif sigmas[i + 2] == s_end:
|
||||
|
||||
# Heun's method
|
||||
x_2 = x + d * dt
|
||||
denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
|
||||
d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
|
||||
|
||||
w = 2 * sigmas[0]
|
||||
w2 = sigmas[i+1]/w
|
||||
w1 = 1 - w2
|
||||
|
||||
d_prime = d * w1 + d_2 * w2
|
||||
|
||||
|
||||
x = x + d_prime * dt
|
||||
|
||||
else:
|
||||
# Heun++
|
||||
x_2 = x + d * dt
|
||||
denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
|
||||
d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
|
||||
dt_2 = sigmas[i + 2] - sigmas[i + 1]
|
||||
|
||||
x_3 = x_2 + d_2 * dt_2
|
||||
denoised_3 = model(x_3, sigmas[i + 2] * s_in, **extra_args)
|
||||
d_3 = to_d(x_3, sigmas[i + 2], denoised_3)
|
||||
|
||||
w = 3 * sigmas[0]
|
||||
w2 = sigmas[i + 1] / w
|
||||
w3 = sigmas[i + 2] / w
|
||||
w1 = 1 - w2 - w3
|
||||
|
||||
d_prime = w1 * d + w2 * d_2 + w3 * d_3
|
||||
x = x + d_prime * dt
|
||||
return x
|
||||
|
@ -5,8 +5,10 @@ import torch.nn.functional as F
|
||||
from torch import nn, einsum
|
||||
from einops import rearrange, repeat
|
||||
from typing import Optional, Any
|
||||
from functools import partial
|
||||
|
||||
from .diffusionmodules.util import checkpoint
|
||||
|
||||
from .diffusionmodules.util import checkpoint, AlphaBlender, timestep_embedding
|
||||
from .sub_quadratic_attention import efficient_dot_product_attention
|
||||
|
||||
from comfy import model_management
|
||||
@ -276,9 +278,20 @@ def attention_split(q, k, v, heads, mask=None):
|
||||
)
|
||||
return r1
|
||||
|
||||
BROKEN_XFORMERS = False
|
||||
try:
|
||||
x_vers = xformers.__version__
|
||||
#I think 0.0.23 is also broken (q with bs bigger than 65535 gives CUDA error)
|
||||
BROKEN_XFORMERS = x_vers.startswith("0.0.21") or x_vers.startswith("0.0.22") or x_vers.startswith("0.0.23")
|
||||
except:
|
||||
pass
|
||||
|
||||
def attention_xformers(q, k, v, heads, mask=None):
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
if BROKEN_XFORMERS:
|
||||
if b * heads > 65535:
|
||||
return attention_pytorch(q, k, v, heads, mask)
|
||||
|
||||
q, k, v = map(
|
||||
lambda t: t.unsqueeze(3)
|
||||
@ -370,53 +383,72 @@ class CrossAttention(nn.Module):
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
|
||||
disable_self_attn=False, dtype=None, device=None, operations=comfy.ops):
|
||||
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, ff_in=False, inner_dim=None,
|
||||
disable_self_attn=False, disable_temporal_crossattention=False, switch_temporal_ca_to_sa=False, dtype=None, device=None, operations=comfy.ops):
|
||||
super().__init__()
|
||||
|
||||
self.ff_in = ff_in or inner_dim is not None
|
||||
if inner_dim is None:
|
||||
inner_dim = dim
|
||||
|
||||
self.is_res = inner_dim == dim
|
||||
|
||||
if self.ff_in:
|
||||
self.norm_in = nn.LayerNorm(dim, dtype=dtype, device=device)
|
||||
self.ff_in = FeedForward(dim, dim_out=inner_dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.disable_self_attn = disable_self_attn
|
||||
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
|
||||
self.attn1 = CrossAttention(query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout,
|
||||
context_dim=context_dim if self.disable_self_attn else None, dtype=dtype, device=device, operations=operations) # is a self-attention if not self.disable_self_attn
|
||||
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations)
|
||||
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
|
||||
heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device, operations=operations) # is self-attn if context is none
|
||||
self.norm1 = nn.LayerNorm(dim, dtype=dtype, device=device)
|
||||
self.norm2 = nn.LayerNorm(dim, dtype=dtype, device=device)
|
||||
self.norm3 = nn.LayerNorm(dim, dtype=dtype, device=device)
|
||||
self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
if disable_temporal_crossattention:
|
||||
if switch_temporal_ca_to_sa:
|
||||
raise ValueError
|
||||
else:
|
||||
self.attn2 = None
|
||||
else:
|
||||
context_dim_attn2 = None
|
||||
if not switch_temporal_ca_to_sa:
|
||||
context_dim_attn2 = context_dim
|
||||
|
||||
self.attn2 = CrossAttention(query_dim=inner_dim, context_dim=context_dim_attn2,
|
||||
heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device, operations=operations) # is self-attn if context is none
|
||||
self.norm2 = nn.LayerNorm(inner_dim, dtype=dtype, device=device)
|
||||
|
||||
self.norm1 = nn.LayerNorm(inner_dim, dtype=dtype, device=device)
|
||||
self.norm3 = nn.LayerNorm(inner_dim, dtype=dtype, device=device)
|
||||
self.checkpoint = checkpoint
|
||||
self.n_heads = n_heads
|
||||
self.d_head = d_head
|
||||
self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa
|
||||
|
||||
def forward(self, x, context=None, transformer_options={}):
|
||||
return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint)
|
||||
|
||||
def _forward(self, x, context=None, transformer_options={}):
|
||||
extra_options = {}
|
||||
block = None
|
||||
block_index = 0
|
||||
if "current_index" in transformer_options:
|
||||
extra_options["transformer_index"] = transformer_options["current_index"]
|
||||
if "block_index" in transformer_options:
|
||||
block_index = transformer_options["block_index"]
|
||||
extra_options["block_index"] = block_index
|
||||
if "original_shape" in transformer_options:
|
||||
extra_options["original_shape"] = transformer_options["original_shape"]
|
||||
if "block" in transformer_options:
|
||||
block = transformer_options["block"]
|
||||
extra_options["block"] = block
|
||||
if "cond_or_uncond" in transformer_options:
|
||||
extra_options["cond_or_uncond"] = transformer_options["cond_or_uncond"]
|
||||
if "patches" in transformer_options:
|
||||
transformer_patches = transformer_options["patches"]
|
||||
else:
|
||||
transformer_patches = {}
|
||||
block = transformer_options.get("block", None)
|
||||
block_index = transformer_options.get("block_index", 0)
|
||||
transformer_patches = {}
|
||||
transformer_patches_replace = {}
|
||||
|
||||
for k in transformer_options:
|
||||
if k == "patches":
|
||||
transformer_patches = transformer_options[k]
|
||||
elif k == "patches_replace":
|
||||
transformer_patches_replace = transformer_options[k]
|
||||
else:
|
||||
extra_options[k] = transformer_options[k]
|
||||
|
||||
extra_options["n_heads"] = self.n_heads
|
||||
extra_options["dim_head"] = self.d_head
|
||||
|
||||
if "patches_replace" in transformer_options:
|
||||
transformer_patches_replace = transformer_options["patches_replace"]
|
||||
else:
|
||||
transformer_patches_replace = {}
|
||||
if self.ff_in:
|
||||
x_skip = x
|
||||
x = self.ff_in(self.norm_in(x))
|
||||
if self.is_res:
|
||||
x += x_skip
|
||||
|
||||
n = self.norm1(x)
|
||||
if self.disable_self_attn:
|
||||
@ -465,31 +497,34 @@ class BasicTransformerBlock(nn.Module):
|
||||
for p in patch:
|
||||
x = p(x, extra_options)
|
||||
|
||||
n = self.norm2(x)
|
||||
|
||||
context_attn2 = context
|
||||
value_attn2 = None
|
||||
if "attn2_patch" in transformer_patches:
|
||||
patch = transformer_patches["attn2_patch"]
|
||||
value_attn2 = context_attn2
|
||||
for p in patch:
|
||||
n, context_attn2, value_attn2 = p(n, context_attn2, value_attn2, extra_options)
|
||||
|
||||
attn2_replace_patch = transformer_patches_replace.get("attn2", {})
|
||||
block_attn2 = transformer_block
|
||||
if block_attn2 not in attn2_replace_patch:
|
||||
block_attn2 = block
|
||||
|
||||
if block_attn2 in attn2_replace_patch:
|
||||
if value_attn2 is None:
|
||||
if self.attn2 is not None:
|
||||
n = self.norm2(x)
|
||||
if self.switch_temporal_ca_to_sa:
|
||||
context_attn2 = n
|
||||
else:
|
||||
context_attn2 = context
|
||||
value_attn2 = None
|
||||
if "attn2_patch" in transformer_patches:
|
||||
patch = transformer_patches["attn2_patch"]
|
||||
value_attn2 = context_attn2
|
||||
n = self.attn2.to_q(n)
|
||||
context_attn2 = self.attn2.to_k(context_attn2)
|
||||
value_attn2 = self.attn2.to_v(value_attn2)
|
||||
n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options)
|
||||
n = self.attn2.to_out(n)
|
||||
else:
|
||||
n = self.attn2(n, context=context_attn2, value=value_attn2)
|
||||
for p in patch:
|
||||
n, context_attn2, value_attn2 = p(n, context_attn2, value_attn2, extra_options)
|
||||
|
||||
attn2_replace_patch = transformer_patches_replace.get("attn2", {})
|
||||
block_attn2 = transformer_block
|
||||
if block_attn2 not in attn2_replace_patch:
|
||||
block_attn2 = block
|
||||
|
||||
if block_attn2 in attn2_replace_patch:
|
||||
if value_attn2 is None:
|
||||
value_attn2 = context_attn2
|
||||
n = self.attn2.to_q(n)
|
||||
context_attn2 = self.attn2.to_k(context_attn2)
|
||||
value_attn2 = self.attn2.to_v(value_attn2)
|
||||
n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options)
|
||||
n = self.attn2.to_out(n)
|
||||
else:
|
||||
n = self.attn2(n, context=context_attn2, value=value_attn2)
|
||||
|
||||
if "attn2_output_patch" in transformer_patches:
|
||||
patch = transformer_patches["attn2_output_patch"]
|
||||
@ -497,7 +532,12 @@ class BasicTransformerBlock(nn.Module):
|
||||
n = p(n, extra_options)
|
||||
|
||||
x += n
|
||||
x = self.ff(self.norm3(x)) + x
|
||||
if self.is_res:
|
||||
x_skip = x
|
||||
x = self.ff(self.norm3(x))
|
||||
if self.is_res:
|
||||
x += x_skip
|
||||
|
||||
return x
|
||||
|
||||
|
||||
@ -565,3 +605,164 @@ class SpatialTransformer(nn.Module):
|
||||
x = self.proj_out(x)
|
||||
return x + x_in
|
||||
|
||||
|
||||
class SpatialVideoTransformer(SpatialTransformer):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
n_heads,
|
||||
d_head,
|
||||
depth=1,
|
||||
dropout=0.0,
|
||||
use_linear=False,
|
||||
context_dim=None,
|
||||
use_spatial_context=False,
|
||||
timesteps=None,
|
||||
merge_strategy: str = "fixed",
|
||||
merge_factor: float = 0.5,
|
||||
time_context_dim=None,
|
||||
ff_in=False,
|
||||
checkpoint=False,
|
||||
time_depth=1,
|
||||
disable_self_attn=False,
|
||||
disable_temporal_crossattention=False,
|
||||
max_time_embed_period: int = 10000,
|
||||
dtype=None, device=None, operations=comfy.ops
|
||||
):
|
||||
super().__init__(
|
||||
in_channels,
|
||||
n_heads,
|
||||
d_head,
|
||||
depth=depth,
|
||||
dropout=dropout,
|
||||
use_checkpoint=checkpoint,
|
||||
context_dim=context_dim,
|
||||
use_linear=use_linear,
|
||||
disable_self_attn=disable_self_attn,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
self.time_depth = time_depth
|
||||
self.depth = depth
|
||||
self.max_time_embed_period = max_time_embed_period
|
||||
|
||||
time_mix_d_head = d_head
|
||||
n_time_mix_heads = n_heads
|
||||
|
||||
time_mix_inner_dim = int(time_mix_d_head * n_time_mix_heads)
|
||||
|
||||
inner_dim = n_heads * d_head
|
||||
if use_spatial_context:
|
||||
time_context_dim = context_dim
|
||||
|
||||
self.time_stack = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
inner_dim,
|
||||
n_time_mix_heads,
|
||||
time_mix_d_head,
|
||||
dropout=dropout,
|
||||
context_dim=time_context_dim,
|
||||
# timesteps=timesteps,
|
||||
checkpoint=checkpoint,
|
||||
ff_in=ff_in,
|
||||
inner_dim=time_mix_inner_dim,
|
||||
disable_self_attn=disable_self_attn,
|
||||
disable_temporal_crossattention=disable_temporal_crossattention,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
for _ in range(self.depth)
|
||||
]
|
||||
)
|
||||
|
||||
assert len(self.time_stack) == len(self.transformer_blocks)
|
||||
|
||||
self.use_spatial_context = use_spatial_context
|
||||
self.in_channels = in_channels
|
||||
|
||||
time_embed_dim = self.in_channels * 4
|
||||
self.time_pos_embed = nn.Sequential(
|
||||
operations.Linear(self.in_channels, time_embed_dim, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Linear(time_embed_dim, self.in_channels, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
self.time_mixer = AlphaBlender(
|
||||
alpha=merge_factor, merge_strategy=merge_strategy
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
time_context: Optional[torch.Tensor] = None,
|
||||
timesteps: Optional[int] = None,
|
||||
image_only_indicator: Optional[torch.Tensor] = None,
|
||||
transformer_options={}
|
||||
) -> torch.Tensor:
|
||||
_, _, h, w = x.shape
|
||||
x_in = x
|
||||
spatial_context = None
|
||||
if exists(context):
|
||||
spatial_context = context
|
||||
|
||||
if self.use_spatial_context:
|
||||
assert (
|
||||
context.ndim == 3
|
||||
), f"n dims of spatial context should be 3 but are {context.ndim}"
|
||||
|
||||
if time_context is None:
|
||||
time_context = context
|
||||
time_context_first_timestep = time_context[::timesteps]
|
||||
time_context = repeat(
|
||||
time_context_first_timestep, "b ... -> (b n) ...", n=h * w
|
||||
)
|
||||
elif time_context is not None and not self.use_spatial_context:
|
||||
time_context = repeat(time_context, "b ... -> (b n) ...", n=h * w)
|
||||
if time_context.ndim == 2:
|
||||
time_context = rearrange(time_context, "b c -> b 1 c")
|
||||
|
||||
x = self.norm(x)
|
||||
if not self.use_linear:
|
||||
x = self.proj_in(x)
|
||||
x = rearrange(x, "b c h w -> b (h w) c")
|
||||
if self.use_linear:
|
||||
x = self.proj_in(x)
|
||||
|
||||
num_frames = torch.arange(timesteps, device=x.device)
|
||||
num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
|
||||
num_frames = rearrange(num_frames, "b t -> (b t)")
|
||||
t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False, max_period=self.max_time_embed_period).to(x.dtype)
|
||||
emb = self.time_pos_embed(t_emb)
|
||||
emb = emb[:, None, :]
|
||||
|
||||
for it_, (block, mix_block) in enumerate(
|
||||
zip(self.transformer_blocks, self.time_stack)
|
||||
):
|
||||
transformer_options["block_index"] = it_
|
||||
x = block(
|
||||
x,
|
||||
context=spatial_context,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
x_mix = x
|
||||
x_mix = x_mix + emb
|
||||
|
||||
B, S, C = x_mix.shape
|
||||
x_mix = rearrange(x_mix, "(b t) s c -> (b s) t c", t=timesteps)
|
||||
x_mix = mix_block(x_mix, context=time_context) #TODO: transformer_options
|
||||
x_mix = rearrange(
|
||||
x_mix, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps
|
||||
)
|
||||
|
||||
x = self.time_mixer(x_spatial=x, x_temporal=x_mix, image_only_indicator=image_only_indicator)
|
||||
|
||||
if self.use_linear:
|
||||
x = self.proj_out(x)
|
||||
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
|
||||
if not self.use_linear:
|
||||
x = self.proj_out(x)
|
||||
out = x + x_in
|
||||
return out
|
||||
|
||||
|
||||
|
@ -5,6 +5,8 @@ import numpy as np
|
||||
import torch as th
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from functools import partial
|
||||
|
||||
from .util import (
|
||||
checkpoint,
|
||||
@ -12,8 +14,9 @@ from .util import (
|
||||
zero_module,
|
||||
normalization,
|
||||
timestep_embedding,
|
||||
AlphaBlender,
|
||||
)
|
||||
from ..attention import SpatialTransformer
|
||||
from ..attention import SpatialTransformer, SpatialVideoTransformer, default
|
||||
from comfy.ldm.util import exists
|
||||
import comfy.ops
|
||||
|
||||
@ -28,6 +31,26 @@ class TimestepBlock(nn.Module):
|
||||
Apply the module to `x` given `emb` timestep embeddings.
|
||||
"""
|
||||
|
||||
#This is needed because accelerate makes a copy of transformer_options which breaks "transformer_index"
|
||||
def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None, time_context=None, num_video_frames=None, image_only_indicator=None):
|
||||
for layer in ts:
|
||||
if isinstance(layer, VideoResBlock):
|
||||
x = layer(x, emb, num_video_frames, image_only_indicator)
|
||||
elif isinstance(layer, TimestepBlock):
|
||||
x = layer(x, emb)
|
||||
elif isinstance(layer, SpatialVideoTransformer):
|
||||
x = layer(x, context, time_context, num_video_frames, image_only_indicator, transformer_options)
|
||||
if "transformer_index" in transformer_options:
|
||||
transformer_options["transformer_index"] += 1
|
||||
elif isinstance(layer, SpatialTransformer):
|
||||
x = layer(x, context, transformer_options)
|
||||
if "transformer_index" in transformer_options:
|
||||
transformer_options["transformer_index"] += 1
|
||||
elif isinstance(layer, Upsample):
|
||||
x = layer(x, output_shape=output_shape)
|
||||
else:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
||||
"""
|
||||
@ -35,31 +58,8 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
||||
support it as an extra input.
|
||||
"""
|
||||
|
||||
def forward(self, x, emb, context=None, transformer_options={}, output_shape=None):
|
||||
for layer in self:
|
||||
if isinstance(layer, TimestepBlock):
|
||||
x = layer(x, emb)
|
||||
elif isinstance(layer, SpatialTransformer):
|
||||
x = layer(x, context, transformer_options)
|
||||
elif isinstance(layer, Upsample):
|
||||
x = layer(x, output_shape=output_shape)
|
||||
else:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
#This is needed because accelerate makes a copy of transformer_options which breaks "current_index"
|
||||
def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None):
|
||||
for layer in ts:
|
||||
if isinstance(layer, TimestepBlock):
|
||||
x = layer(x, emb)
|
||||
elif isinstance(layer, SpatialTransformer):
|
||||
x = layer(x, context, transformer_options)
|
||||
transformer_options["current_index"] += 1
|
||||
elif isinstance(layer, Upsample):
|
||||
x = layer(x, output_shape=output_shape)
|
||||
else:
|
||||
x = layer(x)
|
||||
return x
|
||||
def forward(self, *args, **kwargs):
|
||||
return forward_timestep_embed(self, *args, **kwargs)
|
||||
|
||||
class Upsample(nn.Module):
|
||||
"""
|
||||
@ -154,6 +154,9 @@ class ResBlock(TimestepBlock):
|
||||
use_checkpoint=False,
|
||||
up=False,
|
||||
down=False,
|
||||
kernel_size=3,
|
||||
exchange_temb_dims=False,
|
||||
skip_t_emb=False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=comfy.ops
|
||||
@ -166,11 +169,17 @@ class ResBlock(TimestepBlock):
|
||||
self.use_conv = use_conv
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.use_scale_shift_norm = use_scale_shift_norm
|
||||
self.exchange_temb_dims = exchange_temb_dims
|
||||
|
||||
if isinstance(kernel_size, list):
|
||||
padding = [k // 2 for k in kernel_size]
|
||||
else:
|
||||
padding = kernel_size // 2
|
||||
|
||||
self.in_layers = nn.Sequential(
|
||||
nn.GroupNorm(32, channels, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=dtype, device=device),
|
||||
operations.conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
self.updown = up or down
|
||||
@ -184,19 +193,24 @@ class ResBlock(TimestepBlock):
|
||||
else:
|
||||
self.h_upd = self.x_upd = nn.Identity()
|
||||
|
||||
self.emb_layers = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(
|
||||
emb_channels,
|
||||
2 * self.out_channels if use_scale_shift_norm else self.out_channels, dtype=dtype, device=device
|
||||
),
|
||||
)
|
||||
self.skip_t_emb = skip_t_emb
|
||||
if self.skip_t_emb:
|
||||
self.emb_layers = None
|
||||
self.exchange_temb_dims = False
|
||||
else:
|
||||
self.emb_layers = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(
|
||||
emb_channels,
|
||||
2 * self.out_channels if use_scale_shift_norm else self.out_channels, dtype=dtype, device=device
|
||||
),
|
||||
)
|
||||
self.out_layers = nn.Sequential(
|
||||
nn.GroupNorm(32, self.out_channels, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
nn.Dropout(p=dropout),
|
||||
zero_module(
|
||||
operations.conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1, dtype=dtype, device=device)
|
||||
operations.conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device)
|
||||
),
|
||||
)
|
||||
|
||||
@ -204,7 +218,7 @@ class ResBlock(TimestepBlock):
|
||||
self.skip_connection = nn.Identity()
|
||||
elif use_conv:
|
||||
self.skip_connection = operations.conv_nd(
|
||||
dims, channels, self.out_channels, 3, padding=1, dtype=dtype, device=device
|
||||
dims, channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device
|
||||
)
|
||||
else:
|
||||
self.skip_connection = operations.conv_nd(dims, channels, self.out_channels, 1, dtype=dtype, device=device)
|
||||
@ -230,19 +244,110 @@ class ResBlock(TimestepBlock):
|
||||
h = in_conv(h)
|
||||
else:
|
||||
h = self.in_layers(x)
|
||||
emb_out = self.emb_layers(emb).type(h.dtype)
|
||||
while len(emb_out.shape) < len(h.shape):
|
||||
emb_out = emb_out[..., None]
|
||||
|
||||
emb_out = None
|
||||
if not self.skip_t_emb:
|
||||
emb_out = self.emb_layers(emb).type(h.dtype)
|
||||
while len(emb_out.shape) < len(h.shape):
|
||||
emb_out = emb_out[..., None]
|
||||
if self.use_scale_shift_norm:
|
||||
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
|
||||
scale, shift = th.chunk(emb_out, 2, dim=1)
|
||||
h = out_norm(h) * (1 + scale) + shift
|
||||
h = out_norm(h)
|
||||
if emb_out is not None:
|
||||
scale, shift = th.chunk(emb_out, 2, dim=1)
|
||||
h *= (1 + scale)
|
||||
h += shift
|
||||
h = out_rest(h)
|
||||
else:
|
||||
h = h + emb_out
|
||||
if emb_out is not None:
|
||||
if self.exchange_temb_dims:
|
||||
emb_out = rearrange(emb_out, "b t c ... -> b c t ...")
|
||||
h = h + emb_out
|
||||
h = self.out_layers(h)
|
||||
return self.skip_connection(x) + h
|
||||
|
||||
|
||||
class VideoResBlock(ResBlock):
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
emb_channels: int,
|
||||
dropout: float,
|
||||
video_kernel_size=3,
|
||||
merge_strategy: str = "fixed",
|
||||
merge_factor: float = 0.5,
|
||||
out_channels=None,
|
||||
use_conv: bool = False,
|
||||
use_scale_shift_norm: bool = False,
|
||||
dims: int = 2,
|
||||
use_checkpoint: bool = False,
|
||||
up: bool = False,
|
||||
down: bool = False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=comfy.ops
|
||||
):
|
||||
super().__init__(
|
||||
channels,
|
||||
emb_channels,
|
||||
dropout,
|
||||
out_channels=out_channels,
|
||||
use_conv=use_conv,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
up=up,
|
||||
down=down,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
)
|
||||
|
||||
self.time_stack = ResBlock(
|
||||
default(out_channels, channels),
|
||||
emb_channels,
|
||||
dropout=dropout,
|
||||
dims=3,
|
||||
out_channels=default(out_channels, channels),
|
||||
use_scale_shift_norm=False,
|
||||
use_conv=False,
|
||||
up=False,
|
||||
down=False,
|
||||
kernel_size=video_kernel_size,
|
||||
use_checkpoint=use_checkpoint,
|
||||
exchange_temb_dims=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
)
|
||||
self.time_mixer = AlphaBlender(
|
||||
alpha=merge_factor,
|
||||
merge_strategy=merge_strategy,
|
||||
rearrange_pattern="b t -> b 1 t 1 1",
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: th.Tensor,
|
||||
emb: th.Tensor,
|
||||
num_video_frames: int,
|
||||
image_only_indicator = None,
|
||||
) -> th.Tensor:
|
||||
x = super().forward(x, emb)
|
||||
|
||||
x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames)
|
||||
x = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames)
|
||||
|
||||
x = self.time_stack(
|
||||
x, rearrange(emb, "(b t) ... -> b t ...", t=num_video_frames)
|
||||
)
|
||||
x = self.time_mixer(
|
||||
x_spatial=x_mix, x_temporal=x, image_only_indicator=image_only_indicator
|
||||
)
|
||||
x = rearrange(x, "b c t h w -> (b t) c h w")
|
||||
return x
|
||||
|
||||
|
||||
class Timestep(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
@ -255,7 +360,10 @@ def apply_control(h, control, name):
|
||||
if control is not None and name in control and len(control[name]) > 0:
|
||||
ctrl = control[name].pop()
|
||||
if ctrl is not None:
|
||||
h += ctrl
|
||||
try:
|
||||
h += ctrl
|
||||
except:
|
||||
print("warning control could not be applied", h.shape, ctrl.shape)
|
||||
return h
|
||||
|
||||
class UNetModel(nn.Module):
|
||||
@ -316,6 +424,16 @@ class UNetModel(nn.Module):
|
||||
adm_in_channels=None,
|
||||
transformer_depth_middle=None,
|
||||
transformer_depth_output=None,
|
||||
use_temporal_resblock=False,
|
||||
use_temporal_attention=False,
|
||||
time_context_dim=None,
|
||||
extra_ff_mix_layer=False,
|
||||
use_spatial_context=False,
|
||||
merge_strategy=None,
|
||||
merge_factor=0.0,
|
||||
video_kernel_size=None,
|
||||
disable_temporal_crossattention=False,
|
||||
max_ddpm_temb_period=10000,
|
||||
device=None,
|
||||
operations=comfy.ops,
|
||||
):
|
||||
@ -370,8 +488,12 @@ class UNetModel(nn.Module):
|
||||
self.num_heads = num_heads
|
||||
self.num_head_channels = num_head_channels
|
||||
self.num_heads_upsample = num_heads_upsample
|
||||
self.use_temporal_resblocks = use_temporal_resblock
|
||||
self.predict_codebook_ids = n_embed is not None
|
||||
|
||||
self.default_num_video_frames = None
|
||||
self.default_image_only_indicator = None
|
||||
|
||||
time_embed_dim = model_channels * 4
|
||||
self.time_embed = nn.Sequential(
|
||||
operations.Linear(model_channels, time_embed_dim, dtype=self.dtype, device=device),
|
||||
@ -408,13 +530,104 @@ class UNetModel(nn.Module):
|
||||
input_block_chans = [model_channels]
|
||||
ch = model_channels
|
||||
ds = 1
|
||||
|
||||
def get_attention_layer(
|
||||
ch,
|
||||
num_heads,
|
||||
dim_head,
|
||||
depth=1,
|
||||
context_dim=None,
|
||||
use_checkpoint=False,
|
||||
disable_self_attn=False,
|
||||
):
|
||||
if use_temporal_attention:
|
||||
return SpatialVideoTransformer(
|
||||
ch,
|
||||
num_heads,
|
||||
dim_head,
|
||||
depth=depth,
|
||||
context_dim=context_dim,
|
||||
time_context_dim=time_context_dim,
|
||||
dropout=dropout,
|
||||
ff_in=extra_ff_mix_layer,
|
||||
use_spatial_context=use_spatial_context,
|
||||
merge_strategy=merge_strategy,
|
||||
merge_factor=merge_factor,
|
||||
checkpoint=use_checkpoint,
|
||||
use_linear=use_linear_in_transformer,
|
||||
disable_self_attn=disable_self_attn,
|
||||
disable_temporal_crossattention=disable_temporal_crossattention,
|
||||
max_time_embed_period=max_ddpm_temb_period,
|
||||
dtype=self.dtype, device=device, operations=operations
|
||||
)
|
||||
else:
|
||||
return SpatialTransformer(
|
||||
ch, num_heads, dim_head, depth=depth, context_dim=context_dim,
|
||||
disable_self_attn=disable_self_attn, use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
def get_resblock(
|
||||
merge_factor,
|
||||
merge_strategy,
|
||||
video_kernel_size,
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels,
|
||||
dims,
|
||||
use_checkpoint,
|
||||
use_scale_shift_norm,
|
||||
down=False,
|
||||
up=False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=comfy.ops
|
||||
):
|
||||
if self.use_temporal_resblocks:
|
||||
return VideoResBlock(
|
||||
merge_factor=merge_factor,
|
||||
merge_strategy=merge_strategy,
|
||||
video_kernel_size=video_kernel_size,
|
||||
channels=ch,
|
||||
emb_channels=time_embed_dim,
|
||||
dropout=dropout,
|
||||
out_channels=out_channels,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
down=down,
|
||||
up=up,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
)
|
||||
else:
|
||||
return ResBlock(
|
||||
channels=ch,
|
||||
emb_channels=time_embed_dim,
|
||||
dropout=dropout,
|
||||
out_channels=out_channels,
|
||||
use_checkpoint=use_checkpoint,
|
||||
dims=dims,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
down=down,
|
||||
up=up,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
)
|
||||
|
||||
for level, mult in enumerate(channel_mult):
|
||||
for nr in range(self.num_res_blocks[level]):
|
||||
layers = [
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
get_resblock(
|
||||
merge_factor=merge_factor,
|
||||
merge_strategy=merge_strategy,
|
||||
video_kernel_size=video_kernel_size,
|
||||
ch=ch,
|
||||
time_embed_dim=time_embed_dim,
|
||||
dropout=dropout,
|
||||
out_channels=mult * model_channels,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
@ -441,11 +654,9 @@ class UNetModel(nn.Module):
|
||||
disabled_sa = False
|
||||
|
||||
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
|
||||
layers.append(SpatialTransformer(
|
||||
layers.append(get_attention_layer(
|
||||
ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
|
||||
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
|
||||
)
|
||||
disable_self_attn=disabled_sa, use_checkpoint=use_checkpoint)
|
||||
)
|
||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self._feature_size += ch
|
||||
@ -454,10 +665,13 @@ class UNetModel(nn.Module):
|
||||
out_ch = ch
|
||||
self.input_blocks.append(
|
||||
TimestepEmbedSequential(
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
get_resblock(
|
||||
merge_factor=merge_factor,
|
||||
merge_strategy=merge_strategy,
|
||||
video_kernel_size=video_kernel_size,
|
||||
ch=ch,
|
||||
time_embed_dim=time_embed_dim,
|
||||
dropout=dropout,
|
||||
out_channels=out_ch,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
@ -487,10 +701,14 @@ class UNetModel(nn.Module):
|
||||
#num_heads = 1
|
||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||
mid_block = [
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
get_resblock(
|
||||
merge_factor=merge_factor,
|
||||
merge_strategy=merge_strategy,
|
||||
video_kernel_size=video_kernel_size,
|
||||
ch=ch,
|
||||
time_embed_dim=time_embed_dim,
|
||||
dropout=dropout,
|
||||
out_channels=None,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
@ -499,15 +717,18 @@ class UNetModel(nn.Module):
|
||||
operations=operations
|
||||
)]
|
||||
if transformer_depth_middle >= 0:
|
||||
mid_block += [SpatialTransformer( # always uses a self-attn
|
||||
mid_block += [get_attention_layer( # always uses a self-attn
|
||||
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
|
||||
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
|
||||
disable_self_attn=disable_middle_self_attn, use_checkpoint=use_checkpoint
|
||||
),
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
get_resblock(
|
||||
merge_factor=merge_factor,
|
||||
merge_strategy=merge_strategy,
|
||||
video_kernel_size=video_kernel_size,
|
||||
ch=ch,
|
||||
time_embed_dim=time_embed_dim,
|
||||
dropout=dropout,
|
||||
out_channels=None,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
@ -523,10 +744,13 @@ class UNetModel(nn.Module):
|
||||
for i in range(self.num_res_blocks[level] + 1):
|
||||
ich = input_block_chans.pop()
|
||||
layers = [
|
||||
ResBlock(
|
||||
ch + ich,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
get_resblock(
|
||||
merge_factor=merge_factor,
|
||||
merge_strategy=merge_strategy,
|
||||
video_kernel_size=video_kernel_size,
|
||||
ch=ch + ich,
|
||||
time_embed_dim=time_embed_dim,
|
||||
dropout=dropout,
|
||||
out_channels=model_channels * mult,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
@ -554,19 +778,21 @@ class UNetModel(nn.Module):
|
||||
|
||||
if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
|
||||
layers.append(
|
||||
SpatialTransformer(
|
||||
get_attention_layer(
|
||||
ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
|
||||
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
|
||||
disable_self_attn=disabled_sa, use_checkpoint=use_checkpoint
|
||||
)
|
||||
)
|
||||
if level and i == self.num_res_blocks[level]:
|
||||
out_ch = ch
|
||||
layers.append(
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
get_resblock(
|
||||
merge_factor=merge_factor,
|
||||
merge_strategy=merge_strategy,
|
||||
video_kernel_size=video_kernel_size,
|
||||
ch=ch,
|
||||
time_embed_dim=time_embed_dim,
|
||||
dropout=dropout,
|
||||
out_channels=out_ch,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
@ -605,9 +831,13 @@ class UNetModel(nn.Module):
|
||||
:return: an [N x C x ...] Tensor of outputs.
|
||||
"""
|
||||
transformer_options["original_shape"] = list(x.shape)
|
||||
transformer_options["current_index"] = 0
|
||||
transformer_options["transformer_index"] = 0
|
||||
transformer_patches = transformer_options.get("patches", {})
|
||||
|
||||
num_video_frames = kwargs.get("num_video_frames", self.default_num_video_frames)
|
||||
image_only_indicator = kwargs.get("image_only_indicator", self.default_image_only_indicator)
|
||||
time_context = kwargs.get("time_context", None)
|
||||
|
||||
assert (y is not None) == (
|
||||
self.num_classes is not None
|
||||
), "must specify y if and only if the model is class-conditional"
|
||||
@ -622,14 +852,24 @@ class UNetModel(nn.Module):
|
||||
h = x.type(self.dtype)
|
||||
for id, module in enumerate(self.input_blocks):
|
||||
transformer_options["block"] = ("input", id)
|
||||
h = forward_timestep_embed(module, h, emb, context, transformer_options)
|
||||
h = forward_timestep_embed(module, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
|
||||
h = apply_control(h, control, 'input')
|
||||
if "input_block_patch" in transformer_patches:
|
||||
patch = transformer_patches["input_block_patch"]
|
||||
for p in patch:
|
||||
h = p(h, transformer_options)
|
||||
|
||||
hs.append(h)
|
||||
if "input_block_patch_after_skip" in transformer_patches:
|
||||
patch = transformer_patches["input_block_patch_after_skip"]
|
||||
for p in patch:
|
||||
h = p(h, transformer_options)
|
||||
|
||||
transformer_options["block"] = ("middle", 0)
|
||||
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options)
|
||||
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
|
||||
h = apply_control(h, control, 'middle')
|
||||
|
||||
|
||||
for id, module in enumerate(self.output_blocks):
|
||||
transformer_options["block"] = ("output", id)
|
||||
hsp = hs.pop()
|
||||
@ -646,7 +886,7 @@ class UNetModel(nn.Module):
|
||||
output_shape = hs[-1].shape
|
||||
else:
|
||||
output_shape = None
|
||||
h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape)
|
||||
h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
|
||||
h = h.type(x.dtype)
|
||||
if self.predict_codebook_ids:
|
||||
return self.id_predictor(h)
|
||||
|
@ -13,11 +13,78 @@ import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from einops import repeat
|
||||
from einops import repeat, rearrange
|
||||
|
||||
from comfy.ldm.util import instantiate_from_config
|
||||
import comfy.ops
|
||||
|
||||
class AlphaBlender(nn.Module):
|
||||
strategies = ["learned", "fixed", "learned_with_images"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
alpha: float,
|
||||
merge_strategy: str = "learned_with_images",
|
||||
rearrange_pattern: str = "b t -> (b t) 1 1",
|
||||
):
|
||||
super().__init__()
|
||||
self.merge_strategy = merge_strategy
|
||||
self.rearrange_pattern = rearrange_pattern
|
||||
|
||||
assert (
|
||||
merge_strategy in self.strategies
|
||||
), f"merge_strategy needs to be in {self.strategies}"
|
||||
|
||||
if self.merge_strategy == "fixed":
|
||||
self.register_buffer("mix_factor", torch.Tensor([alpha]))
|
||||
elif (
|
||||
self.merge_strategy == "learned"
|
||||
or self.merge_strategy == "learned_with_images"
|
||||
):
|
||||
self.register_parameter(
|
||||
"mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unknown merge strategy {self.merge_strategy}")
|
||||
|
||||
def get_alpha(self, image_only_indicator: torch.Tensor) -> torch.Tensor:
|
||||
# skip_time_mix = rearrange(repeat(skip_time_mix, 'b -> (b t) () () ()', t=t), '(b t) 1 ... -> b 1 t ...', t=t)
|
||||
if self.merge_strategy == "fixed":
|
||||
# make shape compatible
|
||||
# alpha = repeat(self.mix_factor, '1 -> b () t () ()', t=t, b=bs)
|
||||
alpha = self.mix_factor
|
||||
elif self.merge_strategy == "learned":
|
||||
alpha = torch.sigmoid(self.mix_factor)
|
||||
# make shape compatible
|
||||
# alpha = repeat(alpha, '1 -> s () ()', s = t * bs)
|
||||
elif self.merge_strategy == "learned_with_images":
|
||||
assert image_only_indicator is not None, "need image_only_indicator ..."
|
||||
alpha = torch.where(
|
||||
image_only_indicator.bool(),
|
||||
torch.ones(1, 1, device=image_only_indicator.device),
|
||||
rearrange(torch.sigmoid(self.mix_factor), "... -> ... 1"),
|
||||
)
|
||||
alpha = rearrange(alpha, self.rearrange_pattern)
|
||||
# make shape compatible
|
||||
# alpha = repeat(alpha, '1 -> s () ()', s = t * bs)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
return alpha
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x_spatial,
|
||||
x_temporal,
|
||||
image_only_indicator=None,
|
||||
) -> torch.Tensor:
|
||||
alpha = self.get_alpha(image_only_indicator)
|
||||
x = (
|
||||
alpha.to(x_spatial.dtype) * x_spatial
|
||||
+ (1.0 - alpha).to(x_spatial.dtype) * x_temporal
|
||||
)
|
||||
return x
|
||||
|
||||
|
||||
def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
||||
if schedule == "linear":
|
||||
betas = (
|
||||
|
244
comfy/ldm/modules/temporal_ae.py
Normal file
244
comfy/ldm/modules/temporal_ae.py
Normal file
@ -0,0 +1,244 @@
|
||||
import functools
|
||||
from typing import Callable, Iterable, Union
|
||||
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
|
||||
import comfy.ops
|
||||
|
||||
from .diffusionmodules.model import (
|
||||
AttnBlock,
|
||||
Decoder,
|
||||
ResnetBlock,
|
||||
)
|
||||
from .diffusionmodules.openaimodel import ResBlock, timestep_embedding
|
||||
from .attention import BasicTransformerBlock
|
||||
|
||||
def partialclass(cls, *args, **kwargs):
|
||||
class NewCls(cls):
|
||||
__init__ = functools.partialmethod(cls.__init__, *args, **kwargs)
|
||||
|
||||
return NewCls
|
||||
|
||||
|
||||
class VideoResBlock(ResnetBlock):
|
||||
def __init__(
|
||||
self,
|
||||
out_channels,
|
||||
*args,
|
||||
dropout=0.0,
|
||||
video_kernel_size=3,
|
||||
alpha=0.0,
|
||||
merge_strategy="learned",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(out_channels=out_channels, dropout=dropout, *args, **kwargs)
|
||||
if video_kernel_size is None:
|
||||
video_kernel_size = [3, 1, 1]
|
||||
self.time_stack = ResBlock(
|
||||
channels=out_channels,
|
||||
emb_channels=0,
|
||||
dropout=dropout,
|
||||
dims=3,
|
||||
use_scale_shift_norm=False,
|
||||
use_conv=False,
|
||||
up=False,
|
||||
down=False,
|
||||
kernel_size=video_kernel_size,
|
||||
use_checkpoint=False,
|
||||
skip_t_emb=True,
|
||||
)
|
||||
|
||||
self.merge_strategy = merge_strategy
|
||||
if self.merge_strategy == "fixed":
|
||||
self.register_buffer("mix_factor", torch.Tensor([alpha]))
|
||||
elif self.merge_strategy == "learned":
|
||||
self.register_parameter(
|
||||
"mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unknown merge strategy {self.merge_strategy}")
|
||||
|
||||
def get_alpha(self, bs):
|
||||
if self.merge_strategy == "fixed":
|
||||
return self.mix_factor
|
||||
elif self.merge_strategy == "learned":
|
||||
return torch.sigmoid(self.mix_factor)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
def forward(self, x, temb, skip_video=False, timesteps=None):
|
||||
b, c, h, w = x.shape
|
||||
if timesteps is None:
|
||||
timesteps = b
|
||||
|
||||
x = super().forward(x, temb)
|
||||
|
||||
if not skip_video:
|
||||
x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
|
||||
|
||||
x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
|
||||
|
||||
x = self.time_stack(x, temb)
|
||||
|
||||
alpha = self.get_alpha(bs=b // timesteps)
|
||||
x = alpha * x + (1.0 - alpha) * x_mix
|
||||
|
||||
x = rearrange(x, "b c t h w -> (b t) c h w")
|
||||
return x
|
||||
|
||||
|
||||
class AE3DConv(torch.nn.Conv2d):
|
||||
def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs):
|
||||
super().__init__(in_channels, out_channels, *args, **kwargs)
|
||||
if isinstance(video_kernel_size, Iterable):
|
||||
padding = [int(k // 2) for k in video_kernel_size]
|
||||
else:
|
||||
padding = int(video_kernel_size // 2)
|
||||
|
||||
self.time_mix_conv = torch.nn.Conv3d(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=video_kernel_size,
|
||||
padding=padding,
|
||||
)
|
||||
|
||||
def forward(self, input, timesteps=None, skip_video=False):
|
||||
if timesteps is None:
|
||||
timesteps = input.shape[0]
|
||||
x = super().forward(input)
|
||||
if skip_video:
|
||||
return x
|
||||
x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
|
||||
x = self.time_mix_conv(x)
|
||||
return rearrange(x, "b c t h w -> (b t) c h w")
|
||||
|
||||
|
||||
class AttnVideoBlock(AttnBlock):
|
||||
def __init__(
|
||||
self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned"
|
||||
):
|
||||
super().__init__(in_channels)
|
||||
# no context, single headed, as in base class
|
||||
self.time_mix_block = BasicTransformerBlock(
|
||||
dim=in_channels,
|
||||
n_heads=1,
|
||||
d_head=in_channels,
|
||||
checkpoint=False,
|
||||
ff_in=True,
|
||||
)
|
||||
|
||||
time_embed_dim = self.in_channels * 4
|
||||
self.video_time_embed = torch.nn.Sequential(
|
||||
comfy.ops.Linear(self.in_channels, time_embed_dim),
|
||||
torch.nn.SiLU(),
|
||||
comfy.ops.Linear(time_embed_dim, self.in_channels),
|
||||
)
|
||||
|
||||
self.merge_strategy = merge_strategy
|
||||
if self.merge_strategy == "fixed":
|
||||
self.register_buffer("mix_factor", torch.Tensor([alpha]))
|
||||
elif self.merge_strategy == "learned":
|
||||
self.register_parameter(
|
||||
"mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unknown merge strategy {self.merge_strategy}")
|
||||
|
||||
def forward(self, x, timesteps=None, skip_time_block=False):
|
||||
if skip_time_block:
|
||||
return super().forward(x)
|
||||
|
||||
if timesteps is None:
|
||||
timesteps = x.shape[0]
|
||||
|
||||
x_in = x
|
||||
x = self.attention(x)
|
||||
h, w = x.shape[2:]
|
||||
x = rearrange(x, "b c h w -> b (h w) c")
|
||||
|
||||
x_mix = x
|
||||
num_frames = torch.arange(timesteps, device=x.device)
|
||||
num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
|
||||
num_frames = rearrange(num_frames, "b t -> (b t)")
|
||||
t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False)
|
||||
emb = self.video_time_embed(t_emb) # b, n_channels
|
||||
emb = emb[:, None, :]
|
||||
x_mix = x_mix + emb
|
||||
|
||||
alpha = self.get_alpha()
|
||||
x_mix = self.time_mix_block(x_mix, timesteps=timesteps)
|
||||
x = alpha * x + (1.0 - alpha) * x_mix # alpha merge
|
||||
|
||||
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
|
||||
x = self.proj_out(x)
|
||||
|
||||
return x_in + x
|
||||
|
||||
def get_alpha(
|
||||
self,
|
||||
):
|
||||
if self.merge_strategy == "fixed":
|
||||
return self.mix_factor
|
||||
elif self.merge_strategy == "learned":
|
||||
return torch.sigmoid(self.mix_factor)
|
||||
else:
|
||||
raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}")
|
||||
|
||||
|
||||
|
||||
def make_time_attn(
|
||||
in_channels,
|
||||
attn_type="vanilla",
|
||||
attn_kwargs=None,
|
||||
alpha: float = 0,
|
||||
merge_strategy: str = "learned",
|
||||
):
|
||||
return partialclass(
|
||||
AttnVideoBlock, in_channels, alpha=alpha, merge_strategy=merge_strategy
|
||||
)
|
||||
|
||||
|
||||
class Conv2DWrapper(torch.nn.Conv2d):
|
||||
def forward(self, input: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
return super().forward(input)
|
||||
|
||||
|
||||
class VideoDecoder(Decoder):
|
||||
available_time_modes = ["all", "conv-only", "attn-only"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
video_kernel_size: Union[int, list] = 3,
|
||||
alpha: float = 0.0,
|
||||
merge_strategy: str = "learned",
|
||||
time_mode: str = "conv-only",
|
||||
**kwargs,
|
||||
):
|
||||
self.video_kernel_size = video_kernel_size
|
||||
self.alpha = alpha
|
||||
self.merge_strategy = merge_strategy
|
||||
self.time_mode = time_mode
|
||||
assert (
|
||||
self.time_mode in self.available_time_modes
|
||||
), f"time_mode parameter has to be in {self.available_time_modes}"
|
||||
|
||||
if self.time_mode != "attn-only":
|
||||
kwargs["conv_out_op"] = partialclass(AE3DConv, video_kernel_size=self.video_kernel_size)
|
||||
if self.time_mode not in ["conv-only", "only-last-conv"]:
|
||||
kwargs["attn_op"] = partialclass(make_time_attn, alpha=self.alpha, merge_strategy=self.merge_strategy)
|
||||
if self.time_mode not in ["attn-only", "only-last-conv"]:
|
||||
kwargs["resnet_op"] = partialclass(VideoResBlock, video_kernel_size=self.video_kernel_size, alpha=self.alpha, merge_strategy=self.merge_strategy)
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def get_last_layer(self, skip_time_mix=False, **kwargs):
|
||||
if self.time_mode == "attn-only":
|
||||
raise NotImplementedError("TODO")
|
||||
else:
|
||||
return (
|
||||
self.conv_out.time_mix_conv.weight
|
||||
if not skip_time_mix
|
||||
else self.conv_out.weight
|
||||
)
|
@ -10,17 +10,22 @@ from . import utils
|
||||
class ModelType(Enum):
|
||||
EPS = 1
|
||||
V_PREDICTION = 2
|
||||
V_PREDICTION_EDM = 3
|
||||
|
||||
|
||||
from comfy.model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete
|
||||
from comfy.model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete, ModelSamplingContinuousEDM
|
||||
|
||||
|
||||
def model_sampling(model_config, model_type):
|
||||
s = ModelSamplingDiscrete
|
||||
|
||||
if model_type == ModelType.EPS:
|
||||
c = EPS
|
||||
elif model_type == ModelType.V_PREDICTION:
|
||||
c = V_PREDICTION
|
||||
|
||||
s = ModelSamplingDiscrete
|
||||
elif model_type == ModelType.V_PREDICTION_EDM:
|
||||
c = V_PREDICTION
|
||||
s = ModelSamplingContinuousEDM
|
||||
|
||||
class ModelSampling(s, c):
|
||||
pass
|
||||
@ -121,6 +126,7 @@ class BaseModel(torch.nn.Module):
|
||||
if k.startswith(unet_prefix):
|
||||
to_load[k[len(unet_prefix):]] = sd.pop(k)
|
||||
|
||||
to_load = self.model_config.process_unet_state_dict(to_load)
|
||||
m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
|
||||
if len(m) > 0:
|
||||
print("unet missing:", m)
|
||||
@ -157,6 +163,17 @@ class BaseModel(torch.nn.Module):
|
||||
def set_inpaint(self):
|
||||
self.inpaint_model = True
|
||||
|
||||
def memory_required(self, input_shape):
|
||||
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
|
||||
#TODO: this needs to be tweaked
|
||||
area = input_shape[0] * input_shape[2] * input_shape[3]
|
||||
return (area * comfy.model_management.dtype_size(self.get_dtype()) / 50) * (1024 * 1024)
|
||||
else:
|
||||
#TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory.
|
||||
area = input_shape[0] * input_shape[2] * input_shape[3]
|
||||
return (((area * 0.6) / 0.9) + 1024) * (1024 * 1024)
|
||||
|
||||
|
||||
def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0):
|
||||
adm_inputs = []
|
||||
weights = []
|
||||
@ -251,3 +268,48 @@ class SDXL(BaseModel):
|
||||
out.append(self.embedder(torch.Tensor([target_width])))
|
||||
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1)
|
||||
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
|
||||
|
||||
class SVD_img2vid(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.V_PREDICTION_EDM, device=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
self.embedder = Timestep(256)
|
||||
|
||||
def encode_adm(self, **kwargs):
|
||||
fps_id = kwargs.get("fps", 6) - 1
|
||||
motion_bucket_id = kwargs.get("motion_bucket_id", 127)
|
||||
augmentation = kwargs.get("augmentation_level", 0)
|
||||
|
||||
out = []
|
||||
out.append(self.embedder(torch.Tensor([fps_id])))
|
||||
out.append(self.embedder(torch.Tensor([motion_bucket_id])))
|
||||
out.append(self.embedder(torch.Tensor([augmentation])))
|
||||
|
||||
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0)
|
||||
return flat
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
adm = self.encode_adm(**kwargs)
|
||||
if adm is not None:
|
||||
out['y'] = comfy.conds.CONDRegular(adm)
|
||||
|
||||
latent_image = kwargs.get("concat_latent_image", None)
|
||||
noise = kwargs.get("noise", None)
|
||||
device = kwargs["device"]
|
||||
|
||||
if latent_image is None:
|
||||
latent_image = torch.zeros_like(noise)
|
||||
|
||||
if latent_image.shape[1:] != noise.shape[1:]:
|
||||
latent_image = utils.common_upscale(latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||
|
||||
latent_image = utils.repeat_to_batch_size(latent_image, noise.shape[0])
|
||||
|
||||
out['c_concat'] = comfy.conds.CONDNoiseShape(latent_image)
|
||||
|
||||
if "time_conditioning" in kwargs:
|
||||
out["time_context"] = comfy.conds.CONDCrossAttn(kwargs["time_conditioning"])
|
||||
|
||||
out['image_only_indicator'] = comfy.conds.CONDConstant(torch.zeros((1,), device=device))
|
||||
out['num_video_frames'] = comfy.conds.CONDConstant(noise.shape[0])
|
||||
return out
|
||||
|
@ -24,7 +24,8 @@ def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
|
||||
last_transformer_depth = count_blocks(state_dict_keys, transformer_prefix + '{}')
|
||||
context_dim = state_dict['{}0.attn2.to_k.weight'.format(transformer_prefix)].shape[1]
|
||||
use_linear_in_transformer = len(state_dict['{}1.proj_in.weight'.format(prefix)].shape) == 2
|
||||
return last_transformer_depth, context_dim, use_linear_in_transformer
|
||||
time_stack = '{}1.time_stack.0.attn1.to_q.weight'.format(prefix) in state_dict or '{}1.time_mix_blocks.0.attn1.to_q.weight'.format(prefix) in state_dict
|
||||
return last_transformer_depth, context_dim, use_linear_in_transformer, time_stack
|
||||
return None
|
||||
|
||||
def detect_unet_config(state_dict, key_prefix, dtype):
|
||||
@ -57,6 +58,7 @@ def detect_unet_config(state_dict, key_prefix, dtype):
|
||||
context_dim = None
|
||||
use_linear_in_transformer = False
|
||||
|
||||
video_model = False
|
||||
|
||||
current_res = 1
|
||||
count = 0
|
||||
@ -99,6 +101,7 @@ def detect_unet_config(state_dict, key_prefix, dtype):
|
||||
if context_dim is None:
|
||||
context_dim = out[1]
|
||||
use_linear_in_transformer = out[2]
|
||||
video_model = out[3]
|
||||
else:
|
||||
transformer_depth.append(0)
|
||||
|
||||
@ -127,6 +130,19 @@ def detect_unet_config(state_dict, key_prefix, dtype):
|
||||
unet_config["transformer_depth_middle"] = transformer_depth_middle
|
||||
unet_config['use_linear_in_transformer'] = use_linear_in_transformer
|
||||
unet_config["context_dim"] = context_dim
|
||||
|
||||
if video_model:
|
||||
unet_config["extra_ff_mix_layer"] = True
|
||||
unet_config["use_spatial_context"] = True
|
||||
unet_config["merge_strategy"] = "learned_with_images"
|
||||
unet_config["merge_factor"] = 0.0
|
||||
unet_config["video_kernel_size"] = [3, 1, 1]
|
||||
unet_config["use_temporal_resblock"] = True
|
||||
unet_config["use_temporal_attention"] = True
|
||||
else:
|
||||
unet_config["use_temporal_resblock"] = False
|
||||
unet_config["use_temporal_attention"] = False
|
||||
|
||||
return unet_config
|
||||
|
||||
def model_config_from_unet_config(unet_config):
|
||||
@ -186,17 +202,24 @@ def convert_config(unet_config):
|
||||
|
||||
def unet_config_from_diffusers_unet(state_dict, dtype):
|
||||
match = {}
|
||||
attention_resolutions = []
|
||||
transformer_depth = []
|
||||
|
||||
attn_res = 1
|
||||
for i in range(5):
|
||||
k = "down_blocks.{}.attentions.1.transformer_blocks.0.attn2.to_k.weight".format(i)
|
||||
if k in state_dict:
|
||||
match["context_dim"] = state_dict[k].shape[1]
|
||||
attention_resolutions.append(attn_res)
|
||||
attn_res *= 2
|
||||
down_blocks = count_blocks(state_dict, "down_blocks.{}")
|
||||
for i in range(down_blocks):
|
||||
attn_blocks = count_blocks(state_dict, "down_blocks.{}.attentions.".format(i) + '{}')
|
||||
for ab in range(attn_blocks):
|
||||
transformer_count = count_blocks(state_dict, "down_blocks.{}.attentions.{}.transformer_blocks.".format(i, ab) + '{}')
|
||||
transformer_depth.append(transformer_count)
|
||||
if transformer_count > 0:
|
||||
match["context_dim"] = state_dict["down_blocks.{}.attentions.{}.transformer_blocks.0.attn2.to_k.weight".format(i, ab)].shape[1]
|
||||
|
||||
match["attention_resolutions"] = attention_resolutions
|
||||
attn_res *= 2
|
||||
if attn_blocks == 0:
|
||||
transformer_depth.append(0)
|
||||
transformer_depth.append(0)
|
||||
|
||||
match["transformer_depth"] = transformer_depth
|
||||
|
||||
match["model_channels"] = state_dict["conv_in.weight"].shape[0]
|
||||
match["in_channels"] = state_dict["conv_in.weight"].shape[1]
|
||||
@ -208,50 +231,65 @@ def unet_config_from_diffusers_unet(state_dict, dtype):
|
||||
|
||||
SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
||||
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4],
|
||||
'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64}
|
||||
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 10, 10], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 10,
|
||||
'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10],
|
||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||
|
||||
SDXL_refiner = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2560, 'dtype': dtype, 'in_channels': 4, 'model_channels': 384,
|
||||
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 4, 4, 0], 'channel_mult': [1, 2, 4, 4],
|
||||
'transformer_depth_middle': 4, 'use_linear_in_transformer': True, 'context_dim': 1280, "num_head_channels": 64}
|
||||
'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [0, 0, 4, 4, 4, 4, 0, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 4,
|
||||
'use_linear_in_transformer': True, 'context_dim': 1280, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 4, 4, 4, 4, 4, 4, 0, 0, 0],
|
||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||
|
||||
SD21 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
|
||||
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
||||
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, "num_head_channels": 64}
|
||||
'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2],
|
||||
'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': True,
|
||||
'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
|
||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||
|
||||
SD21_uncliph = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2048, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
||||
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
||||
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, "num_head_channels": 64}
|
||||
'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1,
|
||||
'use_linear_in_transformer': True, 'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
|
||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||
|
||||
SD21_unclipl = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 1536, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
||||
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
||||
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024}
|
||||
'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1,
|
||||
'use_linear_in_transformer': True, 'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
|
||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||
|
||||
SD15 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
|
||||
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
||||
'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768, "num_heads": 8}
|
||||
SD15 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'adm_in_channels': None,
|
||||
'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0],
|
||||
'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768, 'num_heads': 8,
|
||||
'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
|
||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||
|
||||
SDXL_mid_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
||||
'num_res_blocks': 2, 'attention_resolutions': [4], 'transformer_depth': [0, 0, 1], 'channel_mult': [1, 2, 4],
|
||||
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64}
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
||||
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 0, 0, 1, 1], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 1,
|
||||
'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 0, 0, 0, 1, 1, 1],
|
||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||
|
||||
SDXL_small_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
||||
'num_res_blocks': 2, 'attention_resolutions': [], 'transformer_depth': [0, 0, 0], 'channel_mult': [1, 2, 4],
|
||||
'transformer_depth_middle': 0, 'use_linear_in_transformer': True, "num_head_channels": 64, 'context_dim': 1}
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
||||
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 0, 0, 0, 0], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 0,
|
||||
'use_linear_in_transformer': True, 'num_head_channels': 64, 'context_dim': 1, 'transformer_depth_output': [0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||
|
||||
SDXL_diffusers_inpaint = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 9, 'model_channels': 320,
|
||||
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4],
|
||||
'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64}
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 9, 'model_channels': 320,
|
||||
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 10, 10], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 10,
|
||||
'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10],
|
||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||
|
||||
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint]
|
||||
SSD_1B = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
||||
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 4, 4], 'transformer_depth_output': [0, 0, 0, 1, 1, 2, 10, 4, 4],
|
||||
'channel_mult': [1, 2, 4], 'transformer_depth_middle': -1, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
|
||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||
|
||||
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B]
|
||||
|
||||
for unet_config in supported_models:
|
||||
matches = True
|
||||
|
@ -133,6 +133,10 @@ else:
|
||||
import xformers
|
||||
import xformers.ops
|
||||
XFORMERS_IS_AVAILABLE = True
|
||||
try:
|
||||
XFORMERS_IS_AVAILABLE = xformers._has_cpp_library
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
XFORMERS_VERSION = xformers.version.__version__
|
||||
print("xformers version:", XFORMERS_VERSION)
|
||||
@ -478,6 +482,21 @@ def text_encoder_device():
|
||||
else:
|
||||
return torch.device("cpu")
|
||||
|
||||
def text_encoder_dtype(device=None):
|
||||
if args.fp8_e4m3fn_text_enc:
|
||||
return torch.float8_e4m3fn
|
||||
elif args.fp8_e5m2_text_enc:
|
||||
return torch.float8_e5m2
|
||||
elif args.fp16_text_enc:
|
||||
return torch.float16
|
||||
elif args.fp32_text_enc:
|
||||
return torch.float32
|
||||
|
||||
if should_use_fp16(device, prioritize_performance=False):
|
||||
return torch.float16
|
||||
else:
|
||||
return torch.float32
|
||||
|
||||
def vae_device():
|
||||
return get_torch_device()
|
||||
|
||||
@ -579,27 +598,6 @@ def get_free_memory(dev=None, torch_free_too=False):
|
||||
else:
|
||||
return mem_free_total
|
||||
|
||||
def batch_area_memory(area):
|
||||
if xformers_enabled() or pytorch_attention_flash_attention():
|
||||
#TODO: these formulas are copied from maximum_batch_area below
|
||||
return (area / 20) * (1024 * 1024)
|
||||
else:
|
||||
return (((area * 0.6) / 0.9) + 1024) * (1024 * 1024)
|
||||
|
||||
def maximum_batch_area():
|
||||
global vram_state
|
||||
if vram_state == VRAMState.NO_VRAM:
|
||||
return 0
|
||||
|
||||
memory_free = get_free_memory() / (1024 * 1024)
|
||||
if xformers_enabled() or pytorch_attention_flash_attention():
|
||||
#TODO: this needs to be tweaked
|
||||
area = 20 * memory_free
|
||||
else:
|
||||
#TODO: this formula is because AMD sucks and has memory management issues which might be fixed in the future
|
||||
area = ((memory_free - 1024) * 0.9) / (0.6)
|
||||
return int(max(area, 0))
|
||||
|
||||
def cpu_mode():
|
||||
global cpu_state
|
||||
return cpu_state == CPUState.CPU
|
||||
|
@ -37,7 +37,7 @@ class ModelPatcher:
|
||||
return size
|
||||
|
||||
def clone(self):
|
||||
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device)
|
||||
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, weight_inplace_update=self.weight_inplace_update)
|
||||
n.patches = {}
|
||||
for k in self.patches:
|
||||
n.patches[k] = self.patches[k][:]
|
||||
@ -52,6 +52,9 @@ class ModelPatcher:
|
||||
return True
|
||||
return False
|
||||
|
||||
def memory_required(self, input_shape):
|
||||
return self.model.memory_required(input_shape=input_shape)
|
||||
|
||||
def set_model_sampler_cfg_function(self, sampler_cfg_function):
|
||||
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
|
||||
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
|
||||
@ -93,6 +96,12 @@ class ModelPatcher:
|
||||
def set_model_attn2_output_patch(self, patch):
|
||||
self.set_model_patch(patch, "attn2_output_patch")
|
||||
|
||||
def set_model_input_block_patch(self, patch):
|
||||
self.set_model_patch(patch, "input_block_patch")
|
||||
|
||||
def set_model_input_block_patch_after_skip(self, patch):
|
||||
self.set_model_patch(patch, "input_block_patch_after_skip")
|
||||
|
||||
def set_model_output_block_patch(self, patch):
|
||||
self.set_model_patch(patch, "output_block_patch")
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
|
||||
|
||||
import math
|
||||
|
||||
class EPS:
|
||||
def calculate_input(self, sigma, noise):
|
||||
@ -24,7 +24,7 @@ class ModelSamplingDiscrete(torch.nn.Module):
|
||||
super().__init__()
|
||||
beta_schedule = "linear"
|
||||
if model_config is not None:
|
||||
beta_schedule = model_config.beta_schedule
|
||||
beta_schedule = model_config.sampling_settings.get("beta_schedule", beta_schedule)
|
||||
self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
|
||||
self.sigma_data = 1.0
|
||||
|
||||
@ -65,16 +65,65 @@ class ModelSamplingDiscrete(torch.nn.Module):
|
||||
def timestep(self, sigma):
|
||||
log_sigma = sigma.log()
|
||||
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
|
||||
return dists.abs().argmin(dim=0).view(sigma.shape)
|
||||
return dists.abs().argmin(dim=0).view(sigma.shape).to(sigma.device)
|
||||
|
||||
def sigma(self, timestep):
|
||||
t = torch.clamp(timestep.float(), min=0, max=(len(self.sigmas) - 1))
|
||||
t = torch.clamp(timestep.float().to(self.log_sigmas.device), min=0, max=(len(self.sigmas) - 1))
|
||||
low_idx = t.floor().long()
|
||||
high_idx = t.ceil().long()
|
||||
w = t.frac()
|
||||
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
|
||||
return log_sigma.exp()
|
||||
return log_sigma.exp().to(timestep.device)
|
||||
|
||||
def percent_to_sigma(self, percent):
|
||||
return self.sigma(torch.tensor(percent * 999.0))
|
||||
if percent <= 0.0:
|
||||
return 999999999.9
|
||||
if percent >= 1.0:
|
||||
return 0.0
|
||||
percent = 1.0 - percent
|
||||
return self.sigma(torch.tensor(percent * 999.0)).item()
|
||||
|
||||
|
||||
class ModelSamplingContinuousEDM(torch.nn.Module):
|
||||
def __init__(self, model_config=None):
|
||||
super().__init__()
|
||||
self.sigma_data = 1.0
|
||||
|
||||
if model_config is not None:
|
||||
sampling_settings = model_config.sampling_settings
|
||||
else:
|
||||
sampling_settings = {}
|
||||
|
||||
sigma_min = sampling_settings.get("sigma_min", 0.002)
|
||||
sigma_max = sampling_settings.get("sigma_max", 120.0)
|
||||
self.set_sigma_range(sigma_min, sigma_max)
|
||||
|
||||
def set_sigma_range(self, sigma_min, sigma_max):
|
||||
sigmas = torch.linspace(math.log(sigma_min), math.log(sigma_max), 1000).exp()
|
||||
|
||||
self.register_buffer('sigmas', sigmas) #for compatibility with some schedulers
|
||||
self.register_buffer('log_sigmas', sigmas.log())
|
||||
|
||||
@property
|
||||
def sigma_min(self):
|
||||
return self.sigmas[0]
|
||||
|
||||
@property
|
||||
def sigma_max(self):
|
||||
return self.sigmas[-1]
|
||||
|
||||
def timestep(self, sigma):
|
||||
return 0.25 * sigma.log()
|
||||
|
||||
def sigma(self, timestep):
|
||||
return (timestep / 0.25).exp()
|
||||
|
||||
def percent_to_sigma(self, percent):
|
||||
if percent <= 0.0:
|
||||
return 999999999.9
|
||||
if percent >= 1.0:
|
||||
return 0.0
|
||||
percent = 1.0 - percent
|
||||
|
||||
log_sigma_min = math.log(self.sigma_min)
|
||||
return math.exp((math.log(self.sigma_max) - log_sigma_min) * percent + log_sigma_min)
|
||||
|
@ -83,7 +83,7 @@ def prepare_sampling(model, noise_shape, positive, negative, noise_mask):
|
||||
|
||||
real_model = None
|
||||
models, inference_memory = get_additional_models(positive, negative, model.model_dtype())
|
||||
comfy.model_management.load_models_gpu([model] + models, comfy.model_management.batch_area_memory(noise_shape[0] * noise_shape[2] * noise_shape[3]) + inference_memory)
|
||||
comfy.model_management.load_models_gpu([model] + models, model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory)
|
||||
real_model = model.model
|
||||
|
||||
return real_model, positive, negative, noise_mask, models
|
||||
|
@ -11,7 +11,7 @@ import comfy.conds
|
||||
|
||||
#The main sampling function shared by all the samplers
|
||||
#Returns denoised
|
||||
def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
|
||||
def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
|
||||
def get_area_and_mult(conds, x_in, timestep_in):
|
||||
area = (x_in.shape[2], x_in.shape[3], 0, 0)
|
||||
strength = 1.0
|
||||
@ -134,7 +134,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
|
||||
|
||||
return out
|
||||
|
||||
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, model_options):
|
||||
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
|
||||
out_cond = torch.zeros_like(x_in)
|
||||
out_count = torch.ones_like(x_in) * 1e-37
|
||||
|
||||
@ -170,9 +170,11 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
|
||||
to_batch_temp.reverse()
|
||||
to_batch = to_batch_temp[:1]
|
||||
|
||||
free_memory = model_management.get_free_memory(x_in.device)
|
||||
for i in range(1, len(to_batch_temp) + 1):
|
||||
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
|
||||
if (len(batch_amount) * first_shape[0] * first_shape[2] * first_shape[3] < max_total_area):
|
||||
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
|
||||
if model.memory_required(input_shape) < free_memory:
|
||||
to_batch = batch_amount
|
||||
break
|
||||
|
||||
@ -218,12 +220,14 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
|
||||
transformer_options["patches"] = patches
|
||||
|
||||
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
|
||||
transformer_options["sigmas"] = timestep
|
||||
|
||||
c['transformer_options'] = transformer_options
|
||||
|
||||
if 'model_function_wrapper' in model_options:
|
||||
output = model_options['model_function_wrapper'](model_function, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
|
||||
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
|
||||
else:
|
||||
output = model_function(input_x, timestep_, **c).chunk(batch_chunks)
|
||||
output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
|
||||
del input_x
|
||||
|
||||
for o in range(batch_chunks):
|
||||
@ -242,11 +246,10 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
|
||||
return out_cond, out_uncond
|
||||
|
||||
|
||||
max_total_area = model_management.maximum_batch_area()
|
||||
if math.isclose(cond_scale, 1.0):
|
||||
uncond = None
|
||||
|
||||
cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, model_options)
|
||||
cond, uncond = calc_cond_uncond_batch(model, cond, uncond, x, timestep, model_options)
|
||||
if "sampler_cfg_function" in model_options:
|
||||
args = {"cond": x - cond, "uncond": x - uncond, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep}
|
||||
return x - model_options["sampler_cfg_function"](args)
|
||||
@ -258,7 +261,7 @@ class CFGNoisePredictor(torch.nn.Module):
|
||||
super().__init__()
|
||||
self.inner_model = model
|
||||
def apply_model(self, x, timestep, cond, uncond, cond_scale, model_options={}, seed=None):
|
||||
out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, model_options=model_options, seed=seed)
|
||||
out = sampling_function(self.inner_model, x, timestep, uncond, cond, cond_scale, model_options=model_options, seed=seed)
|
||||
return out
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.apply_model(*args, **kwargs)
|
||||
@ -511,52 +514,69 @@ class Sampler:
|
||||
|
||||
class UNIPC(Sampler):
|
||||
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
|
||||
return uni_pc.sample_unipc(model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=self.max_denoise(model_wrap, sigmas), extra_args=extra_args, noise_mask=denoise_mask, callback=callback, disable=disable_pbar)
|
||||
return uni_pc.sample_unipc(model_wrap, noise, latent_image, sigmas, max_denoise=self.max_denoise(model_wrap, sigmas), extra_args=extra_args, noise_mask=denoise_mask, callback=callback, disable=disable_pbar)
|
||||
|
||||
class UNIPCBH2(Sampler):
|
||||
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
|
||||
return uni_pc.sample_unipc(model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=self.max_denoise(model_wrap, sigmas), extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2', disable=disable_pbar)
|
||||
return uni_pc.sample_unipc(model_wrap, noise, latent_image, sigmas, max_denoise=self.max_denoise(model_wrap, sigmas), extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2', disable=disable_pbar)
|
||||
|
||||
KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
|
||||
KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
|
||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm"]
|
||||
|
||||
class KSAMPLER(Sampler):
|
||||
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
|
||||
self.sampler_function = sampler_function
|
||||
self.extra_options = extra_options
|
||||
self.inpaint_options = inpaint_options
|
||||
|
||||
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
|
||||
extra_args["denoise_mask"] = denoise_mask
|
||||
model_k = KSamplerX0Inpaint(model_wrap)
|
||||
model_k.latent_image = latent_image
|
||||
if self.inpaint_options.get("random", False): #TODO: Should this be the default?
|
||||
generator = torch.manual_seed(extra_args.get("seed", 41) + 1)
|
||||
model_k.noise = torch.randn(noise.shape, generator=generator, device="cpu").to(noise.dtype).to(noise.device)
|
||||
else:
|
||||
model_k.noise = noise
|
||||
|
||||
if self.max_denoise(model_wrap, sigmas):
|
||||
noise = noise * torch.sqrt(1.0 + sigmas[0] ** 2.0)
|
||||
else:
|
||||
noise = noise * sigmas[0]
|
||||
|
||||
k_callback = None
|
||||
total_steps = len(sigmas) - 1
|
||||
if callback is not None:
|
||||
k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps)
|
||||
|
||||
if latent_image is not None:
|
||||
noise += latent_image
|
||||
|
||||
samples = self.sampler_function(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **self.extra_options)
|
||||
return samples
|
||||
|
||||
|
||||
def ksampler(sampler_name, extra_options={}, inpaint_options={}):
|
||||
class KSAMPLER(Sampler):
|
||||
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
|
||||
extra_args["denoise_mask"] = denoise_mask
|
||||
model_k = KSamplerX0Inpaint(model_wrap)
|
||||
model_k.latent_image = latent_image
|
||||
if inpaint_options.get("random", False): #TODO: Should this be the default?
|
||||
generator = torch.manual_seed(extra_args.get("seed", 41) + 1)
|
||||
model_k.noise = torch.randn(noise.shape, generator=generator, device="cpu").to(noise.dtype).to(noise.device)
|
||||
else:
|
||||
model_k.noise = noise
|
||||
|
||||
if self.max_denoise(model_wrap, sigmas):
|
||||
noise = noise * torch.sqrt(1.0 + sigmas[0] ** 2.0)
|
||||
else:
|
||||
noise = noise * sigmas[0]
|
||||
|
||||
k_callback = None
|
||||
total_steps = len(sigmas) - 1
|
||||
if callback is not None:
|
||||
k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps)
|
||||
|
||||
if sampler_name == "dpm_fast":
|
||||
def dpm_fast_function(model, noise, sigmas, extra_args, callback, disable):
|
||||
sigma_min = sigmas[-1]
|
||||
if sigma_min == 0:
|
||||
sigma_min = sigmas[-2]
|
||||
total_steps = len(sigmas) - 1
|
||||
return k_diffusion_sampling.sample_dpm_fast(model, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=callback, disable=disable)
|
||||
sampler_function = dpm_fast_function
|
||||
elif sampler_name == "dpm_adaptive":
|
||||
def dpm_adaptive_function(model, noise, sigmas, extra_args, callback, disable):
|
||||
sigma_min = sigmas[-1]
|
||||
if sigma_min == 0:
|
||||
sigma_min = sigmas[-2]
|
||||
return k_diffusion_sampling.sample_dpm_adaptive(model, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=callback, disable=disable)
|
||||
sampler_function = dpm_adaptive_function
|
||||
else:
|
||||
sampler_function = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name))
|
||||
|
||||
if latent_image is not None:
|
||||
noise += latent_image
|
||||
if sampler_name == "dpm_fast":
|
||||
samples = k_diffusion_sampling.sample_dpm_fast(model_k, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=k_callback, disable=disable_pbar)
|
||||
elif sampler_name == "dpm_adaptive":
|
||||
samples = k_diffusion_sampling.sample_dpm_adaptive(model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback, disable=disable_pbar)
|
||||
else:
|
||||
samples = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name))(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **extra_options)
|
||||
return samples
|
||||
return KSAMPLER
|
||||
return KSAMPLER(sampler_function, extra_options, inpaint_options)
|
||||
|
||||
def wrap_model(model):
|
||||
model_denoise = CFGNoisePredictor(model)
|
||||
@ -617,11 +637,11 @@ def calculate_sigmas_scheduler(model, scheduler_name, steps):
|
||||
print("error invalid scheduler", self.scheduler)
|
||||
return sigmas
|
||||
|
||||
def sampler_class(name):
|
||||
def sampler_object(name):
|
||||
if name == "uni_pc":
|
||||
sampler = UNIPC
|
||||
sampler = UNIPC()
|
||||
elif name == "uni_pc_bh2":
|
||||
sampler = UNIPCBH2
|
||||
sampler = UNIPCBH2()
|
||||
elif name == "ddim":
|
||||
sampler = ksampler("euler", inpaint_options={"random": True})
|
||||
else:
|
||||
@ -686,6 +706,6 @@ class KSampler:
|
||||
else:
|
||||
return torch.zeros_like(noise)
|
||||
|
||||
sampler = sampler_class(self.sampler)
|
||||
sampler = sampler_object(self.sampler)
|
||||
|
||||
return sample(self.model, noise, positive, negative, cfg, self.device, sampler(), sigmas, self.model_options, latent_image=latent_image, denoise_mask=denoise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
||||
return sample(self.model, noise, positive, negative, cfg, self.device, sampler, sigmas, self.model_options, latent_image=latent_image, denoise_mask=denoise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
||||
|
63
comfy/sd.py
63
comfy/sd.py
@ -23,6 +23,7 @@ import comfy.model_patcher
|
||||
import comfy.lora
|
||||
import comfy.t2i_adapter.adapter
|
||||
import comfy.supported_models_base
|
||||
import comfy.taesd.taesd
|
||||
|
||||
def load_model_weights(model, sd):
|
||||
m, u = model.load_state_dict(sd, strict=False)
|
||||
@ -95,10 +96,7 @@ class CLIP:
|
||||
load_device = model_management.text_encoder_device()
|
||||
offload_device = model_management.text_encoder_offload_device()
|
||||
params['device'] = offload_device
|
||||
if model_management.should_use_fp16(load_device, prioritize_performance=False):
|
||||
params['dtype'] = torch.float16
|
||||
else:
|
||||
params['dtype'] = torch.float32
|
||||
params['dtype'] = model_management.text_encoder_dtype(load_device)
|
||||
|
||||
self.cond_stage_model = clip(**(params))
|
||||
|
||||
@ -157,10 +155,24 @@ class VAE:
|
||||
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
|
||||
sd = diffusers_convert.convert_vae_state_dict(sd)
|
||||
|
||||
self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) #These are for AutoencoderKL and need tweaking (should be lower)
|
||||
self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype)
|
||||
|
||||
if config is None:
|
||||
#default SD1.x/SD2.x VAE parameters
|
||||
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
||||
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4)
|
||||
if "decoder.mid.block_1.mix_factor" in sd:
|
||||
encoder_config = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
||||
decoder_config = encoder_config.copy()
|
||||
decoder_config["video_kernel_size"] = [3, 1, 1]
|
||||
decoder_config["alpha"] = 0.0
|
||||
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
|
||||
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': encoder_config},
|
||||
decoder_config={'target': "comfy.ldm.modules.temporal_ae.VideoDecoder", 'params': decoder_config})
|
||||
elif "taesd_decoder.1.weight" in sd:
|
||||
self.first_stage_model = comfy.taesd.taesd.TAESD()
|
||||
else:
|
||||
#default SD1.x/SD2.x VAE parameters
|
||||
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
||||
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4)
|
||||
else:
|
||||
self.first_stage_model = AutoencoderKL(**(config['params']))
|
||||
self.first_stage_model = self.first_stage_model.eval()
|
||||
@ -175,10 +187,12 @@ class VAE:
|
||||
if device is None:
|
||||
device = model_management.vae_device()
|
||||
self.device = device
|
||||
self.offload_device = model_management.vae_offload_device()
|
||||
offload_device = model_management.vae_offload_device()
|
||||
self.vae_dtype = model_management.vae_dtype()
|
||||
self.first_stage_model.to(self.vae_dtype)
|
||||
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
|
||||
|
||||
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
|
||||
steps = samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap)
|
||||
steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap)
|
||||
@ -207,10 +221,9 @@ class VAE:
|
||||
return samples
|
||||
|
||||
def decode(self, samples_in):
|
||||
self.first_stage_model = self.first_stage_model.to(self.device)
|
||||
try:
|
||||
memory_used = (2562 * samples_in.shape[2] * samples_in.shape[3] * 64) * 1.7
|
||||
model_management.free_memory(memory_used, self.device)
|
||||
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
||||
free_memory = model_management.get_free_memory(self.device)
|
||||
batch_number = int(free_memory / memory_used)
|
||||
batch_number = max(1, batch_number)
|
||||
@ -223,22 +236,19 @@ class VAE:
|
||||
print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
||||
pixel_samples = self.decode_tiled_(samples_in)
|
||||
|
||||
self.first_stage_model = self.first_stage_model.to(self.offload_device)
|
||||
pixel_samples = pixel_samples.cpu().movedim(1,-1)
|
||||
return pixel_samples
|
||||
|
||||
def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16):
|
||||
self.first_stage_model = self.first_stage_model.to(self.device)
|
||||
model_management.load_model_gpu(self.patcher)
|
||||
output = self.decode_tiled_(samples, tile_x, tile_y, overlap)
|
||||
self.first_stage_model = self.first_stage_model.to(self.offload_device)
|
||||
return output.movedim(1,-1)
|
||||
|
||||
def encode(self, pixel_samples):
|
||||
self.first_stage_model = self.first_stage_model.to(self.device)
|
||||
pixel_samples = pixel_samples.movedim(-1,1)
|
||||
try:
|
||||
memory_used = (2078 * pixel_samples.shape[2] * pixel_samples.shape[3]) * 1.7 #NOTE: this constant along with the one in the decode above are estimated from the mem usage for the VAE and could change.
|
||||
model_management.free_memory(memory_used, self.device)
|
||||
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
||||
free_memory = model_management.get_free_memory(self.device)
|
||||
batch_number = int(free_memory / memory_used)
|
||||
batch_number = max(1, batch_number)
|
||||
@ -251,14 +261,12 @@ class VAE:
|
||||
print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
|
||||
samples = self.encode_tiled_(pixel_samples)
|
||||
|
||||
self.first_stage_model = self.first_stage_model.to(self.offload_device)
|
||||
return samples
|
||||
|
||||
def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
|
||||
self.first_stage_model = self.first_stage_model.to(self.device)
|
||||
model_management.load_model_gpu(self.patcher)
|
||||
pixel_samples = pixel_samples.movedim(-1,1)
|
||||
samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap)
|
||||
self.first_stage_model = self.first_stage_model.to(self.offload_device)
|
||||
return samples
|
||||
|
||||
def get_sd(self):
|
||||
@ -444,6 +452,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
|
||||
if output_vae:
|
||||
vae_sd = comfy.utils.state_dict_prefix_replace(sd, {"first_stage_model.": ""}, filter_keys=True)
|
||||
vae_sd = model_config.process_vae_state_dict(vae_sd)
|
||||
vae = VAE(sd=vae_sd)
|
||||
|
||||
if output_clip:
|
||||
@ -468,20 +477,18 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
return (model_patcher, clip, vae, clipvision)
|
||||
|
||||
|
||||
def load_unet(unet_path): #load unet in diffusers format
|
||||
sd = comfy.utils.load_torch_file(unet_path)
|
||||
def load_unet_state_dict(sd): #load unet in diffusers format
|
||||
parameters = comfy.utils.calculate_parameters(sd)
|
||||
unet_dtype = model_management.unet_dtype(model_params=parameters)
|
||||
if "input_blocks.0.0.weight" in sd: #ldm
|
||||
model_config = model_detection.model_config_from_unet(sd, "", unet_dtype)
|
||||
if model_config is None:
|
||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
|
||||
return None
|
||||
new_sd = sd
|
||||
|
||||
else: #diffusers
|
||||
model_config = model_detection.model_config_from_diffusers_unet(sd, unet_dtype)
|
||||
if model_config is None:
|
||||
print("ERROR UNSUPPORTED UNET", unet_path)
|
||||
return None
|
||||
|
||||
diffusers_keys = comfy.utils.unet_to_diffusers(model_config.unet_config)
|
||||
@ -501,6 +508,14 @@ def load_unet(unet_path): #load unet in diffusers format
|
||||
print("left over keys in unet:", left_over)
|
||||
return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device)
|
||||
|
||||
def load_unet(unet_path):
|
||||
sd = comfy.utils.load_torch_file(unet_path)
|
||||
model = load_unet_state_dict(sd)
|
||||
if model is None:
|
||||
print("ERROR UNSUPPORTED UNET", unet_path)
|
||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
|
||||
return model
|
||||
|
||||
def save_checkpoint(output_path, model, clip, vae, metadata=None):
|
||||
model_management.load_models_gpu([model, clip.load_model()])
|
||||
sd = model.model.state_dict_for_saving(clip.get_sd(), vae.get_sd())
|
||||
|
@ -173,9 +173,9 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
if getattr(self.transformer, self.inner_name).final_layer_norm.weight.dtype != torch.float32:
|
||||
precision_scope = torch.autocast
|
||||
else:
|
||||
precision_scope = lambda a, b: contextlib.nullcontext(a)
|
||||
precision_scope = lambda a, dtype: contextlib.nullcontext(a)
|
||||
|
||||
with precision_scope(model_management.get_autocast_device(device), torch.float32):
|
||||
with precision_scope(model_management.get_autocast_device(device), dtype=torch.float32):
|
||||
attention_mask = None
|
||||
if self.enable_attention_masks:
|
||||
attention_mask = torch.zeros_like(tokens)
|
||||
|
@ -17,6 +17,7 @@ class SD15(supported_models_base.BASE):
|
||||
"model_channels": 320,
|
||||
"use_linear_in_transformer": False,
|
||||
"adm_in_channels": None,
|
||||
"use_temporal_attention": False,
|
||||
}
|
||||
|
||||
unet_extra_config = {
|
||||
@ -56,6 +57,7 @@ class SD20(supported_models_base.BASE):
|
||||
"model_channels": 320,
|
||||
"use_linear_in_transformer": True,
|
||||
"adm_in_channels": None,
|
||||
"use_temporal_attention": False,
|
||||
}
|
||||
|
||||
latent_format = latent_formats.SD15
|
||||
@ -69,6 +71,10 @@ class SD20(supported_models_base.BASE):
|
||||
return model_base.ModelType.EPS
|
||||
|
||||
def process_clip_state_dict(self, state_dict):
|
||||
replace_prefix = {}
|
||||
replace_prefix["conditioner.embedders.0.model."] = "cond_stage_model.model." #SD2 in sgm format
|
||||
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
|
||||
state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.clip_h.transformer.text_model.", 24)
|
||||
return state_dict
|
||||
|
||||
@ -88,6 +94,7 @@ class SD21UnclipL(SD20):
|
||||
"model_channels": 320,
|
||||
"use_linear_in_transformer": True,
|
||||
"adm_in_channels": 1536,
|
||||
"use_temporal_attention": False,
|
||||
}
|
||||
|
||||
clip_vision_prefix = "embedder.model.visual."
|
||||
@ -100,6 +107,7 @@ class SD21UnclipH(SD20):
|
||||
"model_channels": 320,
|
||||
"use_linear_in_transformer": True,
|
||||
"adm_in_channels": 2048,
|
||||
"use_temporal_attention": False,
|
||||
}
|
||||
|
||||
clip_vision_prefix = "embedder.model.visual."
|
||||
@ -112,6 +120,7 @@ class SDXLRefiner(supported_models_base.BASE):
|
||||
"context_dim": 1280,
|
||||
"adm_in_channels": 2560,
|
||||
"transformer_depth": [0, 0, 4, 4, 4, 4, 0, 0],
|
||||
"use_temporal_attention": False,
|
||||
}
|
||||
|
||||
latent_format = latent_formats.SDXL
|
||||
@ -148,7 +157,8 @@ class SDXL(supported_models_base.BASE):
|
||||
"use_linear_in_transformer": True,
|
||||
"transformer_depth": [0, 0, 2, 2, 10, 10],
|
||||
"context_dim": 2048,
|
||||
"adm_in_channels": 2816
|
||||
"adm_in_channels": 2816,
|
||||
"use_temporal_attention": False,
|
||||
}
|
||||
|
||||
latent_format = latent_formats.SDXL
|
||||
@ -203,8 +213,34 @@ class SSD1B(SDXL):
|
||||
"use_linear_in_transformer": True,
|
||||
"transformer_depth": [0, 0, 2, 2, 4, 4],
|
||||
"context_dim": 2048,
|
||||
"adm_in_channels": 2816
|
||||
"adm_in_channels": 2816,
|
||||
"use_temporal_attention": False,
|
||||
}
|
||||
|
||||
class SVD_img2vid(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"model_channels": 320,
|
||||
"in_channels": 8,
|
||||
"use_linear_in_transformer": True,
|
||||
"transformer_depth": [1, 1, 1, 1, 1, 1, 0, 0],
|
||||
"context_dim": 1024,
|
||||
"adm_in_channels": 768,
|
||||
"use_temporal_attention": True,
|
||||
"use_temporal_resblock": True
|
||||
}
|
||||
|
||||
clip_vision_prefix = "conditioner.embedders.0.open_clip.model.visual."
|
||||
|
||||
latent_format = latent_formats.SD15
|
||||
|
||||
sampling_settings = {"sigma_max": 700.0, "sigma_min": 0.002}
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.SVD_img2vid(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self):
|
||||
return None
|
||||
|
||||
models = [SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B]
|
||||
models += [SVD_img2vid]
|
||||
|
@ -19,7 +19,7 @@ class BASE:
|
||||
clip_prefix = []
|
||||
clip_vision_prefix = None
|
||||
noise_aug_config = None
|
||||
beta_schedule = "linear"
|
||||
sampling_settings = {}
|
||||
latent_format = latent_formats.LatentFormat
|
||||
|
||||
@classmethod
|
||||
@ -53,6 +53,12 @@ class BASE:
|
||||
def process_clip_state_dict(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
def process_unet_state_dict(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
def process_vae_state_dict(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
def process_clip_state_dict_for_saving(self, state_dict):
|
||||
replace_prefix = {"": "cond_stage_model."}
|
||||
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
|
@ -46,15 +46,16 @@ class TAESD(nn.Module):
|
||||
latent_magnitude = 3
|
||||
latent_shift = 0.5
|
||||
|
||||
def __init__(self, encoder_path="taesd_encoder.pth", decoder_path="taesd_decoder.pth"):
|
||||
def __init__(self, encoder_path=None, decoder_path=None):
|
||||
"""Initialize pretrained TAESD on the given device from the given checkpoints."""
|
||||
super().__init__()
|
||||
self.encoder = Encoder()
|
||||
self.decoder = Decoder()
|
||||
self.taesd_encoder = Encoder()
|
||||
self.taesd_decoder = Decoder()
|
||||
self.vae_scale = torch.nn.Parameter(torch.tensor(1.0))
|
||||
if encoder_path is not None:
|
||||
self.encoder.load_state_dict(comfy.utils.load_torch_file(encoder_path, safe_load=True))
|
||||
self.taesd_encoder.load_state_dict(comfy.utils.load_torch_file(encoder_path, safe_load=True))
|
||||
if decoder_path is not None:
|
||||
self.decoder.load_state_dict(comfy.utils.load_torch_file(decoder_path, safe_load=True))
|
||||
self.taesd_decoder.load_state_dict(comfy.utils.load_torch_file(decoder_path, safe_load=True))
|
||||
|
||||
@staticmethod
|
||||
def scale_latents(x):
|
||||
@ -65,3 +66,11 @@ class TAESD(nn.Module):
|
||||
def unscale_latents(x):
|
||||
"""[0, 1] -> raw latents"""
|
||||
return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)
|
||||
|
||||
def decode(self, x):
|
||||
x_sample = self.taesd_decoder(x * self.vae_scale)
|
||||
x_sample = x_sample.sub(0.5).mul(2)
|
||||
return x_sample
|
||||
|
||||
def encode(self, x):
|
||||
return self.taesd_encoder(x * 0.5 + 0.5) / self.vae_scale
|
||||
|
@ -258,7 +258,7 @@ def set_attr(obj, attr, value):
|
||||
for name in attrs[:-1]:
|
||||
obj = getattr(obj, name)
|
||||
prev = getattr(obj, attrs[-1])
|
||||
setattr(obj, attrs[-1], torch.nn.Parameter(value))
|
||||
setattr(obj, attrs[-1], torch.nn.Parameter(value, requires_grad=False))
|
||||
del prev
|
||||
|
||||
def copy_to_param(obj, attr, value):
|
||||
@ -307,23 +307,25 @@ def bislerp(samples, width, height):
|
||||
res[dot < 1e-5 - 1] = (b1 * (1.0-r) + b2 * r)[dot < 1e-5 - 1]
|
||||
return res
|
||||
|
||||
def generate_bilinear_data(length_old, length_new):
|
||||
coords_1 = torch.arange(length_old).reshape((1,1,1,-1)).to(torch.float32)
|
||||
def generate_bilinear_data(length_old, length_new, device):
|
||||
coords_1 = torch.arange(length_old, dtype=torch.float32, device=device).reshape((1,1,1,-1))
|
||||
coords_1 = torch.nn.functional.interpolate(coords_1, size=(1, length_new), mode="bilinear")
|
||||
ratios = coords_1 - coords_1.floor()
|
||||
coords_1 = coords_1.to(torch.int64)
|
||||
|
||||
coords_2 = torch.arange(length_old).reshape((1,1,1,-1)).to(torch.float32) + 1
|
||||
coords_2 = torch.arange(length_old, dtype=torch.float32, device=device).reshape((1,1,1,-1)) + 1
|
||||
coords_2[:,:,:,-1] -= 1
|
||||
coords_2 = torch.nn.functional.interpolate(coords_2, size=(1, length_new), mode="bilinear")
|
||||
coords_2 = coords_2.to(torch.int64)
|
||||
return ratios, coords_1, coords_2
|
||||
|
||||
|
||||
orig_dtype = samples.dtype
|
||||
samples = samples.float()
|
||||
n,c,h,w = samples.shape
|
||||
h_new, w_new = (height, width)
|
||||
|
||||
#linear w
|
||||
ratios, coords_1, coords_2 = generate_bilinear_data(w, w_new)
|
||||
ratios, coords_1, coords_2 = generate_bilinear_data(w, w_new, samples.device)
|
||||
coords_1 = coords_1.expand((n, c, h, -1))
|
||||
coords_2 = coords_2.expand((n, c, h, -1))
|
||||
ratios = ratios.expand((n, 1, h, -1))
|
||||
@ -336,7 +338,7 @@ def bislerp(samples, width, height):
|
||||
result = result.reshape(n, h, w_new, c).movedim(-1, 1)
|
||||
|
||||
#linear h
|
||||
ratios, coords_1, coords_2 = generate_bilinear_data(h, h_new)
|
||||
ratios, coords_1, coords_2 = generate_bilinear_data(h, h_new, samples.device)
|
||||
coords_1 = coords_1.reshape((1,1,-1,1)).expand((n, c, -1, w_new))
|
||||
coords_2 = coords_2.reshape((1,1,-1,1)).expand((n, c, -1, w_new))
|
||||
ratios = ratios.reshape((1,1,-1,1)).expand((n, 1, -1, w_new))
|
||||
@ -347,7 +349,7 @@ def bislerp(samples, width, height):
|
||||
|
||||
result = slerp(pass_1, pass_2, ratios)
|
||||
result = result.reshape(n, h_new, w_new, c).movedim(-1, 1)
|
||||
return result
|
||||
return result.to(orig_dtype)
|
||||
|
||||
def lanczos(samples, width, height):
|
||||
images = [Image.fromarray(np.clip(255. * image.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8)) for image in samples]
|
||||
|
@ -16,7 +16,7 @@ class BasicScheduler:
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("SIGMAS",)
|
||||
CATEGORY = "sampling/custom_sampling"
|
||||
CATEGORY = "sampling/custom_sampling/schedulers"
|
||||
|
||||
FUNCTION = "get_sigmas"
|
||||
|
||||
@ -36,7 +36,7 @@ class KarrasScheduler:
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("SIGMAS",)
|
||||
CATEGORY = "sampling/custom_sampling"
|
||||
CATEGORY = "sampling/custom_sampling/schedulers"
|
||||
|
||||
FUNCTION = "get_sigmas"
|
||||
|
||||
@ -54,7 +54,7 @@ class ExponentialScheduler:
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("SIGMAS",)
|
||||
CATEGORY = "sampling/custom_sampling"
|
||||
CATEGORY = "sampling/custom_sampling/schedulers"
|
||||
|
||||
FUNCTION = "get_sigmas"
|
||||
|
||||
@ -73,7 +73,7 @@ class PolyexponentialScheduler:
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("SIGMAS",)
|
||||
CATEGORY = "sampling/custom_sampling"
|
||||
CATEGORY = "sampling/custom_sampling/schedulers"
|
||||
|
||||
FUNCTION = "get_sigmas"
|
||||
|
||||
@ -81,6 +81,25 @@ class PolyexponentialScheduler:
|
||||
sigmas = k_diffusion_sampling.get_sigmas_polyexponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho)
|
||||
return (sigmas, )
|
||||
|
||||
class SDTurboScheduler:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
{"model": ("MODEL",),
|
||||
"steps": ("INT", {"default": 1, "min": 1, "max": 10}),
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("SIGMAS",)
|
||||
CATEGORY = "sampling/custom_sampling/schedulers"
|
||||
|
||||
FUNCTION = "get_sigmas"
|
||||
|
||||
def get_sigmas(self, model, steps):
|
||||
timesteps = torch.flip(torch.arange(1, 11) * 100 - 1, (0,))[:steps]
|
||||
sigmas = model.model.model_sampling.sigma(timesteps)
|
||||
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
|
||||
return (sigmas, )
|
||||
|
||||
class VPScheduler:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -92,7 +111,7 @@ class VPScheduler:
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("SIGMAS",)
|
||||
CATEGORY = "sampling/custom_sampling"
|
||||
CATEGORY = "sampling/custom_sampling/schedulers"
|
||||
|
||||
FUNCTION = "get_sigmas"
|
||||
|
||||
@ -109,7 +128,7 @@ class SplitSigmas:
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("SIGMAS","SIGMAS")
|
||||
CATEGORY = "sampling/custom_sampling"
|
||||
CATEGORY = "sampling/custom_sampling/sigmas"
|
||||
|
||||
FUNCTION = "get_sigmas"
|
||||
|
||||
@ -118,6 +137,24 @@ class SplitSigmas:
|
||||
sigmas2 = sigmas[step:]
|
||||
return (sigmas1, sigmas2)
|
||||
|
||||
class FlipSigmas:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
{"sigmas": ("SIGMAS", ),
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("SIGMAS",)
|
||||
CATEGORY = "sampling/custom_sampling/sigmas"
|
||||
|
||||
FUNCTION = "get_sigmas"
|
||||
|
||||
def get_sigmas(self, sigmas):
|
||||
sigmas = sigmas.flip(0)
|
||||
if sigmas[0] == 0:
|
||||
sigmas[0] = 0.0001
|
||||
return (sigmas,)
|
||||
|
||||
class KSamplerSelect:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -126,12 +163,12 @@ class KSamplerSelect:
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("SAMPLER",)
|
||||
CATEGORY = "sampling/custom_sampling"
|
||||
CATEGORY = "sampling/custom_sampling/samplers"
|
||||
|
||||
FUNCTION = "get_sampler"
|
||||
|
||||
def get_sampler(self, sampler_name):
|
||||
sampler = comfy.samplers.sampler_class(sampler_name)()
|
||||
sampler = comfy.samplers.sampler_object(sampler_name)
|
||||
return (sampler, )
|
||||
|
||||
class SamplerDPMPP_2M_SDE:
|
||||
@ -145,7 +182,7 @@ class SamplerDPMPP_2M_SDE:
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("SAMPLER",)
|
||||
CATEGORY = "sampling/custom_sampling"
|
||||
CATEGORY = "sampling/custom_sampling/samplers"
|
||||
|
||||
FUNCTION = "get_sampler"
|
||||
|
||||
@ -154,7 +191,7 @@ class SamplerDPMPP_2M_SDE:
|
||||
sampler_name = "dpmpp_2m_sde"
|
||||
else:
|
||||
sampler_name = "dpmpp_2m_sde_gpu"
|
||||
sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "solver_type": solver_type})()
|
||||
sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "solver_type": solver_type})
|
||||
return (sampler, )
|
||||
|
||||
|
||||
@ -169,7 +206,7 @@ class SamplerDPMPP_SDE:
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("SAMPLER",)
|
||||
CATEGORY = "sampling/custom_sampling"
|
||||
CATEGORY = "sampling/custom_sampling/samplers"
|
||||
|
||||
FUNCTION = "get_sampler"
|
||||
|
||||
@ -178,7 +215,7 @@ class SamplerDPMPP_SDE:
|
||||
sampler_name = "dpmpp_sde"
|
||||
else:
|
||||
sampler_name = "dpmpp_sde_gpu"
|
||||
sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r})()
|
||||
sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r})
|
||||
return (sampler, )
|
||||
|
||||
class SamplerCustom:
|
||||
@ -234,13 +271,15 @@ class SamplerCustom:
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"SamplerCustom": SamplerCustom,
|
||||
"BasicScheduler": BasicScheduler,
|
||||
"KarrasScheduler": KarrasScheduler,
|
||||
"ExponentialScheduler": ExponentialScheduler,
|
||||
"PolyexponentialScheduler": PolyexponentialScheduler,
|
||||
"VPScheduler": VPScheduler,
|
||||
"SDTurboScheduler": SDTurboScheduler,
|
||||
"KSamplerSelect": KSamplerSelect,
|
||||
"SamplerDPMPP_2M_SDE": SamplerDPMPP_2M_SDE,
|
||||
"SamplerDPMPP_SDE": SamplerDPMPP_SDE,
|
||||
"BasicScheduler": BasicScheduler,
|
||||
"SplitSigmas": SplitSigmas,
|
||||
"FlipSigmas": FlipSigmas,
|
||||
}
|
||||
|
175
comfy_extras/nodes_images.py
Normal file
175
comfy_extras/nodes_images.py
Normal file
@ -0,0 +1,175 @@
|
||||
import nodes
|
||||
import folder_paths
|
||||
from comfy.cli_args import args
|
||||
|
||||
from PIL import Image
|
||||
from PIL.PngImagePlugin import PngInfo
|
||||
|
||||
import numpy as np
|
||||
import json
|
||||
import os
|
||||
|
||||
MAX_RESOLUTION = nodes.MAX_RESOLUTION
|
||||
|
||||
class ImageCrop:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "image": ("IMAGE",),
|
||||
"width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
|
||||
"height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
|
||||
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
||||
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
||||
}}
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "crop"
|
||||
|
||||
CATEGORY = "image/transform"
|
||||
|
||||
def crop(self, image, width, height, x, y):
|
||||
x = min(x, image.shape[2] - 1)
|
||||
y = min(y, image.shape[1] - 1)
|
||||
to_x = width + x
|
||||
to_y = height + y
|
||||
img = image[:,y:to_y, x:to_x, :]
|
||||
return (img,)
|
||||
|
||||
class RepeatImageBatch:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "image": ("IMAGE",),
|
||||
"amount": ("INT", {"default": 1, "min": 1, "max": 64}),
|
||||
}}
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "repeat"
|
||||
|
||||
CATEGORY = "image/batch"
|
||||
|
||||
def repeat(self, image, amount):
|
||||
s = image.repeat((amount, 1,1,1))
|
||||
return (s,)
|
||||
|
||||
class SaveAnimatedWEBP:
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
self.type = "output"
|
||||
self.prefix_append = ""
|
||||
|
||||
methods = {"default": 4, "fastest": 0, "slowest": 6}
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
{"images": ("IMAGE", ),
|
||||
"filename_prefix": ("STRING", {"default": "ComfyUI"}),
|
||||
"fps": ("FLOAT", {"default": 6.0, "min": 0.01, "max": 1000.0, "step": 0.01}),
|
||||
"lossless": ("BOOLEAN", {"default": True}),
|
||||
"quality": ("INT", {"default": 80, "min": 0, "max": 100}),
|
||||
"method": (list(s.methods.keys()),),
|
||||
# "num_frames": ("INT", {"default": 0, "min": 0, "max": 8192}),
|
||||
},
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save_images"
|
||||
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "_for_testing"
|
||||
|
||||
def save_images(self, images, fps, filename_prefix, lossless, quality, method, num_frames=0, prompt=None, extra_pnginfo=None):
|
||||
method = self.methods.get(method)
|
||||
filename_prefix += self.prefix_append
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
|
||||
results = list()
|
||||
pil_images = []
|
||||
for image in images:
|
||||
i = 255. * image.cpu().numpy()
|
||||
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
|
||||
pil_images.append(img)
|
||||
|
||||
metadata = pil_images[0].getexif()
|
||||
if not args.disable_metadata:
|
||||
if prompt is not None:
|
||||
metadata[0x0110] = "prompt:{}".format(json.dumps(prompt))
|
||||
if extra_pnginfo is not None:
|
||||
inital_exif = 0x010f
|
||||
for x in extra_pnginfo:
|
||||
metadata[inital_exif] = "{}:{}".format(x, json.dumps(extra_pnginfo[x]))
|
||||
inital_exif -= 1
|
||||
|
||||
if num_frames == 0:
|
||||
num_frames = len(pil_images)
|
||||
|
||||
c = len(pil_images)
|
||||
for i in range(0, c, num_frames):
|
||||
file = f"{filename}_{counter:05}_.webp"
|
||||
pil_images[i].save(os.path.join(full_output_folder, file), save_all=True, duration=int(1000.0/fps), append_images=pil_images[i + 1:i + num_frames], exif=metadata, lossless=lossless, quality=quality, method=method)
|
||||
results.append({
|
||||
"filename": file,
|
||||
"subfolder": subfolder,
|
||||
"type": self.type
|
||||
})
|
||||
counter += 1
|
||||
|
||||
animated = num_frames != 1
|
||||
return { "ui": { "images": results, "animated": (animated,) } }
|
||||
|
||||
class SaveAnimatedPNG:
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
self.type = "output"
|
||||
self.prefix_append = ""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
{"images": ("IMAGE", ),
|
||||
"filename_prefix": ("STRING", {"default": "ComfyUI"}),
|
||||
"fps": ("FLOAT", {"default": 6.0, "min": 0.01, "max": 1000.0, "step": 0.01}),
|
||||
"compress_level": ("INT", {"default": 4, "min": 0, "max": 9})
|
||||
},
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save_images"
|
||||
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "_for_testing"
|
||||
|
||||
def save_images(self, images, fps, compress_level, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
|
||||
filename_prefix += self.prefix_append
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
|
||||
results = list()
|
||||
pil_images = []
|
||||
for image in images:
|
||||
i = 255. * image.cpu().numpy()
|
||||
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
|
||||
pil_images.append(img)
|
||||
|
||||
metadata = None
|
||||
if not args.disable_metadata:
|
||||
metadata = PngInfo()
|
||||
if prompt is not None:
|
||||
metadata.add(b"comf", "prompt".encode("latin-1", "strict") + b"\0" + json.dumps(prompt).encode("latin-1", "strict"), after_idat=True)
|
||||
if extra_pnginfo is not None:
|
||||
for x in extra_pnginfo:
|
||||
metadata.add(b"comf", x.encode("latin-1", "strict") + b"\0" + json.dumps(extra_pnginfo[x]).encode("latin-1", "strict"), after_idat=True)
|
||||
|
||||
file = f"{filename}_{counter:05}_.png"
|
||||
pil_images[0].save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=compress_level, save_all=True, duration=int(1000.0/fps), append_images=pil_images[1:])
|
||||
results.append({
|
||||
"filename": file,
|
||||
"subfolder": subfolder,
|
||||
"type": self.type
|
||||
})
|
||||
|
||||
return { "ui": { "images": results, "animated": (True,)} }
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ImageCrop": ImageCrop,
|
||||
"RepeatImageBatch": RepeatImageBatch,
|
||||
"SaveAnimatedWEBP": SaveAnimatedWEBP,
|
||||
"SaveAnimatedPNG": SaveAnimatedPNG,
|
||||
}
|
@ -1,4 +1,5 @@
|
||||
import comfy.utils
|
||||
import torch
|
||||
|
||||
def reshape_latent_to(target_shape, latent):
|
||||
if latent.shape[1:] != target_shape[1:]:
|
||||
@ -67,8 +68,43 @@ class LatentMultiply:
|
||||
samples_out["samples"] = s1 * multiplier
|
||||
return (samples_out,)
|
||||
|
||||
class LatentInterpolate:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "samples1": ("LATENT",),
|
||||
"samples2": ("LATENT",),
|
||||
"ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "op"
|
||||
|
||||
CATEGORY = "latent/advanced"
|
||||
|
||||
def op(self, samples1, samples2, ratio):
|
||||
samples_out = samples1.copy()
|
||||
|
||||
s1 = samples1["samples"]
|
||||
s2 = samples2["samples"]
|
||||
|
||||
s2 = reshape_latent_to(s1.shape, s2)
|
||||
|
||||
m1 = torch.linalg.vector_norm(s1, dim=(1))
|
||||
m2 = torch.linalg.vector_norm(s2, dim=(1))
|
||||
|
||||
s1 = torch.nan_to_num(s1 / m1)
|
||||
s2 = torch.nan_to_num(s2 / m2)
|
||||
|
||||
t = (s1 * ratio + s2 * (1.0 - ratio))
|
||||
mt = torch.linalg.vector_norm(t, dim=(1))
|
||||
st = torch.nan_to_num(t / mt)
|
||||
|
||||
samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio))
|
||||
return (samples_out,)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"LatentAdd": LatentAdd,
|
||||
"LatentSubtract": LatentSubtract,
|
||||
"LatentMultiply": LatentMultiply,
|
||||
"LatentInterpolate": LatentInterpolate,
|
||||
}
|
||||
|
@ -17,7 +17,9 @@ class LCM(comfy.model_sampling.EPS):
|
||||
|
||||
return c_out * x0 + c_skip * model_input
|
||||
|
||||
class ModelSamplingDiscreteLCM(torch.nn.Module):
|
||||
class ModelSamplingDiscreteDistilled(torch.nn.Module):
|
||||
original_timesteps = 50
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.sigma_data = 1.0
|
||||
@ -29,13 +31,12 @@ class ModelSamplingDiscreteLCM(torch.nn.Module):
|
||||
alphas = 1.0 - betas
|
||||
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||||
|
||||
original_timesteps = 50
|
||||
self.skip_steps = timesteps // original_timesteps
|
||||
self.skip_steps = timesteps // self.original_timesteps
|
||||
|
||||
|
||||
alphas_cumprod_valid = torch.zeros((original_timesteps), dtype=torch.float32)
|
||||
for x in range(original_timesteps):
|
||||
alphas_cumprod_valid[original_timesteps - 1 - x] = alphas_cumprod[timesteps - 1 - x * self.skip_steps]
|
||||
alphas_cumprod_valid = torch.zeros((self.original_timesteps), dtype=torch.float32)
|
||||
for x in range(self.original_timesteps):
|
||||
alphas_cumprod_valid[self.original_timesteps - 1 - x] = alphas_cumprod[timesteps - 1 - x * self.skip_steps]
|
||||
|
||||
sigmas = ((1 - alphas_cumprod_valid) / alphas_cumprod_valid) ** 0.5
|
||||
self.set_sigmas(sigmas)
|
||||
@ -55,18 +56,23 @@ class ModelSamplingDiscreteLCM(torch.nn.Module):
|
||||
def timestep(self, sigma):
|
||||
log_sigma = sigma.log()
|
||||
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
|
||||
return dists.abs().argmin(dim=0).view(sigma.shape) * self.skip_steps + (self.skip_steps - 1)
|
||||
return (dists.abs().argmin(dim=0).view(sigma.shape) * self.skip_steps + (self.skip_steps - 1)).to(sigma.device)
|
||||
|
||||
def sigma(self, timestep):
|
||||
t = torch.clamp(((timestep - (self.skip_steps - 1)) / self.skip_steps).float(), min=0, max=(len(self.sigmas) - 1))
|
||||
t = torch.clamp(((timestep.float().to(self.log_sigmas.device) - (self.skip_steps - 1)) / self.skip_steps).float(), min=0, max=(len(self.sigmas) - 1))
|
||||
low_idx = t.floor().long()
|
||||
high_idx = t.ceil().long()
|
||||
w = t.frac()
|
||||
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
|
||||
return log_sigma.exp()
|
||||
return log_sigma.exp().to(timestep.device)
|
||||
|
||||
def percent_to_sigma(self, percent):
|
||||
return self.sigma(torch.tensor(percent * 999.0))
|
||||
if percent <= 0.0:
|
||||
return 999999999.9
|
||||
if percent >= 1.0:
|
||||
return 0.0
|
||||
percent = 1.0 - percent
|
||||
return self.sigma(torch.tensor(percent * 999.0)).item()
|
||||
|
||||
|
||||
def rescale_zero_terminal_snr_sigmas(sigmas):
|
||||
@ -111,7 +117,7 @@ class ModelSamplingDiscrete:
|
||||
sampling_type = comfy.model_sampling.V_PREDICTION
|
||||
elif sampling == "lcm":
|
||||
sampling_type = LCM
|
||||
sampling_base = ModelSamplingDiscreteLCM
|
||||
sampling_base = ModelSamplingDiscreteDistilled
|
||||
|
||||
class ModelSamplingAdvanced(sampling_base, sampling_type):
|
||||
pass
|
||||
@ -123,6 +129,36 @@ class ModelSamplingDiscrete:
|
||||
m.add_object_patch("model_sampling", model_sampling)
|
||||
return (m, )
|
||||
|
||||
class ModelSamplingContinuousEDM:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"sampling": (["v_prediction", "eps"],),
|
||||
"sigma_max": ("FLOAT", {"default": 120.0, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
|
||||
"sigma_min": ("FLOAT", {"default": 0.002, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
|
||||
CATEGORY = "advanced/model"
|
||||
|
||||
def patch(self, model, sampling, sigma_max, sigma_min):
|
||||
m = model.clone()
|
||||
|
||||
if sampling == "eps":
|
||||
sampling_type = comfy.model_sampling.EPS
|
||||
elif sampling == "v_prediction":
|
||||
sampling_type = comfy.model_sampling.V_PREDICTION
|
||||
|
||||
class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingContinuousEDM, sampling_type):
|
||||
pass
|
||||
|
||||
model_sampling = ModelSamplingAdvanced()
|
||||
model_sampling.set_sigma_range(sigma_min, sigma_max)
|
||||
m.add_object_patch("model_sampling", model_sampling)
|
||||
return (m, )
|
||||
|
||||
class RescaleCFG:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -164,5 +200,6 @@ class RescaleCFG:
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ModelSamplingDiscrete": ModelSamplingDiscrete,
|
||||
"ModelSamplingContinuousEDM": ModelSamplingContinuousEDM,
|
||||
"RescaleCFG": RescaleCFG,
|
||||
}
|
||||
|
53
comfy_extras/nodes_model_downscale.py
Normal file
53
comfy_extras/nodes_model_downscale.py
Normal file
@ -0,0 +1,53 @@
|
||||
import torch
|
||||
import comfy.utils
|
||||
|
||||
class PatchModelAddDownscale:
|
||||
upscale_methods = ["bicubic", "nearest-exact", "bilinear", "area", "bislerp"]
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"block_number": ("INT", {"default": 3, "min": 1, "max": 32, "step": 1}),
|
||||
"downscale_factor": ("FLOAT", {"default": 2.0, "min": 0.1, "max": 9.0, "step": 0.001}),
|
||||
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||
"end_percent": ("FLOAT", {"default": 0.35, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||
"downscale_after_skip": ("BOOLEAN", {"default": True}),
|
||||
"downscale_method": (s.upscale_methods,),
|
||||
"upscale_method": (s.upscale_methods,),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
|
||||
CATEGORY = "_for_testing"
|
||||
|
||||
def patch(self, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip, downscale_method, upscale_method):
|
||||
sigma_start = model.model.model_sampling.percent_to_sigma(start_percent)
|
||||
sigma_end = model.model.model_sampling.percent_to_sigma(end_percent)
|
||||
|
||||
def input_block_patch(h, transformer_options):
|
||||
if transformer_options["block"][1] == block_number:
|
||||
sigma = transformer_options["sigmas"][0].item()
|
||||
if sigma <= sigma_start and sigma >= sigma_end:
|
||||
h = comfy.utils.common_upscale(h, round(h.shape[-1] * (1.0 / downscale_factor)), round(h.shape[-2] * (1.0 / downscale_factor)), downscale_method, "disabled")
|
||||
return h
|
||||
|
||||
def output_block_patch(h, hsp, transformer_options):
|
||||
if h.shape[2] != hsp.shape[2]:
|
||||
h = comfy.utils.common_upscale(h, hsp.shape[-1], hsp.shape[-2], upscale_method, "disabled")
|
||||
return h, hsp
|
||||
|
||||
m = model.clone()
|
||||
if downscale_after_skip:
|
||||
m.set_model_input_block_patch_after_skip(input_block_patch)
|
||||
else:
|
||||
m.set_model_input_block_patch(input_block_patch)
|
||||
m.set_model_output_block_patch(output_block_patch)
|
||||
return (m, )
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"PatchModelAddDownscale": PatchModelAddDownscale,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
# Sampling
|
||||
"PatchModelAddDownscale": "PatchModelAddDownscale (Kohya Deep Shrink)",
|
||||
}
|
89
comfy_extras/nodes_video_model.py
Normal file
89
comfy_extras/nodes_video_model.py
Normal file
@ -0,0 +1,89 @@
|
||||
import nodes
|
||||
import torch
|
||||
import comfy.utils
|
||||
import comfy.sd
|
||||
import folder_paths
|
||||
|
||||
|
||||
class ImageOnlyCheckpointLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL", "CLIP_VISION", "VAE")
|
||||
FUNCTION = "load_checkpoint"
|
||||
|
||||
CATEGORY = "loaders/video_models"
|
||||
|
||||
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
|
||||
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
|
||||
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=False, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||
return (out[0], out[3], out[2])
|
||||
|
||||
|
||||
class SVD_img2vid_Conditioning:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip_vision": ("CLIP_VISION",),
|
||||
"init_image": ("IMAGE",),
|
||||
"vae": ("VAE",),
|
||||
"width": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
||||
"height": ("INT", {"default": 576, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
||||
"video_frames": ("INT", {"default": 14, "min": 1, "max": 4096}),
|
||||
"motion_bucket_id": ("INT", {"default": 127, "min": 1, "max": 1023}),
|
||||
"fps": ("INT", {"default": 6, "min": 1, "max": 1024}),
|
||||
"augmentation_level": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.01})
|
||||
}}
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
||||
RETURN_NAMES = ("positive", "negative", "latent")
|
||||
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning/video_models"
|
||||
|
||||
def encode(self, clip_vision, init_image, vae, width, height, video_frames, motion_bucket_id, fps, augmentation_level):
|
||||
output = clip_vision.encode_image(init_image)
|
||||
pooled = output.image_embeds.unsqueeze(0)
|
||||
pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1)
|
||||
encode_pixels = pixels[:,:,:,:3]
|
||||
if augmentation_level > 0:
|
||||
encode_pixels += torch.randn_like(pixels) * augmentation_level
|
||||
t = vae.encode(encode_pixels)
|
||||
positive = [[pooled, {"motion_bucket_id": motion_bucket_id, "fps": fps, "augmentation_level": augmentation_level, "concat_latent_image": t}]]
|
||||
negative = [[torch.zeros_like(pooled), {"motion_bucket_id": motion_bucket_id, "fps": fps, "augmentation_level": augmentation_level, "concat_latent_image": torch.zeros_like(t)}]]
|
||||
latent = torch.zeros([video_frames, 4, height // 8, width // 8])
|
||||
return (positive, negative, {"samples":latent})
|
||||
|
||||
class VideoLinearCFGGuidance:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"min_cfg": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
|
||||
CATEGORY = "sampling/video_models"
|
||||
|
||||
def patch(self, model, min_cfg):
|
||||
def linear_cfg(args):
|
||||
cond = args["cond"]
|
||||
uncond = args["uncond"]
|
||||
cond_scale = args["cond_scale"]
|
||||
|
||||
scale = torch.linspace(min_cfg, cond_scale, cond.shape[0], device=cond.device).reshape((cond.shape[0], 1, 1, 1))
|
||||
return uncond + scale * (cond - uncond)
|
||||
|
||||
m = model.clone()
|
||||
m.set_model_sampler_cfg_function(linear_cfg)
|
||||
return (m, )
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ImageOnlyCheckpointLoader": ImageOnlyCheckpointLoader,
|
||||
"SVD_img2vid_Conditioning": SVD_img2vid_Conditioning,
|
||||
"VideoLinearCFGGuidance": VideoLinearCFGGuidance,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"ImageOnlyCheckpointLoader": "Image Only Checkpoint Loader (img2vid model)",
|
||||
}
|
23
execution.py
23
execution.py
@ -681,6 +681,7 @@ def validate_prompt(prompt):
|
||||
|
||||
return (True, None, list(good_outputs), node_errors)
|
||||
|
||||
MAXIMUM_HISTORY_SIZE = 10000
|
||||
|
||||
class PromptQueue:
|
||||
def __init__(self, server):
|
||||
@ -699,10 +700,12 @@ class PromptQueue:
|
||||
self.server.queue_updated()
|
||||
self.not_empty.notify()
|
||||
|
||||
def get(self):
|
||||
def get(self, timeout=None):
|
||||
with self.not_empty:
|
||||
while len(self.queue) == 0:
|
||||
self.not_empty.wait()
|
||||
self.not_empty.wait(timeout=timeout)
|
||||
if timeout is not None and len(self.queue) == 0:
|
||||
return None
|
||||
item = heapq.heappop(self.queue)
|
||||
i = self.task_counter
|
||||
self.currently_running[i] = copy.deepcopy(item)
|
||||
@ -713,6 +716,8 @@ class PromptQueue:
|
||||
def task_done(self, item_id, outputs):
|
||||
with self.mutex:
|
||||
prompt = self.currently_running.pop(item_id)
|
||||
if len(self.history) > MAXIMUM_HISTORY_SIZE:
|
||||
self.history.pop(next(iter(self.history)))
|
||||
self.history[prompt[1]] = { "prompt": prompt, "outputs": {} }
|
||||
for o in outputs:
|
||||
self.history[prompt[1]]["outputs"][o] = outputs[o]
|
||||
@ -747,10 +752,20 @@ class PromptQueue:
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_history(self, prompt_id=None):
|
||||
def get_history(self, prompt_id=None, max_items=None, offset=-1):
|
||||
with self.mutex:
|
||||
if prompt_id is None:
|
||||
return copy.deepcopy(self.history)
|
||||
out = {}
|
||||
i = 0
|
||||
if offset < 0 and max_items is not None:
|
||||
offset = len(self.history) - max_items
|
||||
for k in self.history:
|
||||
if i >= offset:
|
||||
out[k] = self.history[k]
|
||||
if max_items is not None and len(out) >= max_items:
|
||||
break
|
||||
i += 1
|
||||
return out
|
||||
elif prompt_id in self.history:
|
||||
return {prompt_id: copy.deepcopy(self.history[prompt_id])}
|
||||
else:
|
||||
|
@ -38,7 +38,10 @@ input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "inp
|
||||
filename_list_cache = {}
|
||||
|
||||
if not os.path.exists(input_directory):
|
||||
os.makedirs(input_directory)
|
||||
try:
|
||||
os.makedirs(input_directory)
|
||||
except:
|
||||
print("Failed to create input directory")
|
||||
|
||||
def set_output_directory(output_dir):
|
||||
global output_directory
|
||||
@ -228,8 +231,12 @@ def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height
|
||||
full_output_folder = os.path.join(output_dir, subfolder)
|
||||
|
||||
if os.path.commonpath((output_dir, os.path.abspath(full_output_folder))) != output_dir:
|
||||
print("Saving image outside the output folder is not allowed.")
|
||||
return {}
|
||||
err = "**** ERROR: Saving image outside the output folder is not allowed." + \
|
||||
"\n full_output_folder: " + os.path.abspath(full_output_folder) + \
|
||||
"\n output_dir: " + output_dir + \
|
||||
"\n commonpath: " + os.path.commonpath((output_dir, os.path.abspath(full_output_folder)))
|
||||
print(err)
|
||||
raise Exception(err)
|
||||
|
||||
try:
|
||||
counter = max(filter(lambda a: a[1][:-1] == filename and a[1][-1] == "_", map(map_filename, os.listdir(full_output_folder))))[0] + 1
|
||||
|
@ -22,10 +22,7 @@ class TAESDPreviewerImpl(LatentPreviewer):
|
||||
self.taesd = taesd
|
||||
|
||||
def decode_latent_to_preview(self, x0):
|
||||
x_sample = self.taesd.decoder(x0[:1])[0].detach()
|
||||
# x_sample = self.taesd.unscale_latents(x_sample).div(4).add(0.5) # returns value in [-2, 2]
|
||||
x_sample = x_sample.sub(0.5).mul(2)
|
||||
|
||||
x_sample = self.taesd.decode(x0[:1])[0].detach()
|
||||
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
||||
x_sample = x_sample.astype(np.uint8)
|
||||
|
41
main.py
41
main.py
@ -88,18 +88,37 @@ def cuda_malloc_warning():
|
||||
|
||||
def prompt_worker(q, server):
|
||||
e = execution.PromptExecutor(server)
|
||||
while True:
|
||||
item, item_id = q.get()
|
||||
execution_start_time = time.perf_counter()
|
||||
prompt_id = item[1]
|
||||
e.execute(item[2], prompt_id, item[3], item[4])
|
||||
q.task_done(item_id, e.outputs_ui)
|
||||
if server.client_id is not None:
|
||||
server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, server.client_id)
|
||||
last_gc_collect = 0
|
||||
need_gc = False
|
||||
gc_collect_interval = 10.0
|
||||
|
||||
print("Prompt executed in {:.2f} seconds".format(time.perf_counter() - execution_start_time))
|
||||
gc.collect()
|
||||
comfy.model_management.soft_empty_cache()
|
||||
while True:
|
||||
timeout = None
|
||||
if need_gc:
|
||||
timeout = max(gc_collect_interval - (current_time - last_gc_collect), 0.0)
|
||||
|
||||
queue_item = q.get(timeout=timeout)
|
||||
if queue_item is not None:
|
||||
item, item_id = queue_item
|
||||
execution_start_time = time.perf_counter()
|
||||
prompt_id = item[1]
|
||||
e.execute(item[2], prompt_id, item[3], item[4])
|
||||
need_gc = True
|
||||
q.task_done(item_id, e.outputs_ui)
|
||||
if server.client_id is not None:
|
||||
server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, server.client_id)
|
||||
|
||||
current_time = time.perf_counter()
|
||||
execution_time = current_time - execution_start_time
|
||||
print("Prompt executed in {:.2f} seconds".format(execution_time))
|
||||
|
||||
if need_gc:
|
||||
current_time = time.perf_counter()
|
||||
if (current_time - last_gc_collect) > gc_collect_interval:
|
||||
gc.collect()
|
||||
comfy.model_management.soft_empty_cache()
|
||||
last_gc_collect = current_time
|
||||
need_gc = False
|
||||
|
||||
async def run(server, address='', port=8188, verbose=True, call_on_start=None):
|
||||
await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop())
|
||||
|
86
nodes.py
86
nodes.py
@ -248,8 +248,8 @@ class ConditioningSetTimestepRange:
|
||||
c = []
|
||||
for t in conditioning:
|
||||
d = t[1].copy()
|
||||
d['start_percent'] = 1.0 - start
|
||||
d['end_percent'] = 1.0 - end
|
||||
d['start_percent'] = start
|
||||
d['end_percent'] = end
|
||||
n = [t[0], d]
|
||||
c.append(n)
|
||||
return (c, )
|
||||
@ -572,10 +572,69 @@ class LoraLoader:
|
||||
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip)
|
||||
return (model_lora, clip_lora)
|
||||
|
||||
class VAELoader:
|
||||
class LoraLoaderModelOnly(LoraLoader):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "vae_name": (folder_paths.get_filename_list("vae"), )}}
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"lora_name": (folder_paths.get_filename_list("loras"), ),
|
||||
"strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "load_lora_model_only"
|
||||
|
||||
def load_lora_model_only(self, model, lora_name, strength_model):
|
||||
return (self.load_lora(model, None, lora_name, strength_model, 0)[0],)
|
||||
|
||||
class VAELoader:
|
||||
@staticmethod
|
||||
def vae_list():
|
||||
vaes = folder_paths.get_filename_list("vae")
|
||||
approx_vaes = folder_paths.get_filename_list("vae_approx")
|
||||
sdxl_taesd_enc = False
|
||||
sdxl_taesd_dec = False
|
||||
sd1_taesd_enc = False
|
||||
sd1_taesd_dec = False
|
||||
|
||||
for v in approx_vaes:
|
||||
if v.startswith("taesd_decoder."):
|
||||
sd1_taesd_dec = True
|
||||
elif v.startswith("taesd_encoder."):
|
||||
sd1_taesd_enc = True
|
||||
elif v.startswith("taesdxl_decoder."):
|
||||
sdxl_taesd_dec = True
|
||||
elif v.startswith("taesdxl_encoder."):
|
||||
sdxl_taesd_enc = True
|
||||
if sd1_taesd_dec and sd1_taesd_enc:
|
||||
vaes.append("taesd")
|
||||
if sdxl_taesd_dec and sdxl_taesd_enc:
|
||||
vaes.append("taesdxl")
|
||||
return vaes
|
||||
|
||||
@staticmethod
|
||||
def load_taesd(name):
|
||||
sd = {}
|
||||
approx_vaes = folder_paths.get_filename_list("vae_approx")
|
||||
|
||||
encoder = next(filter(lambda a: a.startswith("{}_encoder.".format(name)), approx_vaes))
|
||||
decoder = next(filter(lambda a: a.startswith("{}_decoder.".format(name)), approx_vaes))
|
||||
|
||||
enc = comfy.utils.load_torch_file(folder_paths.get_full_path("vae_approx", encoder))
|
||||
for k in enc:
|
||||
sd["taesd_encoder.{}".format(k)] = enc[k]
|
||||
|
||||
dec = comfy.utils.load_torch_file(folder_paths.get_full_path("vae_approx", decoder))
|
||||
for k in dec:
|
||||
sd["taesd_decoder.{}".format(k)] = dec[k]
|
||||
|
||||
if name == "taesd":
|
||||
sd["vae_scale"] = torch.tensor(0.18215)
|
||||
elif name == "taesdxl":
|
||||
sd["vae_scale"] = torch.tensor(0.13025)
|
||||
return sd
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "vae_name": (s.vae_list(), )}}
|
||||
RETURN_TYPES = ("VAE",)
|
||||
FUNCTION = "load_vae"
|
||||
|
||||
@ -583,8 +642,11 @@ class VAELoader:
|
||||
|
||||
#TODO: scale factor?
|
||||
def load_vae(self, vae_name):
|
||||
vae_path = folder_paths.get_full_path("vae", vae_name)
|
||||
sd = comfy.utils.load_torch_file(vae_path)
|
||||
if vae_name in ["taesd", "taesdxl"]:
|
||||
sd = self.load_taesd(vae_name)
|
||||
else:
|
||||
vae_path = folder_paths.get_full_path("vae", vae_name)
|
||||
sd = comfy.utils.load_torch_file(vae_path)
|
||||
vae = comfy.sd.VAE(sd=sd)
|
||||
return (vae,)
|
||||
|
||||
@ -685,7 +747,7 @@ class ControlNetApplyAdvanced:
|
||||
if prev_cnet in cnets:
|
||||
c_net = cnets[prev_cnet]
|
||||
else:
|
||||
c_net = control_net.copy().set_cond_hint(control_hint, strength, (1.0 - start_percent, 1.0 - end_percent))
|
||||
c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent))
|
||||
c_net.set_previous_controlnet(prev_cnet)
|
||||
cnets[prev_cnet] = c_net
|
||||
|
||||
@ -1275,6 +1337,7 @@ class SaveImage:
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
self.type = "output"
|
||||
self.prefix_append = ""
|
||||
self.compress_level = 4
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -1308,7 +1371,7 @@ class SaveImage:
|
||||
metadata.add_text(x, json.dumps(extra_pnginfo[x]))
|
||||
|
||||
file = f"{filename}_{counter:05}_.png"
|
||||
img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=4)
|
||||
img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=self.compress_level)
|
||||
results.append({
|
||||
"filename": file,
|
||||
"subfolder": subfolder,
|
||||
@ -1323,6 +1386,7 @@ class PreviewImage(SaveImage):
|
||||
self.output_dir = folder_paths.get_temp_directory()
|
||||
self.type = "temp"
|
||||
self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5))
|
||||
self.compress_level = 1
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -1654,6 +1718,7 @@ NODE_CLASS_MAPPINGS = {
|
||||
|
||||
"ConditioningZeroOut": ConditioningZeroOut,
|
||||
"ConditioningSetTimestepRange": ConditioningSetTimestepRange,
|
||||
"LoraLoaderModelOnly": LoraLoaderModelOnly,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
@ -1759,7 +1824,7 @@ def load_custom_nodes():
|
||||
node_paths = folder_paths.get_folder_paths("custom_nodes")
|
||||
node_import_times = []
|
||||
for custom_node_path in node_paths:
|
||||
possible_modules = os.listdir(custom_node_path)
|
||||
possible_modules = os.listdir(os.path.realpath(custom_node_path))
|
||||
if "__pycache__" in possible_modules:
|
||||
possible_modules.remove("__pycache__")
|
||||
|
||||
@ -1799,6 +1864,9 @@ def init_custom_nodes():
|
||||
"nodes_custom_sampler.py",
|
||||
"nodes_hypertile.py",
|
||||
"nodes_model_advanced.py",
|
||||
"nodes_model_downscale.py",
|
||||
"nodes_images.py",
|
||||
"nodes_video_model.py",
|
||||
]
|
||||
|
||||
for node_file in extras_files:
|
||||
|
@ -431,7 +431,10 @@ class PromptServer():
|
||||
|
||||
@routes.get("/history")
|
||||
async def get_history(request):
|
||||
return web.json_response(self.prompt_queue.get_history())
|
||||
max_items = request.rel_url.query.get("max_items", None)
|
||||
if max_items is not None:
|
||||
max_items = int(max_items)
|
||||
return web.json_response(self.prompt_queue.get_history(max_items=max_items))
|
||||
|
||||
@routes.get("/history/{prompt_id}")
|
||||
async def get_history(request):
|
||||
@ -573,7 +576,7 @@ class PromptServer():
|
||||
bytesIO = BytesIO()
|
||||
header = struct.pack(">I", type_num)
|
||||
bytesIO.write(header)
|
||||
image.save(bytesIO, format=image_type, quality=95, compress_level=4)
|
||||
image.save(bytesIO, format=image_type, quality=95, compress_level=1)
|
||||
preview_bytes = bytesIO.getvalue()
|
||||
await self.send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid)
|
||||
|
||||
|
@ -20,6 +20,7 @@ async function setup() {
|
||||
// Modify the response data to add some checkpoints
|
||||
const objectInfo = JSON.parse(data);
|
||||
objectInfo.CheckpointLoaderSimple.input.required.ckpt_name[0] = ["model1.safetensors", "model2.ckpt"];
|
||||
objectInfo.VAELoader.input.required.vae_name[0] = ["vae1.safetensors", "vae2.ckpt"];
|
||||
|
||||
data = JSON.stringify(objectInfo, undefined, "\t");
|
||||
|
||||
|
196
tests-ui/tests/extensions.test.js
Normal file
196
tests-ui/tests/extensions.test.js
Normal file
@ -0,0 +1,196 @@
|
||||
// @ts-check
|
||||
/// <reference path="../node_modules/@types/jest/index.d.ts" />
|
||||
const { start } = require("../utils");
|
||||
const lg = require("../utils/litegraph");
|
||||
|
||||
describe("extensions", () => {
|
||||
beforeEach(() => {
|
||||
lg.setup(global);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
lg.teardown(global);
|
||||
});
|
||||
|
||||
it("calls each extension hook", async () => {
|
||||
const mockExtension = {
|
||||
name: "TestExtension",
|
||||
init: jest.fn(),
|
||||
setup: jest.fn(),
|
||||
addCustomNodeDefs: jest.fn(),
|
||||
getCustomWidgets: jest.fn(),
|
||||
beforeRegisterNodeDef: jest.fn(),
|
||||
registerCustomNodes: jest.fn(),
|
||||
loadedGraphNode: jest.fn(),
|
||||
nodeCreated: jest.fn(),
|
||||
beforeConfigureGraph: jest.fn(),
|
||||
afterConfigureGraph: jest.fn(),
|
||||
};
|
||||
|
||||
const { app, ez, graph } = await start({
|
||||
async preSetup(app) {
|
||||
app.registerExtension(mockExtension);
|
||||
},
|
||||
});
|
||||
|
||||
// Basic initialisation hooks should be called once, with app
|
||||
expect(mockExtension.init).toHaveBeenCalledTimes(1);
|
||||
expect(mockExtension.init).toHaveBeenCalledWith(app);
|
||||
|
||||
// Adding custom node defs should be passed the full list of nodes
|
||||
expect(mockExtension.addCustomNodeDefs).toHaveBeenCalledTimes(1);
|
||||
expect(mockExtension.addCustomNodeDefs.mock.calls[0][1]).toStrictEqual(app);
|
||||
const defs = mockExtension.addCustomNodeDefs.mock.calls[0][0];
|
||||
expect(defs).toHaveProperty("KSampler");
|
||||
expect(defs).toHaveProperty("LoadImage");
|
||||
|
||||
// Get custom widgets is called once and should return new widget types
|
||||
expect(mockExtension.getCustomWidgets).toHaveBeenCalledTimes(1);
|
||||
expect(mockExtension.getCustomWidgets).toHaveBeenCalledWith(app);
|
||||
|
||||
// Before register node def will be called once per node type
|
||||
const nodeNames = Object.keys(defs);
|
||||
const nodeCount = nodeNames.length;
|
||||
expect(mockExtension.beforeRegisterNodeDef).toHaveBeenCalledTimes(nodeCount);
|
||||
for (let i = 0; i < nodeCount; i++) {
|
||||
// It should be send the JS class and the original JSON definition
|
||||
const nodeClass = mockExtension.beforeRegisterNodeDef.mock.calls[i][0];
|
||||
const nodeDef = mockExtension.beforeRegisterNodeDef.mock.calls[i][1];
|
||||
|
||||
expect(nodeClass.name).toBe("ComfyNode");
|
||||
expect(nodeClass.comfyClass).toBe(nodeNames[i]);
|
||||
expect(nodeDef.name).toBe(nodeNames[i]);
|
||||
expect(nodeDef).toHaveProperty("input");
|
||||
expect(nodeDef).toHaveProperty("output");
|
||||
}
|
||||
|
||||
// Register custom nodes is called once after registerNode defs to allow adding other frontend nodes
|
||||
expect(mockExtension.registerCustomNodes).toHaveBeenCalledTimes(1);
|
||||
|
||||
// Before configure graph will be called here as the default graph is being loaded
|
||||
expect(mockExtension.beforeConfigureGraph).toHaveBeenCalledTimes(1);
|
||||
// it gets sent the graph data that is going to be loaded
|
||||
const graphData = mockExtension.beforeConfigureGraph.mock.calls[0][0];
|
||||
|
||||
// A node created is fired for each node constructor that is called
|
||||
expect(mockExtension.nodeCreated).toHaveBeenCalledTimes(graphData.nodes.length);
|
||||
for (let i = 0; i < graphData.nodes.length; i++) {
|
||||
expect(mockExtension.nodeCreated.mock.calls[i][0].type).toBe(graphData.nodes[i].type);
|
||||
}
|
||||
|
||||
// Each node then calls loadedGraphNode to allow them to be updated
|
||||
expect(mockExtension.loadedGraphNode).toHaveBeenCalledTimes(graphData.nodes.length);
|
||||
for (let i = 0; i < graphData.nodes.length; i++) {
|
||||
expect(mockExtension.loadedGraphNode.mock.calls[i][0].type).toBe(graphData.nodes[i].type);
|
||||
}
|
||||
|
||||
// After configure is then called once all the setup is done
|
||||
expect(mockExtension.afterConfigureGraph).toHaveBeenCalledTimes(1);
|
||||
|
||||
expect(mockExtension.setup).toHaveBeenCalledTimes(1);
|
||||
expect(mockExtension.setup).toHaveBeenCalledWith(app);
|
||||
|
||||
// Ensure hooks are called in the correct order
|
||||
const callOrder = [
|
||||
"init",
|
||||
"addCustomNodeDefs",
|
||||
"getCustomWidgets",
|
||||
"beforeRegisterNodeDef",
|
||||
"registerCustomNodes",
|
||||
"beforeConfigureGraph",
|
||||
"nodeCreated",
|
||||
"loadedGraphNode",
|
||||
"afterConfigureGraph",
|
||||
"setup",
|
||||
];
|
||||
for (let i = 1; i < callOrder.length; i++) {
|
||||
const fn1 = mockExtension[callOrder[i - 1]];
|
||||
const fn2 = mockExtension[callOrder[i]];
|
||||
expect(fn1.mock.invocationCallOrder[0]).toBeLessThan(fn2.mock.invocationCallOrder[0]);
|
||||
}
|
||||
|
||||
graph.clear();
|
||||
|
||||
// Ensure adding a new node calls the correct callback
|
||||
ez.LoadImage();
|
||||
expect(mockExtension.loadedGraphNode).toHaveBeenCalledTimes(graphData.nodes.length);
|
||||
expect(mockExtension.nodeCreated).toHaveBeenCalledTimes(graphData.nodes.length + 1);
|
||||
expect(mockExtension.nodeCreated.mock.lastCall[0].type).toBe("LoadImage");
|
||||
|
||||
// Reload the graph to ensure correct hooks are fired
|
||||
await graph.reload();
|
||||
|
||||
// These hooks should not be fired again
|
||||
expect(mockExtension.init).toHaveBeenCalledTimes(1);
|
||||
expect(mockExtension.addCustomNodeDefs).toHaveBeenCalledTimes(1);
|
||||
expect(mockExtension.getCustomWidgets).toHaveBeenCalledTimes(1);
|
||||
expect(mockExtension.registerCustomNodes).toHaveBeenCalledTimes(1);
|
||||
expect(mockExtension.beforeRegisterNodeDef).toHaveBeenCalledTimes(nodeCount);
|
||||
expect(mockExtension.setup).toHaveBeenCalledTimes(1);
|
||||
|
||||
// These should be called again
|
||||
expect(mockExtension.beforeConfigureGraph).toHaveBeenCalledTimes(2);
|
||||
expect(mockExtension.nodeCreated).toHaveBeenCalledTimes(graphData.nodes.length + 2);
|
||||
expect(mockExtension.loadedGraphNode).toHaveBeenCalledTimes(graphData.nodes.length + 1);
|
||||
expect(mockExtension.afterConfigureGraph).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
it("allows custom nodeDefs and widgets to be registered", async () => {
|
||||
const widgetMock = jest.fn((node, inputName, inputData, app) => {
|
||||
expect(node.constructor.comfyClass).toBe("TestNode");
|
||||
expect(inputName).toBe("test_input");
|
||||
expect(inputData[0]).toBe("CUSTOMWIDGET");
|
||||
expect(inputData[1]?.hello).toBe("world");
|
||||
expect(app).toStrictEqual(app);
|
||||
|
||||
return {
|
||||
widget: node.addWidget("button", inputName, "hello", () => {}),
|
||||
};
|
||||
});
|
||||
|
||||
// Register our extension that adds a custom node + widget type
|
||||
const mockExtension = {
|
||||
name: "TestExtension",
|
||||
addCustomNodeDefs: (nodeDefs) => {
|
||||
nodeDefs["TestNode"] = {
|
||||
output: [],
|
||||
output_name: [],
|
||||
output_is_list: [],
|
||||
name: "TestNode",
|
||||
display_name: "TestNode",
|
||||
category: "Test",
|
||||
input: {
|
||||
required: {
|
||||
test_input: ["CUSTOMWIDGET", { hello: "world" }],
|
||||
},
|
||||
},
|
||||
};
|
||||
},
|
||||
getCustomWidgets: jest.fn(() => {
|
||||
return {
|
||||
CUSTOMWIDGET: widgetMock,
|
||||
};
|
||||
}),
|
||||
};
|
||||
|
||||
const { graph, ez } = await start({
|
||||
async preSetup(app) {
|
||||
app.registerExtension(mockExtension);
|
||||
},
|
||||
});
|
||||
|
||||
expect(mockExtension.getCustomWidgets).toBeCalledTimes(1);
|
||||
|
||||
graph.clear();
|
||||
expect(widgetMock).toBeCalledTimes(0);
|
||||
const node = ez.TestNode();
|
||||
expect(widgetMock).toBeCalledTimes(1);
|
||||
|
||||
// Ensure our custom widget is created
|
||||
expect(node.inputs.length).toBe(0);
|
||||
expect(node.widgets.length).toBe(1);
|
||||
const w = node.widgets[0].widget;
|
||||
expect(w.name).toBe("test_input");
|
||||
expect(w.type).toBe("button");
|
||||
});
|
||||
});
|
818
tests-ui/tests/groupNode.test.js
Normal file
818
tests-ui/tests/groupNode.test.js
Normal file
@ -0,0 +1,818 @@
|
||||
// @ts-check
|
||||
/// <reference path="../node_modules/@types/jest/index.d.ts" />
|
||||
|
||||
const { start, createDefaultWorkflow } = require("../utils");
|
||||
const lg = require("../utils/litegraph");
|
||||
|
||||
describe("group node", () => {
|
||||
beforeEach(() => {
|
||||
lg.setup(global);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
lg.teardown(global);
|
||||
});
|
||||
|
||||
/**
|
||||
*
|
||||
* @param {*} app
|
||||
* @param {*} graph
|
||||
* @param {*} name
|
||||
* @param {*} nodes
|
||||
* @returns { Promise<InstanceType<import("../utils/ezgraph")["EzNode"]>> }
|
||||
*/
|
||||
async function convertToGroup(app, graph, name, nodes) {
|
||||
// Select the nodes we are converting
|
||||
for (const n of nodes) {
|
||||
n.select(true);
|
||||
}
|
||||
|
||||
expect(Object.keys(app.canvas.selected_nodes).sort((a, b) => +a - +b)).toEqual(
|
||||
nodes.map((n) => n.id + "").sort((a, b) => +a - +b)
|
||||
);
|
||||
|
||||
global.prompt = jest.fn().mockImplementation(() => name);
|
||||
const groupNode = await nodes[0].menu["Convert to Group Node"].call(false);
|
||||
|
||||
// Check group name was requested
|
||||
expect(window.prompt).toHaveBeenCalled();
|
||||
|
||||
// Ensure old nodes are removed
|
||||
for (const n of nodes) {
|
||||
expect(n.isRemoved).toBeTruthy();
|
||||
}
|
||||
|
||||
expect(groupNode.type).toEqual("workflow/" + name);
|
||||
|
||||
return graph.find(groupNode);
|
||||
}
|
||||
|
||||
/**
|
||||
* @param { Record<string, string | number> | number[] } idMap
|
||||
* @param { Record<string, Record<string, unknown>> } valueMap
|
||||
*/
|
||||
function getOutput(idMap = {}, valueMap = {}) {
|
||||
if (idMap instanceof Array) {
|
||||
idMap = idMap.reduce((p, n) => {
|
||||
p[n] = n + "";
|
||||
return p;
|
||||
}, {});
|
||||
}
|
||||
const expected = {
|
||||
1: { inputs: { ckpt_name: "model1.safetensors", ...valueMap?.[1] }, class_type: "CheckpointLoaderSimple" },
|
||||
2: { inputs: { text: "positive", clip: ["1", 1], ...valueMap?.[2] }, class_type: "CLIPTextEncode" },
|
||||
3: { inputs: { text: "negative", clip: ["1", 1], ...valueMap?.[3] }, class_type: "CLIPTextEncode" },
|
||||
4: { inputs: { width: 512, height: 512, batch_size: 1, ...valueMap?.[4] }, class_type: "EmptyLatentImage" },
|
||||
5: {
|
||||
inputs: {
|
||||
seed: 0,
|
||||
steps: 20,
|
||||
cfg: 8,
|
||||
sampler_name: "euler",
|
||||
scheduler: "normal",
|
||||
denoise: 1,
|
||||
model: ["1", 0],
|
||||
positive: ["2", 0],
|
||||
negative: ["3", 0],
|
||||
latent_image: ["4", 0],
|
||||
...valueMap?.[5],
|
||||
},
|
||||
class_type: "KSampler",
|
||||
},
|
||||
6: { inputs: { samples: ["5", 0], vae: ["1", 2], ...valueMap?.[6] }, class_type: "VAEDecode" },
|
||||
7: { inputs: { filename_prefix: "ComfyUI", images: ["6", 0], ...valueMap?.[7] }, class_type: "SaveImage" },
|
||||
};
|
||||
|
||||
// Map old IDs to new at the top level
|
||||
const mapped = {};
|
||||
for (const oldId in idMap) {
|
||||
mapped[idMap[oldId]] = expected[oldId];
|
||||
delete expected[oldId];
|
||||
}
|
||||
Object.assign(mapped, expected);
|
||||
|
||||
// Map old IDs to new inside links
|
||||
for (const k in mapped) {
|
||||
for (const input in mapped[k].inputs) {
|
||||
const v = mapped[k].inputs[input];
|
||||
if (v instanceof Array) {
|
||||
if (v[0] in idMap) {
|
||||
v[0] = idMap[v[0]] + "";
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return mapped;
|
||||
}
|
||||
|
||||
test("can be created from selected nodes", async () => {
|
||||
const { ez, graph, app } = await start();
|
||||
const nodes = createDefaultWorkflow(ez, graph);
|
||||
const group = await convertToGroup(app, graph, "test", [nodes.pos, nodes.neg, nodes.empty]);
|
||||
|
||||
// Ensure links are now to the group node
|
||||
expect(group.inputs).toHaveLength(2);
|
||||
expect(group.outputs).toHaveLength(3);
|
||||
|
||||
expect(group.inputs.map((i) => i.input.name)).toEqual(["clip", "CLIPTextEncode clip"]);
|
||||
expect(group.outputs.map((i) => i.output.name)).toEqual(["LATENT", "CONDITIONING", "CLIPTextEncode CONDITIONING"]);
|
||||
|
||||
// ckpt clip to both clip inputs on the group
|
||||
expect(nodes.ckpt.outputs.CLIP.connections.map((t) => [t.targetNode.id, t.targetInput.index])).toEqual([
|
||||
[group.id, 0],
|
||||
[group.id, 1],
|
||||
]);
|
||||
|
||||
// group conditioning to sampler
|
||||
expect(group.outputs["CONDITIONING"].connections.map((t) => [t.targetNode.id, t.targetInput.index])).toEqual([
|
||||
[nodes.sampler.id, 1],
|
||||
]);
|
||||
// group conditioning 2 to sampler
|
||||
expect(
|
||||
group.outputs["CLIPTextEncode CONDITIONING"].connections.map((t) => [t.targetNode.id, t.targetInput.index])
|
||||
).toEqual([[nodes.sampler.id, 2]]);
|
||||
// group latent to sampler
|
||||
expect(group.outputs["LATENT"].connections.map((t) => [t.targetNode.id, t.targetInput.index])).toEqual([
|
||||
[nodes.sampler.id, 3],
|
||||
]);
|
||||
});
|
||||
|
||||
test("maintains all output links on conversion", async () => {
|
||||
const { ez, graph, app } = await start();
|
||||
const nodes = createDefaultWorkflow(ez, graph);
|
||||
const save2 = ez.SaveImage(...nodes.decode.outputs);
|
||||
const save3 = ez.SaveImage(...nodes.decode.outputs);
|
||||
// Ensure an output with multiple links maintains them on convert to group
|
||||
const group = await convertToGroup(app, graph, "test", [nodes.sampler, nodes.decode]);
|
||||
expect(group.outputs[0].connections.length).toBe(3);
|
||||
expect(group.outputs[0].connections[0].targetNode.id).toBe(nodes.save.id);
|
||||
expect(group.outputs[0].connections[1].targetNode.id).toBe(save2.id);
|
||||
expect(group.outputs[0].connections[2].targetNode.id).toBe(save3.id);
|
||||
|
||||
// and they're still linked when converting back to nodes
|
||||
const newNodes = group.menu["Convert to nodes"].call();
|
||||
const decode = graph.find(newNodes.find((n) => n.type === "VAEDecode"));
|
||||
expect(decode.outputs[0].connections.length).toBe(3);
|
||||
expect(decode.outputs[0].connections[0].targetNode.id).toBe(nodes.save.id);
|
||||
expect(decode.outputs[0].connections[1].targetNode.id).toBe(save2.id);
|
||||
expect(decode.outputs[0].connections[2].targetNode.id).toBe(save3.id);
|
||||
});
|
||||
test("can be be converted back to nodes", async () => {
|
||||
const { ez, graph, app } = await start();
|
||||
const nodes = createDefaultWorkflow(ez, graph);
|
||||
const toConvert = [nodes.pos, nodes.neg, nodes.empty, nodes.sampler];
|
||||
const group = await convertToGroup(app, graph, "test", toConvert);
|
||||
|
||||
// Edit some values to ensure they are set back onto the converted nodes
|
||||
expect(group.widgets["text"].value).toBe("positive");
|
||||
group.widgets["text"].value = "pos";
|
||||
expect(group.widgets["CLIPTextEncode text"].value).toBe("negative");
|
||||
group.widgets["CLIPTextEncode text"].value = "neg";
|
||||
expect(group.widgets["width"].value).toBe(512);
|
||||
group.widgets["width"].value = 1024;
|
||||
expect(group.widgets["sampler_name"].value).toBe("euler");
|
||||
group.widgets["sampler_name"].value = "ddim";
|
||||
expect(group.widgets["control_after_generate"].value).toBe("randomize");
|
||||
group.widgets["control_after_generate"].value = "fixed";
|
||||
|
||||
/** @type { Array<any> } */
|
||||
group.menu["Convert to nodes"].call();
|
||||
|
||||
// ensure widget values are set
|
||||
const pos = graph.find(nodes.pos.id);
|
||||
expect(pos.node.type).toBe("CLIPTextEncode");
|
||||
expect(pos.widgets["text"].value).toBe("pos");
|
||||
const neg = graph.find(nodes.neg.id);
|
||||
expect(neg.node.type).toBe("CLIPTextEncode");
|
||||
expect(neg.widgets["text"].value).toBe("neg");
|
||||
const empty = graph.find(nodes.empty.id);
|
||||
expect(empty.node.type).toBe("EmptyLatentImage");
|
||||
expect(empty.widgets["width"].value).toBe(1024);
|
||||
const sampler = graph.find(nodes.sampler.id);
|
||||
expect(sampler.node.type).toBe("KSampler");
|
||||
expect(sampler.widgets["sampler_name"].value).toBe("ddim");
|
||||
expect(sampler.widgets["control_after_generate"].value).toBe("fixed");
|
||||
|
||||
// validate links
|
||||
expect(nodes.ckpt.outputs.CLIP.connections.map((t) => [t.targetNode.id, t.targetInput.index])).toEqual([
|
||||
[pos.id, 0],
|
||||
[neg.id, 0],
|
||||
]);
|
||||
|
||||
expect(pos.outputs["CONDITIONING"].connections.map((t) => [t.targetNode.id, t.targetInput.index])).toEqual([
|
||||
[nodes.sampler.id, 1],
|
||||
]);
|
||||
|
||||
expect(neg.outputs["CONDITIONING"].connections.map((t) => [t.targetNode.id, t.targetInput.index])).toEqual([
|
||||
[nodes.sampler.id, 2],
|
||||
]);
|
||||
|
||||
expect(empty.outputs["LATENT"].connections.map((t) => [t.targetNode.id, t.targetInput.index])).toEqual([
|
||||
[nodes.sampler.id, 3],
|
||||
]);
|
||||
});
|
||||
test("it can embed reroutes as inputs", async () => {
|
||||
const { ez, graph, app } = await start();
|
||||
const nodes = createDefaultWorkflow(ez, graph);
|
||||
|
||||
// Add and connect a reroute to the clip text encodes
|
||||
const reroute = ez.Reroute();
|
||||
nodes.ckpt.outputs.CLIP.connectTo(reroute.inputs[0]);
|
||||
reroute.outputs[0].connectTo(nodes.pos.inputs[0]);
|
||||
reroute.outputs[0].connectTo(nodes.neg.inputs[0]);
|
||||
|
||||
// Convert to group and ensure we only have 1 input of the correct type
|
||||
const group = await convertToGroup(app, graph, "test", [nodes.pos, nodes.neg, nodes.empty, reroute]);
|
||||
expect(group.inputs).toHaveLength(1);
|
||||
expect(group.inputs[0].input.type).toEqual("CLIP");
|
||||
|
||||
expect((await graph.toPrompt()).output).toEqual(getOutput());
|
||||
});
|
||||
test("it can embed reroutes as outputs", async () => {
|
||||
const { ez, graph, app } = await start();
|
||||
const nodes = createDefaultWorkflow(ez, graph);
|
||||
|
||||
// Add a reroute with no output so we output IMAGE even though its used internally
|
||||
const reroute = ez.Reroute();
|
||||
nodes.decode.outputs.IMAGE.connectTo(reroute.inputs[0]);
|
||||
|
||||
// Convert to group and ensure there is an IMAGE output
|
||||
const group = await convertToGroup(app, graph, "test", [nodes.decode, nodes.save, reroute]);
|
||||
expect(group.outputs).toHaveLength(1);
|
||||
expect(group.outputs[0].output.type).toEqual("IMAGE");
|
||||
expect((await graph.toPrompt()).output).toEqual(getOutput([nodes.decode.id, nodes.save.id]));
|
||||
});
|
||||
test("it can embed reroutes as pipes", async () => {
|
||||
const { ez, graph, app } = await start();
|
||||
const nodes = createDefaultWorkflow(ez, graph);
|
||||
|
||||
// Use reroutes as a pipe
|
||||
const rerouteModel = ez.Reroute();
|
||||
const rerouteClip = ez.Reroute();
|
||||
const rerouteVae = ez.Reroute();
|
||||
nodes.ckpt.outputs.MODEL.connectTo(rerouteModel.inputs[0]);
|
||||
nodes.ckpt.outputs.CLIP.connectTo(rerouteClip.inputs[0]);
|
||||
nodes.ckpt.outputs.VAE.connectTo(rerouteVae.inputs[0]);
|
||||
|
||||
const group = await convertToGroup(app, graph, "test", [rerouteModel, rerouteClip, rerouteVae]);
|
||||
|
||||
expect(group.outputs).toHaveLength(3);
|
||||
expect(group.outputs.map((o) => o.output.type)).toEqual(["MODEL", "CLIP", "VAE"]);
|
||||
|
||||
expect(group.outputs).toHaveLength(3);
|
||||
expect(group.outputs.map((o) => o.output.type)).toEqual(["MODEL", "CLIP", "VAE"]);
|
||||
|
||||
group.outputs[0].connectTo(nodes.sampler.inputs.model);
|
||||
group.outputs[1].connectTo(nodes.pos.inputs.clip);
|
||||
group.outputs[1].connectTo(nodes.neg.inputs.clip);
|
||||
});
|
||||
test("can handle reroutes used internally", async () => {
|
||||
const { ez, graph, app } = await start();
|
||||
const nodes = createDefaultWorkflow(ez, graph);
|
||||
|
||||
let reroutes = [];
|
||||
let prevNode = nodes.ckpt;
|
||||
for(let i = 0; i < 5; i++) {
|
||||
const reroute = ez.Reroute();
|
||||
prevNode.outputs[0].connectTo(reroute.inputs[0]);
|
||||
prevNode = reroute;
|
||||
reroutes.push(reroute);
|
||||
}
|
||||
prevNode.outputs[0].connectTo(nodes.sampler.inputs.model);
|
||||
|
||||
const group = await convertToGroup(app, graph, "test", [...reroutes, ...Object.values(nodes)]);
|
||||
expect((await graph.toPrompt()).output).toEqual(getOutput());
|
||||
|
||||
group.menu["Convert to nodes"].call();
|
||||
expect((await graph.toPrompt()).output).toEqual(getOutput());
|
||||
});
|
||||
test("creates with widget values from inner nodes", async () => {
|
||||
const { ez, graph, app } = await start();
|
||||
const nodes = createDefaultWorkflow(ez, graph);
|
||||
|
||||
nodes.ckpt.widgets.ckpt_name.value = "model2.ckpt";
|
||||
nodes.pos.widgets.text.value = "hello";
|
||||
nodes.neg.widgets.text.value = "world";
|
||||
nodes.empty.widgets.width.value = 256;
|
||||
nodes.empty.widgets.height.value = 1024;
|
||||
nodes.sampler.widgets.seed.value = 1;
|
||||
nodes.sampler.widgets.control_after_generate.value = "increment";
|
||||
nodes.sampler.widgets.steps.value = 8;
|
||||
nodes.sampler.widgets.cfg.value = 4.5;
|
||||
nodes.sampler.widgets.sampler_name.value = "uni_pc";
|
||||
nodes.sampler.widgets.scheduler.value = "karras";
|
||||
nodes.sampler.widgets.denoise.value = 0.9;
|
||||
|
||||
const group = await convertToGroup(app, graph, "test", [
|
||||
nodes.ckpt,
|
||||
nodes.pos,
|
||||
nodes.neg,
|
||||
nodes.empty,
|
||||
nodes.sampler,
|
||||
]);
|
||||
|
||||
expect(group.widgets["ckpt_name"].value).toEqual("model2.ckpt");
|
||||
expect(group.widgets["text"].value).toEqual("hello");
|
||||
expect(group.widgets["CLIPTextEncode text"].value).toEqual("world");
|
||||
expect(group.widgets["width"].value).toEqual(256);
|
||||
expect(group.widgets["height"].value).toEqual(1024);
|
||||
expect(group.widgets["seed"].value).toEqual(1);
|
||||
expect(group.widgets["control_after_generate"].value).toEqual("increment");
|
||||
expect(group.widgets["steps"].value).toEqual(8);
|
||||
expect(group.widgets["cfg"].value).toEqual(4.5);
|
||||
expect(group.widgets["sampler_name"].value).toEqual("uni_pc");
|
||||
expect(group.widgets["scheduler"].value).toEqual("karras");
|
||||
expect(group.widgets["denoise"].value).toEqual(0.9);
|
||||
|
||||
expect((await graph.toPrompt()).output).toEqual(
|
||||
getOutput([nodes.ckpt.id, nodes.pos.id, nodes.neg.id, nodes.empty.id, nodes.sampler.id], {
|
||||
[nodes.ckpt.id]: { ckpt_name: "model2.ckpt" },
|
||||
[nodes.pos.id]: { text: "hello" },
|
||||
[nodes.neg.id]: { text: "world" },
|
||||
[nodes.empty.id]: { width: 256, height: 1024 },
|
||||
[nodes.sampler.id]: {
|
||||
seed: 1,
|
||||
steps: 8,
|
||||
cfg: 4.5,
|
||||
sampler_name: "uni_pc",
|
||||
scheduler: "karras",
|
||||
denoise: 0.9,
|
||||
},
|
||||
})
|
||||
);
|
||||
});
|
||||
test("group inputs can be reroutes", async () => {
|
||||
const { ez, graph, app } = await start();
|
||||
const nodes = createDefaultWorkflow(ez, graph);
|
||||
const group = await convertToGroup(app, graph, "test", [nodes.pos, nodes.neg]);
|
||||
|
||||
const reroute = ez.Reroute();
|
||||
nodes.ckpt.outputs.CLIP.connectTo(reroute.inputs[0]);
|
||||
|
||||
reroute.outputs[0].connectTo(group.inputs[0]);
|
||||
reroute.outputs[0].connectTo(group.inputs[1]);
|
||||
|
||||
expect((await graph.toPrompt()).output).toEqual(getOutput([nodes.pos.id, nodes.neg.id]));
|
||||
});
|
||||
test("group outputs can be reroutes", async () => {
|
||||
const { ez, graph, app } = await start();
|
||||
const nodes = createDefaultWorkflow(ez, graph);
|
||||
const group = await convertToGroup(app, graph, "test", [nodes.pos, nodes.neg]);
|
||||
|
||||
const reroute1 = ez.Reroute();
|
||||
const reroute2 = ez.Reroute();
|
||||
group.outputs[0].connectTo(reroute1.inputs[0]);
|
||||
group.outputs[1].connectTo(reroute2.inputs[0]);
|
||||
|
||||
reroute1.outputs[0].connectTo(nodes.sampler.inputs.positive);
|
||||
reroute2.outputs[0].connectTo(nodes.sampler.inputs.negative);
|
||||
|
||||
expect((await graph.toPrompt()).output).toEqual(getOutput([nodes.pos.id, nodes.neg.id]));
|
||||
});
|
||||
test("groups can connect to each other", async () => {
|
||||
const { ez, graph, app } = await start();
|
||||
const nodes = createDefaultWorkflow(ez, graph);
|
||||
const group1 = await convertToGroup(app, graph, "test", [nodes.pos, nodes.neg]);
|
||||
const group2 = await convertToGroup(app, graph, "test2", [nodes.empty, nodes.sampler]);
|
||||
|
||||
group1.outputs[0].connectTo(group2.inputs["positive"]);
|
||||
group1.outputs[1].connectTo(group2.inputs["negative"]);
|
||||
|
||||
expect((await graph.toPrompt()).output).toEqual(
|
||||
getOutput([nodes.pos.id, nodes.neg.id, nodes.empty.id, nodes.sampler.id])
|
||||
);
|
||||
});
|
||||
test("displays generated image on group node", async () => {
|
||||
const { ez, graph, app } = await start();
|
||||
const nodes = createDefaultWorkflow(ez, graph);
|
||||
let group = await convertToGroup(app, graph, "test", [
|
||||
nodes.pos,
|
||||
nodes.neg,
|
||||
nodes.empty,
|
||||
nodes.sampler,
|
||||
nodes.decode,
|
||||
nodes.save,
|
||||
]);
|
||||
|
||||
const { api } = require("../../web/scripts/api");
|
||||
|
||||
api.dispatchEvent(new CustomEvent("execution_start", {}));
|
||||
api.dispatchEvent(new CustomEvent("executing", { detail: `${nodes.save.id}` }));
|
||||
// Event should be forwarded to group node id
|
||||
expect(+app.runningNodeId).toEqual(group.id);
|
||||
expect(group.node["imgs"]).toBeFalsy();
|
||||
api.dispatchEvent(
|
||||
new CustomEvent("executed", {
|
||||
detail: {
|
||||
node: `${nodes.save.id}`,
|
||||
output: {
|
||||
images: [
|
||||
{
|
||||
filename: "test.png",
|
||||
type: "output",
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
})
|
||||
);
|
||||
|
||||
// Trigger paint
|
||||
group.node.onDrawBackground?.(app.canvas.ctx, app.canvas.canvas);
|
||||
|
||||
expect(group.node["images"]).toEqual([
|
||||
{
|
||||
filename: "test.png",
|
||||
type: "output",
|
||||
},
|
||||
]);
|
||||
|
||||
// Reload
|
||||
const workflow = JSON.stringify((await graph.toPrompt()).workflow);
|
||||
await app.loadGraphData(JSON.parse(workflow));
|
||||
group = graph.find(group);
|
||||
|
||||
// Trigger inner nodes to get created
|
||||
group.node["getInnerNodes"]();
|
||||
|
||||
// Check it works for internal node ids
|
||||
api.dispatchEvent(new CustomEvent("execution_start", {}));
|
||||
api.dispatchEvent(new CustomEvent("executing", { detail: `${group.id}:5` }));
|
||||
// Event should be forwarded to group node id
|
||||
expect(+app.runningNodeId).toEqual(group.id);
|
||||
expect(group.node["imgs"]).toBeFalsy();
|
||||
api.dispatchEvent(
|
||||
new CustomEvent("executed", {
|
||||
detail: {
|
||||
node: `${group.id}:5`,
|
||||
output: {
|
||||
images: [
|
||||
{
|
||||
filename: "test2.png",
|
||||
type: "output",
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
})
|
||||
);
|
||||
|
||||
// Trigger paint
|
||||
group.node.onDrawBackground?.(app.canvas.ctx, app.canvas.canvas);
|
||||
|
||||
expect(group.node["images"]).toEqual([
|
||||
{
|
||||
filename: "test2.png",
|
||||
type: "output",
|
||||
},
|
||||
]);
|
||||
});
|
||||
test("allows widgets to be converted to inputs", async () => {
|
||||
const { ez, graph, app } = await start();
|
||||
const nodes = createDefaultWorkflow(ez, graph);
|
||||
const group = await convertToGroup(app, graph, "test", [nodes.pos, nodes.neg]);
|
||||
group.widgets[0].convertToInput();
|
||||
|
||||
const primitive = ez.PrimitiveNode();
|
||||
primitive.outputs[0].connectTo(group.inputs["text"]);
|
||||
primitive.widgets[0].value = "hello";
|
||||
|
||||
expect((await graph.toPrompt()).output).toEqual(
|
||||
getOutput([nodes.pos.id, nodes.neg.id], {
|
||||
[nodes.pos.id]: { text: "hello" },
|
||||
})
|
||||
);
|
||||
});
|
||||
test("can be copied", async () => {
|
||||
const { ez, graph, app } = await start();
|
||||
const nodes = createDefaultWorkflow(ez, graph);
|
||||
|
||||
const group1 = await convertToGroup(app, graph, "test", [
|
||||
nodes.pos,
|
||||
nodes.neg,
|
||||
nodes.empty,
|
||||
nodes.sampler,
|
||||
nodes.decode,
|
||||
nodes.save,
|
||||
]);
|
||||
|
||||
group1.widgets["text"].value = "hello";
|
||||
group1.widgets["width"].value = 256;
|
||||
group1.widgets["seed"].value = 1;
|
||||
|
||||
// Clone the node
|
||||
group1.menu.Clone.call();
|
||||
expect(app.graph._nodes).toHaveLength(3);
|
||||
const group2 = graph.find(app.graph._nodes[2]);
|
||||
expect(group2.node.type).toEqual("workflow/test");
|
||||
expect(group2.id).not.toEqual(group1.id);
|
||||
|
||||
// Reconnect ckpt
|
||||
nodes.ckpt.outputs.MODEL.connectTo(group2.inputs["model"]);
|
||||
nodes.ckpt.outputs.CLIP.connectTo(group2.inputs["clip"]);
|
||||
nodes.ckpt.outputs.CLIP.connectTo(group2.inputs["CLIPTextEncode clip"]);
|
||||
nodes.ckpt.outputs.VAE.connectTo(group2.inputs["vae"]);
|
||||
|
||||
group2.widgets["text"].value = "world";
|
||||
group2.widgets["width"].value = 1024;
|
||||
group2.widgets["seed"].value = 100;
|
||||
|
||||
let i = 0;
|
||||
expect((await graph.toPrompt()).output).toEqual({
|
||||
...getOutput([nodes.empty.id, nodes.pos.id, nodes.neg.id, nodes.sampler.id, nodes.decode.id, nodes.save.id], {
|
||||
[nodes.empty.id]: { width: 256 },
|
||||
[nodes.pos.id]: { text: "hello" },
|
||||
[nodes.sampler.id]: { seed: 1 },
|
||||
}),
|
||||
...getOutput(
|
||||
{
|
||||
[nodes.empty.id]: `${group2.id}:${i++}`,
|
||||
[nodes.pos.id]: `${group2.id}:${i++}`,
|
||||
[nodes.neg.id]: `${group2.id}:${i++}`,
|
||||
[nodes.sampler.id]: `${group2.id}:${i++}`,
|
||||
[nodes.decode.id]: `${group2.id}:${i++}`,
|
||||
[nodes.save.id]: `${group2.id}:${i++}`,
|
||||
},
|
||||
{
|
||||
[nodes.empty.id]: { width: 1024 },
|
||||
[nodes.pos.id]: { text: "world" },
|
||||
[nodes.sampler.id]: { seed: 100 },
|
||||
}
|
||||
),
|
||||
});
|
||||
|
||||
graph.arrange();
|
||||
});
|
||||
test("is embedded in workflow", async () => {
|
||||
let { ez, graph, app } = await start();
|
||||
const nodes = createDefaultWorkflow(ez, graph);
|
||||
let group = await convertToGroup(app, graph, "test", [nodes.pos, nodes.neg]);
|
||||
const workflow = JSON.stringify((await graph.toPrompt()).workflow);
|
||||
|
||||
// Clear the environment
|
||||
({ ez, graph, app } = await start({
|
||||
resetEnv: true,
|
||||
}));
|
||||
// Ensure the node isnt registered
|
||||
expect(() => ez["workflow/test"]).toThrow();
|
||||
|
||||
// Reload the workflow
|
||||
await app.loadGraphData(JSON.parse(workflow));
|
||||
|
||||
// Ensure the node is found
|
||||
group = graph.find(group);
|
||||
|
||||
// Generate prompt and ensure it is as expected
|
||||
expect((await graph.toPrompt()).output).toEqual(
|
||||
getOutput({
|
||||
[nodes.pos.id]: `${group.id}:0`,
|
||||
[nodes.neg.id]: `${group.id}:1`,
|
||||
})
|
||||
);
|
||||
});
|
||||
test("shows missing node error on missing internal node when loading graph data", async () => {
|
||||
const { graph } = await start();
|
||||
|
||||
const dialogShow = jest.spyOn(graph.app.ui.dialog, "show");
|
||||
await graph.app.loadGraphData({
|
||||
last_node_id: 3,
|
||||
last_link_id: 1,
|
||||
nodes: [
|
||||
{
|
||||
id: 3,
|
||||
type: "workflow/testerror",
|
||||
},
|
||||
],
|
||||
links: [],
|
||||
groups: [],
|
||||
config: {},
|
||||
extra: {
|
||||
groupNodes: {
|
||||
testerror: {
|
||||
nodes: [
|
||||
{
|
||||
type: "NotKSampler",
|
||||
},
|
||||
{
|
||||
type: "NotVAEDecode",
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
expect(dialogShow).toBeCalledTimes(1);
|
||||
const call = dialogShow.mock.calls[0][0].innerHTML;
|
||||
expect(call).toContain("the following node types were not found");
|
||||
expect(call).toContain("NotKSampler");
|
||||
expect(call).toContain("NotVAEDecode");
|
||||
expect(call).toContain("workflow/testerror");
|
||||
});
|
||||
test("maintains widget inputs on conversion back to nodes", async () => {
|
||||
const { ez, graph, app } = await start();
|
||||
let pos = ez.CLIPTextEncode({ text: "positive" });
|
||||
pos.node.title = "Positive";
|
||||
let neg = ez.CLIPTextEncode({ text: "negative" });
|
||||
neg.node.title = "Negative";
|
||||
pos.widgets.text.convertToInput();
|
||||
neg.widgets.text.convertToInput();
|
||||
|
||||
let primitive = ez.PrimitiveNode();
|
||||
primitive.outputs[0].connectTo(pos.inputs.text);
|
||||
primitive.outputs[0].connectTo(neg.inputs.text);
|
||||
|
||||
const group = await convertToGroup(app, graph, "test", [pos, neg, primitive]);
|
||||
// This will use a primitive widget named 'value'
|
||||
expect(group.widgets.length).toBe(1);
|
||||
expect(group.widgets["value"].value).toBe("positive");
|
||||
|
||||
const newNodes = group.menu["Convert to nodes"].call();
|
||||
pos = graph.find(newNodes.find((n) => n.title === "Positive"));
|
||||
neg = graph.find(newNodes.find((n) => n.title === "Negative"));
|
||||
primitive = graph.find(newNodes.find((n) => n.type === "PrimitiveNode"));
|
||||
|
||||
expect(pos.inputs).toHaveLength(2);
|
||||
expect(neg.inputs).toHaveLength(2);
|
||||
expect(primitive.outputs[0].connections).toHaveLength(2);
|
||||
|
||||
expect((await graph.toPrompt()).output).toEqual({
|
||||
1: { inputs: { text: "positive" }, class_type: "CLIPTextEncode" },
|
||||
2: { inputs: { text: "positive" }, class_type: "CLIPTextEncode" },
|
||||
});
|
||||
});
|
||||
test("adds widgets in node execution order", async () => {
|
||||
const { ez, graph, app } = await start();
|
||||
const scale = ez.LatentUpscale();
|
||||
const save = ez.SaveImage();
|
||||
const empty = ez.EmptyLatentImage();
|
||||
const decode = ez.VAEDecode();
|
||||
|
||||
scale.outputs.LATENT.connectTo(decode.inputs.samples);
|
||||
decode.outputs.IMAGE.connectTo(save.inputs.images);
|
||||
empty.outputs.LATENT.connectTo(scale.inputs.samples);
|
||||
|
||||
const group = await convertToGroup(app, graph, "test", [scale, save, empty, decode]);
|
||||
const widgets = group.widgets.map((w) => w.widget.name);
|
||||
expect(widgets).toStrictEqual([
|
||||
"width",
|
||||
"height",
|
||||
"batch_size",
|
||||
"upscale_method",
|
||||
"LatentUpscale width",
|
||||
"LatentUpscale height",
|
||||
"crop",
|
||||
"filename_prefix",
|
||||
]);
|
||||
});
|
||||
test("adds output for external links when converting to group", async () => {
|
||||
const { ez, graph, app } = await start();
|
||||
const img = ez.EmptyLatentImage();
|
||||
let decode = ez.VAEDecode(...img.outputs);
|
||||
const preview1 = ez.PreviewImage(...decode.outputs);
|
||||
const preview2 = ez.PreviewImage(...decode.outputs);
|
||||
|
||||
const group = await convertToGroup(app, graph, "test", [img, decode, preview1]);
|
||||
|
||||
// Ensure we have an output connected to the 2nd preview node
|
||||
expect(group.outputs.length).toBe(1);
|
||||
expect(group.outputs[0].connections.length).toBe(1);
|
||||
expect(group.outputs[0].connections[0].targetNode.id).toBe(preview2.id);
|
||||
|
||||
// Convert back and ensure bothe previews are still connected
|
||||
group.menu["Convert to nodes"].call();
|
||||
decode = graph.find(decode);
|
||||
expect(decode.outputs[0].connections.length).toBe(2);
|
||||
expect(decode.outputs[0].connections[0].targetNode.id).toBe(preview1.id);
|
||||
expect(decode.outputs[0].connections[1].targetNode.id).toBe(preview2.id);
|
||||
});
|
||||
test("adds output for external links when converting to group when nodes are not in execution order", async () => {
|
||||
const { ez, graph, app } = await start();
|
||||
const sampler = ez.KSampler();
|
||||
const ckpt = ez.CheckpointLoaderSimple();
|
||||
const empty = ez.EmptyLatentImage();
|
||||
const pos = ez.CLIPTextEncode(ckpt.outputs.CLIP, { text: "positive" });
|
||||
const neg = ez.CLIPTextEncode(ckpt.outputs.CLIP, { text: "negative" });
|
||||
const decode1 = ez.VAEDecode(sampler.outputs.LATENT, ckpt.outputs.VAE);
|
||||
const save = ez.SaveImage(decode1.outputs.IMAGE);
|
||||
ckpt.outputs.MODEL.connectTo(sampler.inputs.model);
|
||||
pos.outputs.CONDITIONING.connectTo(sampler.inputs.positive);
|
||||
neg.outputs.CONDITIONING.connectTo(sampler.inputs.negative);
|
||||
empty.outputs.LATENT.connectTo(sampler.inputs.latent_image);
|
||||
|
||||
const encode = ez.VAEEncode(decode1.outputs.IMAGE);
|
||||
const vae = ez.VAELoader();
|
||||
const decode2 = ez.VAEDecode(encode.outputs.LATENT, vae.outputs.VAE);
|
||||
const preview = ez.PreviewImage(decode2.outputs.IMAGE);
|
||||
vae.outputs.VAE.connectTo(encode.inputs.vae);
|
||||
|
||||
const group = await convertToGroup(app, graph, "test", [vae, decode1, encode, sampler]);
|
||||
|
||||
expect(group.outputs.length).toBe(3);
|
||||
expect(group.outputs[0].output.name).toBe("VAE");
|
||||
expect(group.outputs[0].output.type).toBe("VAE");
|
||||
expect(group.outputs[1].output.name).toBe("IMAGE");
|
||||
expect(group.outputs[1].output.type).toBe("IMAGE");
|
||||
expect(group.outputs[2].output.name).toBe("LATENT");
|
||||
expect(group.outputs[2].output.type).toBe("LATENT");
|
||||
|
||||
expect(group.outputs[0].connections.length).toBe(1);
|
||||
expect(group.outputs[0].connections[0].targetNode.id).toBe(decode2.id);
|
||||
expect(group.outputs[0].connections[0].targetInput.index).toBe(1);
|
||||
|
||||
expect(group.outputs[1].connections.length).toBe(1);
|
||||
expect(group.outputs[1].connections[0].targetNode.id).toBe(save.id);
|
||||
expect(group.outputs[1].connections[0].targetInput.index).toBe(0);
|
||||
|
||||
expect(group.outputs[2].connections.length).toBe(1);
|
||||
expect(group.outputs[2].connections[0].targetNode.id).toBe(decode2.id);
|
||||
expect(group.outputs[2].connections[0].targetInput.index).toBe(0);
|
||||
|
||||
expect((await graph.toPrompt()).output).toEqual({
|
||||
...getOutput({ 1: ckpt.id, 2: pos.id, 3: neg.id, 4: empty.id, 5: sampler.id, 6: decode1.id, 7: save.id }),
|
||||
[vae.id]: { inputs: { vae_name: "vae1.safetensors" }, class_type: vae.node.type },
|
||||
[encode.id]: { inputs: { pixels: ["6", 0], vae: [vae.id + "", 0] }, class_type: encode.node.type },
|
||||
[decode2.id]: { inputs: { samples: [encode.id + "", 0], vae: [vae.id + "", 0] }, class_type: decode2.node.type },
|
||||
[preview.id]: { inputs: { images: [decode2.id + "", 0] }, class_type: preview.node.type },
|
||||
});
|
||||
});
|
||||
test("works with IMAGEUPLOAD widget", async () => {
|
||||
const { ez, graph, app } = await start();
|
||||
const img = ez.LoadImage();
|
||||
const preview1 = ez.PreviewImage(img.outputs[0]);
|
||||
|
||||
const group = await convertToGroup(app, graph, "test", [img, preview1]);
|
||||
const widget = group.widgets["upload"];
|
||||
expect(widget).toBeTruthy();
|
||||
expect(widget.widget.type).toBe("button");
|
||||
});
|
||||
test("internal primitive populates widgets for all linked inputs", async () => {
|
||||
const { ez, graph, app } = await start();
|
||||
const img = ez.LoadImage();
|
||||
const scale1 = ez.ImageScale(img.outputs[0]);
|
||||
const scale2 = ez.ImageScale(img.outputs[0]);
|
||||
ez.PreviewImage(scale1.outputs[0]);
|
||||
ez.PreviewImage(scale2.outputs[0]);
|
||||
|
||||
scale1.widgets.width.convertToInput();
|
||||
scale2.widgets.height.convertToInput();
|
||||
|
||||
const primitive = ez.PrimitiveNode();
|
||||
primitive.outputs[0].connectTo(scale1.inputs.width);
|
||||
primitive.outputs[0].connectTo(scale2.inputs.height);
|
||||
|
||||
const group = await convertToGroup(app, graph, "test", [img, primitive, scale1, scale2]);
|
||||
group.widgets.value.value = 100;
|
||||
expect((await graph.toPrompt()).output).toEqual({
|
||||
1: {
|
||||
inputs: { image: img.widgets.image.value, upload: "image" },
|
||||
class_type: "LoadImage",
|
||||
},
|
||||
2: {
|
||||
inputs: { upscale_method: "nearest-exact", width: 100, height: 512, crop: "disabled", image: ["1", 0] },
|
||||
class_type: "ImageScale",
|
||||
},
|
||||
3: {
|
||||
inputs: { upscale_method: "nearest-exact", width: 512, height: 100, crop: "disabled", image: ["1", 0] },
|
||||
class_type: "ImageScale",
|
||||
},
|
||||
4: { inputs: { images: ["2", 0] }, class_type: "PreviewImage" },
|
||||
5: { inputs: { images: ["3", 0] }, class_type: "PreviewImage" },
|
||||
});
|
||||
});
|
||||
test("primitive control widgets values are copied on convert", async () => {
|
||||
const { ez, graph, app } = await start();
|
||||
const sampler = ez.KSampler();
|
||||
sampler.widgets.seed.convertToInput();
|
||||
sampler.widgets.sampler_name.convertToInput();
|
||||
|
||||
let p1 = ez.PrimitiveNode();
|
||||
let p2 = ez.PrimitiveNode();
|
||||
p1.outputs[0].connectTo(sampler.inputs.seed);
|
||||
p2.outputs[0].connectTo(sampler.inputs.sampler_name);
|
||||
|
||||
p1.widgets.control_after_generate.value = "increment";
|
||||
p2.widgets.control_after_generate.value = "decrement";
|
||||
p2.widgets.control_filter_list.value = "/.*/";
|
||||
|
||||
p2.node.title = "p2";
|
||||
|
||||
const group = await convertToGroup(app, graph, "test", [sampler, p1, p2]);
|
||||
expect(group.widgets.control_after_generate.value).toBe("increment");
|
||||
expect(group.widgets["p2 control_after_generate"].value).toBe("decrement");
|
||||
expect(group.widgets["p2 control_filter_list"].value).toBe("/.*/");
|
||||
|
||||
group.widgets.control_after_generate.value = "fixed";
|
||||
group.widgets["p2 control_after_generate"].value = "randomize";
|
||||
group.widgets["p2 control_filter_list"].value = "/.+/";
|
||||
|
||||
group.menu["Convert to nodes"].call();
|
||||
p1 = graph.find(p1);
|
||||
p2 = graph.find(p2);
|
||||
|
||||
expect(p1.widgets.control_after_generate.value).toBe("fixed");
|
||||
expect(p2.widgets.control_after_generate.value).toBe("randomize");
|
||||
expect(p2.widgets.control_filter_list.value).toBe("/.+/");
|
||||
});
|
||||
});
|
@ -14,10 +14,10 @@ const lg = require("../utils/litegraph");
|
||||
* @param { InstanceType<Ez["EzGraph"]> } graph
|
||||
* @param { InstanceType<Ez["EzInput"]> } input
|
||||
* @param { string } widgetType
|
||||
* @param { boolean } hasControlWidget
|
||||
* @param { number } controlWidgetCount
|
||||
* @returns
|
||||
*/
|
||||
async function connectPrimitiveAndReload(ez, graph, input, widgetType, hasControlWidget) {
|
||||
async function connectPrimitiveAndReload(ez, graph, input, widgetType, controlWidgetCount = 0) {
|
||||
// Connect to primitive and ensure its still connected after
|
||||
let primitive = ez.PrimitiveNode();
|
||||
primitive.outputs[0].connectTo(input);
|
||||
@ -33,13 +33,17 @@ async function connectPrimitiveAndReload(ez, graph, input, widgetType, hasContro
|
||||
expect(valueWidget.widget.type).toBe(widgetType);
|
||||
|
||||
// Check if control_after_generate should be added
|
||||
if (hasControlWidget) {
|
||||
if (controlWidgetCount) {
|
||||
const controlWidget = primitive.widgets.control_after_generate;
|
||||
expect(controlWidget.widget.type).toBe("combo");
|
||||
if(widgetType === "combo") {
|
||||
const filterWidget = primitive.widgets.control_filter_list;
|
||||
expect(filterWidget.widget.type).toBe("string");
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure we dont have other widgets
|
||||
expect(primitive.node.widgets).toHaveLength(1 + +!!hasControlWidget);
|
||||
expect(primitive.node.widgets).toHaveLength(1 + controlWidgetCount);
|
||||
});
|
||||
|
||||
return primitive;
|
||||
@ -55,8 +59,8 @@ describe("widget inputs", () => {
|
||||
});
|
||||
|
||||
[
|
||||
{ name: "int", type: "INT", widget: "number", control: true },
|
||||
{ name: "float", type: "FLOAT", widget: "number", control: true },
|
||||
{ name: "int", type: "INT", widget: "number", control: 1 },
|
||||
{ name: "float", type: "FLOAT", widget: "number", control: 1 },
|
||||
{ name: "text", type: "STRING" },
|
||||
{
|
||||
name: "customtext",
|
||||
@ -64,7 +68,7 @@ describe("widget inputs", () => {
|
||||
opt: { multiline: true },
|
||||
},
|
||||
{ name: "toggle", type: "BOOLEAN" },
|
||||
{ name: "combo", type: ["a", "b", "c"], control: true },
|
||||
{ name: "combo", type: ["a", "b", "c"], control: 2 },
|
||||
].forEach((c) => {
|
||||
test(`widget conversion + primitive works on ${c.name}`, async () => {
|
||||
const { ez, graph } = await start({
|
||||
@ -106,7 +110,7 @@ describe("widget inputs", () => {
|
||||
n.widgets.ckpt_name.convertToInput();
|
||||
expect(n.inputs.length).toEqual(inputCount + 1);
|
||||
|
||||
const primitive = await connectPrimitiveAndReload(ez, graph, n.inputs.ckpt_name, "combo", true);
|
||||
const primitive = await connectPrimitiveAndReload(ez, graph, n.inputs.ckpt_name, "combo", 2);
|
||||
|
||||
// Disconnect & reconnect
|
||||
primitive.outputs[0].connections[0].disconnect();
|
||||
@ -198,8 +202,8 @@ describe("widget inputs", () => {
|
||||
});
|
||||
|
||||
expect(dialogShow).toBeCalledTimes(1);
|
||||
expect(dialogShow.mock.calls[0][0]).toContain("the following node types were not found");
|
||||
expect(dialogShow.mock.calls[0][0]).toContain("TestNode");
|
||||
expect(dialogShow.mock.calls[0][0].innerHTML).toContain("the following node types were not found");
|
||||
expect(dialogShow.mock.calls[0][0].innerHTML).toContain("TestNode");
|
||||
});
|
||||
|
||||
test("defaultInput widgets can be converted back to inputs", async () => {
|
||||
@ -226,7 +230,7 @@ describe("widget inputs", () => {
|
||||
// Reload and ensure it still only has 1 converted widget
|
||||
if (!assertNotNullOrUndefined(input)) return;
|
||||
|
||||
await connectPrimitiveAndReload(ez, graph, input, "number", true);
|
||||
await connectPrimitiveAndReload(ez, graph, input, "number", 1);
|
||||
n = graph.find(n);
|
||||
expect(n.widgets).toHaveLength(1);
|
||||
w = n.widgets.example;
|
||||
@ -258,7 +262,7 @@ describe("widget inputs", () => {
|
||||
|
||||
// Reload and ensure it still only has 1 converted widget
|
||||
if (assertNotNullOrUndefined(input)) {
|
||||
await connectPrimitiveAndReload(ez, graph, input, "number", true);
|
||||
await connectPrimitiveAndReload(ez, graph, input, "number", 1);
|
||||
n = graph.find(n);
|
||||
expect(n.widgets).toHaveLength(1);
|
||||
expect(n.widgets.example.isConvertedToInput).toBeTruthy();
|
||||
@ -316,4 +320,76 @@ describe("widget inputs", () => {
|
||||
n1.outputs[0].connectTo(n2.inputs[0]);
|
||||
expect(() => n1.outputs[0].connectTo(n3.inputs[0])).toThrow();
|
||||
});
|
||||
|
||||
test("combo primitive can filter list when control_after_generate called", async () => {
|
||||
const { ez } = await start({
|
||||
mockNodeDefs: {
|
||||
...makeNodeDef("TestNode1", { example: [["A", "B", "C", "D", "AA", "BB", "CC", "DD", "AAA", "BBB"], {}] }),
|
||||
},
|
||||
});
|
||||
|
||||
const n1 = ez.TestNode1();
|
||||
n1.widgets.example.convertToInput();
|
||||
const p = ez.PrimitiveNode()
|
||||
p.outputs[0].connectTo(n1.inputs[0]);
|
||||
|
||||
const value = p.widgets.value;
|
||||
const control = p.widgets.control_after_generate.widget;
|
||||
const filter = p.widgets.control_filter_list;
|
||||
|
||||
expect(p.widgets.length).toBe(3);
|
||||
control.value = "increment";
|
||||
expect(value.value).toBe("A");
|
||||
|
||||
// Manually trigger after queue when set to increment
|
||||
control["afterQueued"]();
|
||||
expect(value.value).toBe("B");
|
||||
|
||||
// Filter to items containing D
|
||||
filter.value = "D";
|
||||
control["afterQueued"]();
|
||||
expect(value.value).toBe("D");
|
||||
control["afterQueued"]();
|
||||
expect(value.value).toBe("DD");
|
||||
|
||||
// Check decrement
|
||||
value.value = "BBB";
|
||||
control.value = "decrement";
|
||||
filter.value = "B";
|
||||
control["afterQueued"]();
|
||||
expect(value.value).toBe("BB");
|
||||
control["afterQueued"]();
|
||||
expect(value.value).toBe("B");
|
||||
|
||||
// Check regex works
|
||||
value.value = "BBB";
|
||||
filter.value = "/[AB]|^C$/";
|
||||
control["afterQueued"]();
|
||||
expect(value.value).toBe("AAA");
|
||||
control["afterQueued"]();
|
||||
expect(value.value).toBe("BB");
|
||||
control["afterQueued"]();
|
||||
expect(value.value).toBe("AA");
|
||||
control["afterQueued"]();
|
||||
expect(value.value).toBe("C");
|
||||
control["afterQueued"]();
|
||||
expect(value.value).toBe("B");
|
||||
control["afterQueued"]();
|
||||
expect(value.value).toBe("A");
|
||||
|
||||
// Check random
|
||||
control.value = "randomize";
|
||||
filter.value = "/D/";
|
||||
for(let i = 0; i < 100; i++) {
|
||||
control["afterQueued"]();
|
||||
expect(value.value === "D" || value.value === "DD").toBeTruthy();
|
||||
}
|
||||
|
||||
// Ensure it doesnt apply when fixed
|
||||
control.value = "fixed";
|
||||
value.value = "B";
|
||||
filter.value = "C";
|
||||
control["afterQueued"]();
|
||||
expect(value.value).toBe("B");
|
||||
});
|
||||
});
|
||||
|
@ -150,7 +150,7 @@ export class EzNodeMenuItem {
|
||||
if (selectNode) {
|
||||
this.node.select();
|
||||
}
|
||||
this.item.callback.call(this.node.node, undefined, undefined, undefined, undefined, this.node.node);
|
||||
return this.item.callback.call(this.node.node, undefined, undefined, undefined, undefined, this.node.node);
|
||||
}
|
||||
}
|
||||
|
||||
@ -240,8 +240,12 @@ export class EzNode {
|
||||
return this.#makeLookupArray(() => this.app.canvas.getNodeMenuOptions(this.node), "content", EzNodeMenuItem);
|
||||
}
|
||||
|
||||
select() {
|
||||
this.app.canvas.selectNode(this.node);
|
||||
get isRemoved() {
|
||||
return !this.app.graph.getNodeById(this.id);
|
||||
}
|
||||
|
||||
select(addToSelection = false) {
|
||||
this.app.canvas.selectNode(this.node, addToSelection);
|
||||
}
|
||||
|
||||
// /**
|
||||
@ -275,12 +279,17 @@ export class EzNode {
|
||||
if (!s) return p;
|
||||
|
||||
const name = s[nameProperty];
|
||||
const item = new ctor(this, i, s);
|
||||
// @ts-ignore
|
||||
if (!name || name in p) {
|
||||
throw new Error(`Unable to store ${nodeProperty} ${name} on array as name conflicts.`);
|
||||
p.push(item);
|
||||
if (name) {
|
||||
// @ts-ignore
|
||||
if (name in p) {
|
||||
throw new Error(`Unable to store ${nodeProperty} ${name} on array as name conflicts.`);
|
||||
}
|
||||
}
|
||||
// @ts-ignore
|
||||
p.push((p[name] = new ctor(this, i, s)));
|
||||
p[name] = item;
|
||||
return p;
|
||||
}, Object.assign([], { $: this }));
|
||||
}
|
||||
@ -348,6 +357,19 @@ export class EzGraph {
|
||||
}, 10);
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* @returns { Promise<{
|
||||
* workflow: {},
|
||||
* output: Record<string, {
|
||||
* class_name: string,
|
||||
* inputs: Record<string, [string, number] | unknown>
|
||||
* }>}> }
|
||||
*/
|
||||
toPrompt() {
|
||||
// @ts-ignore
|
||||
return this.app.graphToPrompt();
|
||||
}
|
||||
}
|
||||
|
||||
export const Ez = {
|
||||
@ -356,12 +378,12 @@ export const Ez = {
|
||||
* @example
|
||||
* const { ez, graph } = Ez.graph(app);
|
||||
* graph.clear();
|
||||
* const [model, clip, vae] = ez.CheckpointLoaderSimple();
|
||||
* const [pos] = ez.CLIPTextEncode(clip, { text: "positive" });
|
||||
* const [neg] = ez.CLIPTextEncode(clip, { text: "negative" });
|
||||
* const [latent] = ez.KSampler(model, pos, neg, ...ez.EmptyLatentImage());
|
||||
* const [image] = ez.VAEDecode(latent, vae);
|
||||
* const saveNode = ez.SaveImage(image).node;
|
||||
* const [model, clip, vae] = ez.CheckpointLoaderSimple().outputs;
|
||||
* const [pos] = ez.CLIPTextEncode(clip, { text: "positive" }).outputs;
|
||||
* const [neg] = ez.CLIPTextEncode(clip, { text: "negative" }).outputs;
|
||||
* const [latent] = ez.KSampler(model, pos, neg, ...ez.EmptyLatentImage().outputs).outputs;
|
||||
* const [image] = ez.VAEDecode(latent, vae).outputs;
|
||||
* const saveNode = ez.SaveImage(image);
|
||||
* console.log(saveNode);
|
||||
* graph.arrange();
|
||||
* @param { app } app
|
||||
|
@ -1,21 +1,29 @@
|
||||
const { mockApi } = require("./setup");
|
||||
const { Ez } = require("./ezgraph");
|
||||
const lg = require("./litegraph");
|
||||
|
||||
/**
|
||||
*
|
||||
* @param { Parameters<mockApi>[0] } config
|
||||
* @param { Parameters<mockApi>[0] & { resetEnv?: boolean, preSetup?(app): Promise<void> } } config
|
||||
* @returns
|
||||
*/
|
||||
export async function start(config = undefined) {
|
||||
export async function start(config = {}) {
|
||||
if(config.resetEnv) {
|
||||
jest.resetModules();
|
||||
jest.resetAllMocks();
|
||||
lg.setup(global);
|
||||
}
|
||||
|
||||
mockApi(config);
|
||||
const { app } = require("../../web/scripts/app");
|
||||
config.preSetup?.(app);
|
||||
await app.setup();
|
||||
return Ez.graph(app, global["LiteGraph"], global["LGraphCanvas"]);
|
||||
return { ...Ez.graph(app, global["LiteGraph"], global["LGraphCanvas"]), app };
|
||||
}
|
||||
|
||||
/**
|
||||
* @param { ReturnType<Ez["graph"]>["graph"] } graph
|
||||
* @param { (hasReloaded: boolean) => (Promise<void> | void) } cb
|
||||
* @param { ReturnType<Ez["graph"]>["graph"] } graph
|
||||
* @param { (hasReloaded: boolean) => (Promise<void> | void) } cb
|
||||
*/
|
||||
export async function checkBeforeAndAfterReload(graph, cb) {
|
||||
await cb(false);
|
||||
@ -24,10 +32,10 @@ export async function checkBeforeAndAfterReload(graph, cb) {
|
||||
}
|
||||
|
||||
/**
|
||||
* @param { string } name
|
||||
* @param { Record<string, string | [string | string[], any]> } input
|
||||
* @param { string } name
|
||||
* @param { Record<string, string | [string | string[], any]> } input
|
||||
* @param { (string | string[])[] | Record<string, string | string[]> } output
|
||||
* @returns { Record<string, import("../../web/types/comfy").ComfyObjectInfo> }
|
||||
* @returns { Record<string, import("../../web/types/comfy").ComfyObjectInfo> }
|
||||
*/
|
||||
export function makeNodeDef(name, input, output = {}) {
|
||||
const nodeDef = {
|
||||
@ -37,19 +45,19 @@ export function makeNodeDef(name, input, output = {}) {
|
||||
output_name: [],
|
||||
output_is_list: [],
|
||||
input: {
|
||||
required: {}
|
||||
required: {},
|
||||
},
|
||||
};
|
||||
for(const k in input) {
|
||||
for (const k in input) {
|
||||
nodeDef.input.required[k] = typeof input[k] === "string" ? [input[k], {}] : [...input[k]];
|
||||
}
|
||||
if(output instanceof Array) {
|
||||
if (output instanceof Array) {
|
||||
output = output.reduce((p, c) => {
|
||||
p[c] = c;
|
||||
return p;
|
||||
}, {})
|
||||
}, {});
|
||||
}
|
||||
for(const k in output) {
|
||||
for (const k in output) {
|
||||
nodeDef.output.push(output[k]);
|
||||
nodeDef.output_name.push(k);
|
||||
nodeDef.output_is_list.push(false);
|
||||
@ -68,4 +76,31 @@ export function assertNotNullOrUndefined(x) {
|
||||
expect(x).not.toEqual(null);
|
||||
expect(x).not.toEqual(undefined);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param { ReturnType<Ez["graph"]>["ez"] } ez
|
||||
* @param { ReturnType<Ez["graph"]>["graph"] } graph
|
||||
*/
|
||||
export function createDefaultWorkflow(ez, graph) {
|
||||
graph.clear();
|
||||
const ckpt = ez.CheckpointLoaderSimple();
|
||||
|
||||
const pos = ez.CLIPTextEncode(ckpt.outputs.CLIP, { text: "positive" });
|
||||
const neg = ez.CLIPTextEncode(ckpt.outputs.CLIP, { text: "negative" });
|
||||
|
||||
const empty = ez.EmptyLatentImage();
|
||||
const sampler = ez.KSampler(
|
||||
ckpt.outputs.MODEL,
|
||||
pos.outputs.CONDITIONING,
|
||||
neg.outputs.CONDITIONING,
|
||||
empty.outputs.LATENT
|
||||
);
|
||||
|
||||
const decode = ez.VAEDecode(sampler.outputs.LATENT, ckpt.outputs.VAE);
|
||||
const save = ez.SaveImage(decode.outputs.IMAGE);
|
||||
graph.arrange();
|
||||
|
||||
return { ckpt, pos, neg, empty, sampler, decode, save };
|
||||
}
|
||||
|
@ -30,16 +30,20 @@ export function mockApi({ mockExtensions, mockNodeDefs } = {}) {
|
||||
mockNodeDefs = JSON.parse(fs.readFileSync(path.resolve("./data/object_info.json")));
|
||||
}
|
||||
|
||||
const events = new EventTarget();
|
||||
const mockApi = {
|
||||
addEventListener: events.addEventListener.bind(events),
|
||||
removeEventListener: events.removeEventListener.bind(events),
|
||||
dispatchEvent: events.dispatchEvent.bind(events),
|
||||
getSystemStats: jest.fn(),
|
||||
getExtensions: jest.fn(() => mockExtensions),
|
||||
getNodeDefs: jest.fn(() => mockNodeDefs),
|
||||
init: jest.fn(),
|
||||
apiURL: jest.fn((x) => "../../web/" + x),
|
||||
};
|
||||
jest.mock("../../web/scripts/api", () => ({
|
||||
get api() {
|
||||
return {
|
||||
addEventListener: jest.fn(),
|
||||
getSystemStats: jest.fn(),
|
||||
getExtensions: jest.fn(() => mockExtensions),
|
||||
getNodeDefs: jest.fn(() => mockNodeDefs),
|
||||
init: jest.fn(),
|
||||
apiURL: jest.fn((x) => "../../web/" + x),
|
||||
};
|
||||
return mockApi;
|
||||
},
|
||||
}));
|
||||
}
|
||||
|
@ -174,6 +174,213 @@ const colorPalettes = {
|
||||
"tr-odd-bg-color": "#073642",
|
||||
}
|
||||
},
|
||||
},
|
||||
"arc": {
|
||||
"id": "arc",
|
||||
"name": "Arc",
|
||||
"colors": {
|
||||
"node_slot": {
|
||||
"BOOLEAN": "",
|
||||
"CLIP": "#eacb8b",
|
||||
"CLIP_VISION": "#A8DADC",
|
||||
"CLIP_VISION_OUTPUT": "#ad7452",
|
||||
"CONDITIONING": "#cf876f",
|
||||
"CONTROL_NET": "#00d78d",
|
||||
"CONTROL_NET_WEIGHTS": "",
|
||||
"FLOAT": "",
|
||||
"GLIGEN": "",
|
||||
"IMAGE": "#80a1c0",
|
||||
"IMAGEUPLOAD": "",
|
||||
"INT": "",
|
||||
"LATENT": "#b38ead",
|
||||
"LATENT_KEYFRAME": "",
|
||||
"MASK": "#a3bd8d",
|
||||
"MODEL": "#8978a7",
|
||||
"SAMPLER": "",
|
||||
"SIGMAS": "",
|
||||
"STRING": "",
|
||||
"STYLE_MODEL": "#C2FFAE",
|
||||
"T2I_ADAPTER_WEIGHTS": "",
|
||||
"TAESD": "#DCC274",
|
||||
"TIMESTEP_KEYFRAME": "",
|
||||
"UPSCALE_MODEL": "",
|
||||
"VAE": "#be616b"
|
||||
},
|
||||
"litegraph_base": {
|
||||
"BACKGROUND_IMAGE": "",
|
||||
"CLEAR_BACKGROUND_COLOR": "#2b2f38",
|
||||
"NODE_TITLE_COLOR": "#b2b7bd",
|
||||
"NODE_SELECTED_TITLE_COLOR": "#FFF",
|
||||
"NODE_TEXT_SIZE": 14,
|
||||
"NODE_TEXT_COLOR": "#AAA",
|
||||
"NODE_SUBTEXT_SIZE": 12,
|
||||
"NODE_DEFAULT_COLOR": "#2b2f38",
|
||||
"NODE_DEFAULT_BGCOLOR": "#242730",
|
||||
"NODE_DEFAULT_BOXCOLOR": "#6e7581",
|
||||
"NODE_DEFAULT_SHAPE": "box",
|
||||
"NODE_BOX_OUTLINE_COLOR": "#FFF",
|
||||
"DEFAULT_SHADOW_COLOR": "rgba(0,0,0,0.5)",
|
||||
"DEFAULT_GROUP_FONT": 22,
|
||||
"WIDGET_BGCOLOR": "#2b2f38",
|
||||
"WIDGET_OUTLINE_COLOR": "#6e7581",
|
||||
"WIDGET_TEXT_COLOR": "#DDD",
|
||||
"WIDGET_SECONDARY_TEXT_COLOR": "#b2b7bd",
|
||||
"LINK_COLOR": "#9A9",
|
||||
"EVENT_LINK_COLOR": "#A86",
|
||||
"CONNECTING_LINK_COLOR": "#AFA"
|
||||
},
|
||||
"comfy_base": {
|
||||
"fg-color": "#fff",
|
||||
"bg-color": "#2b2f38",
|
||||
"comfy-menu-bg": "#242730",
|
||||
"comfy-input-bg": "#2b2f38",
|
||||
"input-text": "#ddd",
|
||||
"descrip-text": "#b2b7bd",
|
||||
"drag-text": "#ccc",
|
||||
"error-text": "#ff4444",
|
||||
"border-color": "#6e7581",
|
||||
"tr-even-bg-color": "#2b2f38",
|
||||
"tr-odd-bg-color": "#242730"
|
||||
}
|
||||
},
|
||||
},
|
||||
"nord": {
|
||||
"id": "nord",
|
||||
"name": "Nord",
|
||||
"colors": {
|
||||
"node_slot": {
|
||||
"BOOLEAN": "",
|
||||
"CLIP": "#eacb8b",
|
||||
"CLIP_VISION": "#A8DADC",
|
||||
"CLIP_VISION_OUTPUT": "#ad7452",
|
||||
"CONDITIONING": "#cf876f",
|
||||
"CONTROL_NET": "#00d78d",
|
||||
"CONTROL_NET_WEIGHTS": "",
|
||||
"FLOAT": "",
|
||||
"GLIGEN": "",
|
||||
"IMAGE": "#80a1c0",
|
||||
"IMAGEUPLOAD": "",
|
||||
"INT": "",
|
||||
"LATENT": "#b38ead",
|
||||
"LATENT_KEYFRAME": "",
|
||||
"MASK": "#a3bd8d",
|
||||
"MODEL": "#8978a7",
|
||||
"SAMPLER": "",
|
||||
"SIGMAS": "",
|
||||
"STRING": "",
|
||||
"STYLE_MODEL": "#C2FFAE",
|
||||
"T2I_ADAPTER_WEIGHTS": "",
|
||||
"TAESD": "#DCC274",
|
||||
"TIMESTEP_KEYFRAME": "",
|
||||
"UPSCALE_MODEL": "",
|
||||
"VAE": "#be616b"
|
||||
},
|
||||
"litegraph_base": {
|
||||
"BACKGROUND_IMAGE": "",
|
||||
"CLEAR_BACKGROUND_COLOR": "#212732",
|
||||
"NODE_TITLE_COLOR": "#999",
|
||||
"NODE_SELECTED_TITLE_COLOR": "#e5eaf0",
|
||||
"NODE_TEXT_SIZE": 14,
|
||||
"NODE_TEXT_COLOR": "#bcc2c8",
|
||||
"NODE_SUBTEXT_SIZE": 12,
|
||||
"NODE_DEFAULT_COLOR": "#2e3440",
|
||||
"NODE_DEFAULT_BGCOLOR": "#161b22",
|
||||
"NODE_DEFAULT_BOXCOLOR": "#545d70",
|
||||
"NODE_DEFAULT_SHAPE": "box",
|
||||
"NODE_BOX_OUTLINE_COLOR": "#e5eaf0",
|
||||
"DEFAULT_SHADOW_COLOR": "rgba(0,0,0,0.5)",
|
||||
"DEFAULT_GROUP_FONT": 24,
|
||||
"WIDGET_BGCOLOR": "#2e3440",
|
||||
"WIDGET_OUTLINE_COLOR": "#545d70",
|
||||
"WIDGET_TEXT_COLOR": "#bcc2c8",
|
||||
"WIDGET_SECONDARY_TEXT_COLOR": "#999",
|
||||
"LINK_COLOR": "#9A9",
|
||||
"EVENT_LINK_COLOR": "#A86",
|
||||
"CONNECTING_LINK_COLOR": "#AFA"
|
||||
},
|
||||
"comfy_base": {
|
||||
"fg-color": "#e5eaf0",
|
||||
"bg-color": "#2e3440",
|
||||
"comfy-menu-bg": "#161b22",
|
||||
"comfy-input-bg": "#2e3440",
|
||||
"input-text": "#bcc2c8",
|
||||
"descrip-text": "#999",
|
||||
"drag-text": "#ccc",
|
||||
"error-text": "#ff4444",
|
||||
"border-color": "#545d70",
|
||||
"tr-even-bg-color": "#2e3440",
|
||||
"tr-odd-bg-color": "#161b22"
|
||||
}
|
||||
},
|
||||
},
|
||||
"github": {
|
||||
"id": "github",
|
||||
"name": "Github",
|
||||
"colors": {
|
||||
"node_slot": {
|
||||
"BOOLEAN": "",
|
||||
"CLIP": "#eacb8b",
|
||||
"CLIP_VISION": "#A8DADC",
|
||||
"CLIP_VISION_OUTPUT": "#ad7452",
|
||||
"CONDITIONING": "#cf876f",
|
||||
"CONTROL_NET": "#00d78d",
|
||||
"CONTROL_NET_WEIGHTS": "",
|
||||
"FLOAT": "",
|
||||
"GLIGEN": "",
|
||||
"IMAGE": "#80a1c0",
|
||||
"IMAGEUPLOAD": "",
|
||||
"INT": "",
|
||||
"LATENT": "#b38ead",
|
||||
"LATENT_KEYFRAME": "",
|
||||
"MASK": "#a3bd8d",
|
||||
"MODEL": "#8978a7",
|
||||
"SAMPLER": "",
|
||||
"SIGMAS": "",
|
||||
"STRING": "",
|
||||
"STYLE_MODEL": "#C2FFAE",
|
||||
"T2I_ADAPTER_WEIGHTS": "",
|
||||
"TAESD": "#DCC274",
|
||||
"TIMESTEP_KEYFRAME": "",
|
||||
"UPSCALE_MODEL": "",
|
||||
"VAE": "#be616b"
|
||||
},
|
||||
"litegraph_base": {
|
||||
"BACKGROUND_IMAGE": "",
|
||||
"CLEAR_BACKGROUND_COLOR": "#040506",
|
||||
"NODE_TITLE_COLOR": "#999",
|
||||
"NODE_SELECTED_TITLE_COLOR": "#e5eaf0",
|
||||
"NODE_TEXT_SIZE": 14,
|
||||
"NODE_TEXT_COLOR": "#bcc2c8",
|
||||
"NODE_SUBTEXT_SIZE": 12,
|
||||
"NODE_DEFAULT_COLOR": "#161b22",
|
||||
"NODE_DEFAULT_BGCOLOR": "#13171d",
|
||||
"NODE_DEFAULT_BOXCOLOR": "#30363d",
|
||||
"NODE_DEFAULT_SHAPE": "box",
|
||||
"NODE_BOX_OUTLINE_COLOR": "#e5eaf0",
|
||||
"DEFAULT_SHADOW_COLOR": "rgba(0,0,0,0.5)",
|
||||
"DEFAULT_GROUP_FONT": 24,
|
||||
"WIDGET_BGCOLOR": "#161b22",
|
||||
"WIDGET_OUTLINE_COLOR": "#30363d",
|
||||
"WIDGET_TEXT_COLOR": "#bcc2c8",
|
||||
"WIDGET_SECONDARY_TEXT_COLOR": "#999",
|
||||
"LINK_COLOR": "#9A9",
|
||||
"EVENT_LINK_COLOR": "#A86",
|
||||
"CONNECTING_LINK_COLOR": "#AFA"
|
||||
},
|
||||
"comfy_base": {
|
||||
"fg-color": "#e5eaf0",
|
||||
"bg-color": "#161b22",
|
||||
"comfy-menu-bg": "#13171d",
|
||||
"comfy-input-bg": "#161b22",
|
||||
"input-text": "#bcc2c8",
|
||||
"descrip-text": "#999",
|
||||
"drag-text": "#ccc",
|
||||
"error-text": "#ff4444",
|
||||
"border-color": "#30363d",
|
||||
"tr-even-bg-color": "#161b22",
|
||||
"tr-odd-bg-color": "#13171d"
|
||||
}
|
||||
},
|
||||
}
|
||||
};
|
||||
|
||||
|
1053
web/extensions/core/groupNode.js
Normal file
1053
web/extensions/core/groupNode.js
Normal file
File diff suppressed because it is too large
Load Diff
@ -1,5 +1,6 @@
|
||||
import { app } from "../../scripts/app.js";
|
||||
import { ComfyDialog, $el } from "../../scripts/ui.js";
|
||||
import { GroupNodeConfig, GroupNodeHandler } from "./groupNode.js";
|
||||
|
||||
// Adds the ability to save and add multiple nodes as a template
|
||||
// To save:
|
||||
@ -34,7 +35,7 @@ class ManageTemplates extends ComfyDialog {
|
||||
type: "file",
|
||||
accept: ".json",
|
||||
multiple: true,
|
||||
style: {display: "none"},
|
||||
style: { display: "none" },
|
||||
parent: document.body,
|
||||
onchange: () => this.importAll(),
|
||||
});
|
||||
@ -109,13 +110,13 @@ class ManageTemplates extends ComfyDialog {
|
||||
return;
|
||||
}
|
||||
|
||||
const json = JSON.stringify({templates: this.templates}, null, 2); // convert the data to a JSON string
|
||||
const blob = new Blob([json], {type: "application/json"});
|
||||
const json = JSON.stringify({ templates: this.templates }, null, 2); // convert the data to a JSON string
|
||||
const blob = new Blob([json], { type: "application/json" });
|
||||
const url = URL.createObjectURL(blob);
|
||||
const a = $el("a", {
|
||||
href: url,
|
||||
download: "node_templates.json",
|
||||
style: {display: "none"},
|
||||
style: { display: "none" },
|
||||
parent: document.body,
|
||||
});
|
||||
a.click();
|
||||
@ -291,11 +292,11 @@ app.registerExtension({
|
||||
setup() {
|
||||
const manage = new ManageTemplates();
|
||||
|
||||
const clipboardAction = (cb) => {
|
||||
const clipboardAction = async (cb) => {
|
||||
// We use the clipboard functions but dont want to overwrite the current user clipboard
|
||||
// Restore it after we've run our callback
|
||||
const old = localStorage.getItem("litegrapheditor_clipboard");
|
||||
cb();
|
||||
await cb();
|
||||
localStorage.setItem("litegrapheditor_clipboard", old);
|
||||
};
|
||||
|
||||
@ -309,13 +310,31 @@ app.registerExtension({
|
||||
disabled: !Object.keys(app.canvas.selected_nodes || {}).length,
|
||||
callback: () => {
|
||||
const name = prompt("Enter name");
|
||||
if (!name || !name.trim()) return;
|
||||
if (!name?.trim()) return;
|
||||
|
||||
clipboardAction(() => {
|
||||
app.canvas.copyToClipboard();
|
||||
let data = localStorage.getItem("litegrapheditor_clipboard");
|
||||
data = JSON.parse(data);
|
||||
const nodeIds = Object.keys(app.canvas.selected_nodes);
|
||||
for (let i = 0; i < nodeIds.length; i++) {
|
||||
const node = app.graph.getNodeById(nodeIds[i]);
|
||||
const nodeData = node?.constructor.nodeData;
|
||||
|
||||
let groupData = GroupNodeHandler.getGroupData(node);
|
||||
if (groupData) {
|
||||
groupData = groupData.nodeData;
|
||||
if (!data.groupNodes) {
|
||||
data.groupNodes = {};
|
||||
}
|
||||
data.groupNodes[nodeData.name] = groupData;
|
||||
data.nodes[i].type = nodeData.name;
|
||||
}
|
||||
}
|
||||
|
||||
manage.templates.push({
|
||||
name,
|
||||
data: localStorage.getItem("litegrapheditor_clipboard"),
|
||||
data: JSON.stringify(data),
|
||||
});
|
||||
manage.store();
|
||||
});
|
||||
@ -323,15 +342,19 @@ app.registerExtension({
|
||||
});
|
||||
|
||||
// Map each template to a menu item
|
||||
const subItems = manage.templates.map((t) => ({
|
||||
content: t.name,
|
||||
callback: () => {
|
||||
clipboardAction(() => {
|
||||
localStorage.setItem("litegrapheditor_clipboard", t.data);
|
||||
app.canvas.pasteFromClipboard();
|
||||
});
|
||||
},
|
||||
}));
|
||||
const subItems = manage.templates.map((t) => {
|
||||
return {
|
||||
content: t.name,
|
||||
callback: () => {
|
||||
clipboardAction(async () => {
|
||||
const data = JSON.parse(t.data);
|
||||
await GroupNodeConfig.registerFromWorkflow(data.groupNodes, {});
|
||||
localStorage.setItem("litegrapheditor_clipboard", t.data);
|
||||
app.canvas.pasteFromClipboard();
|
||||
});
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
subItems.push(null, {
|
||||
content: "Manage",
|
||||
|
150
web/extensions/core/undoRedo.js
Normal file
150
web/extensions/core/undoRedo.js
Normal file
@ -0,0 +1,150 @@
|
||||
import { app } from "../../scripts/app.js";
|
||||
|
||||
const MAX_HISTORY = 50;
|
||||
|
||||
let undo = [];
|
||||
let redo = [];
|
||||
let activeState = null;
|
||||
let isOurLoad = false;
|
||||
function checkState() {
|
||||
const currentState = app.graph.serialize();
|
||||
if (!graphEqual(activeState, currentState)) {
|
||||
undo.push(activeState);
|
||||
if (undo.length > MAX_HISTORY) {
|
||||
undo.shift();
|
||||
}
|
||||
activeState = clone(currentState);
|
||||
redo.length = 0;
|
||||
}
|
||||
}
|
||||
|
||||
const loadGraphData = app.loadGraphData;
|
||||
app.loadGraphData = async function () {
|
||||
const v = await loadGraphData.apply(this, arguments);
|
||||
if (isOurLoad) {
|
||||
isOurLoad = false;
|
||||
} else {
|
||||
checkState();
|
||||
}
|
||||
return v;
|
||||
};
|
||||
|
||||
function clone(obj) {
|
||||
try {
|
||||
if (typeof structuredClone !== "undefined") {
|
||||
return structuredClone(obj);
|
||||
}
|
||||
} catch (error) {
|
||||
// structuredClone is stricter than using JSON.parse/stringify so fallback to that
|
||||
}
|
||||
|
||||
return JSON.parse(JSON.stringify(obj));
|
||||
}
|
||||
|
||||
function graphEqual(a, b, root = true) {
|
||||
if (a === b) return true;
|
||||
|
||||
if (typeof a == "object" && a && typeof b == "object" && b) {
|
||||
const keys = Object.getOwnPropertyNames(a);
|
||||
|
||||
if (keys.length != Object.getOwnPropertyNames(b).length) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (const key of keys) {
|
||||
let av = a[key];
|
||||
let bv = b[key];
|
||||
if (root && key === "nodes") {
|
||||
// Nodes need to be sorted as the order changes when selecting nodes
|
||||
av = [...av].sort((a, b) => a.id - b.id);
|
||||
bv = [...bv].sort((a, b) => a.id - b.id);
|
||||
}
|
||||
if (!graphEqual(av, bv, false)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
const undoRedo = async (e) => {
|
||||
if (e.ctrlKey || e.metaKey) {
|
||||
if (e.key === "y") {
|
||||
const prevState = redo.pop();
|
||||
if (prevState) {
|
||||
undo.push(activeState);
|
||||
isOurLoad = true;
|
||||
await app.loadGraphData(prevState);
|
||||
activeState = prevState;
|
||||
}
|
||||
return true;
|
||||
} else if (e.key === "z") {
|
||||
const prevState = undo.pop();
|
||||
if (prevState) {
|
||||
redo.push(activeState);
|
||||
isOurLoad = true;
|
||||
await app.loadGraphData(prevState);
|
||||
activeState = prevState;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const bindInput = (activeEl) => {
|
||||
if (activeEl?.tagName !== "CANVAS" && activeEl?.tagName !== "BODY") {
|
||||
for (const evt of ["change", "input", "blur"]) {
|
||||
if (`on${evt}` in activeEl) {
|
||||
const listener = () => {
|
||||
checkState();
|
||||
activeEl.removeEventListener(evt, listener);
|
||||
};
|
||||
activeEl.addEventListener(evt, listener);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
window.addEventListener(
|
||||
"keydown",
|
||||
(e) => {
|
||||
requestAnimationFrame(async () => {
|
||||
const activeEl = document.activeElement;
|
||||
if (activeEl?.tagName === "INPUT" || activeEl?.type === "textarea") {
|
||||
// Ignore events on inputs, they have their native history
|
||||
return;
|
||||
}
|
||||
|
||||
// Check if this is a ctrl+z ctrl+y
|
||||
if (await undoRedo(e)) return;
|
||||
|
||||
// If our active element is some type of input then handle changes after they're done
|
||||
if (bindInput(activeEl)) return;
|
||||
checkState();
|
||||
});
|
||||
},
|
||||
true
|
||||
);
|
||||
|
||||
// Handle clicking DOM elements (e.g. widgets)
|
||||
window.addEventListener("mouseup", () => {
|
||||
checkState();
|
||||
});
|
||||
|
||||
// Handle litegraph clicks
|
||||
const processMouseUp = LGraphCanvas.prototype.processMouseUp;
|
||||
LGraphCanvas.prototype.processMouseUp = function (e) {
|
||||
const v = processMouseUp.apply(this, arguments);
|
||||
checkState();
|
||||
return v;
|
||||
};
|
||||
const processMouseDown = LGraphCanvas.prototype.processMouseDown;
|
||||
LGraphCanvas.prototype.processMouseDown = function (e) {
|
||||
const v = processMouseDown.apply(this, arguments);
|
||||
checkState();
|
||||
return v;
|
||||
};
|
@ -1,4 +1,4 @@
|
||||
import { ComfyWidgets, addValueControlWidget } from "../../scripts/widgets.js";
|
||||
import { ComfyWidgets, addValueControlWidgets } from "../../scripts/widgets.js";
|
||||
import { app } from "../../scripts/app.js";
|
||||
|
||||
const CONVERTED_TYPE = "converted-widget";
|
||||
@ -121,6 +121,110 @@ function isValidCombo(combo, obj) {
|
||||
return true;
|
||||
}
|
||||
|
||||
export function mergeIfValid(output, config2, forceUpdate, recreateWidget, config1) {
|
||||
if (!config1) {
|
||||
config1 = output.widget[CONFIG] ?? output.widget[GET_CONFIG]();
|
||||
}
|
||||
|
||||
if (config1[0] instanceof Array) {
|
||||
if (!isValidCombo(config1[0], config2[0])) return false;
|
||||
} else if (config1[0] !== config2[0]) {
|
||||
// Types dont match
|
||||
console.log(`connection rejected: types dont match`, config1[0], config2[0]);
|
||||
return false;
|
||||
}
|
||||
|
||||
const keys = new Set([...Object.keys(config1[1] ?? {}), ...Object.keys(config2[1] ?? {})]);
|
||||
|
||||
let customConfig;
|
||||
const getCustomConfig = () => {
|
||||
if (!customConfig) {
|
||||
if (typeof structuredClone === "undefined") {
|
||||
customConfig = JSON.parse(JSON.stringify(config1[1] ?? {}));
|
||||
} else {
|
||||
customConfig = structuredClone(config1[1] ?? {});
|
||||
}
|
||||
}
|
||||
return customConfig;
|
||||
};
|
||||
|
||||
const isNumber = config1[0] === "INT" || config1[0] === "FLOAT";
|
||||
for (const k of keys.values()) {
|
||||
if (k !== "default" && k !== "forceInput" && k !== "defaultInput") {
|
||||
let v1 = config1[1][k];
|
||||
let v2 = config2[1]?.[k];
|
||||
|
||||
if (v1 === v2 || (!v1 && !v2)) continue;
|
||||
|
||||
if (isNumber) {
|
||||
if (k === "min") {
|
||||
const theirMax = config2[1]?.["max"];
|
||||
if (theirMax != null && v1 > theirMax) {
|
||||
console.log("connection rejected: min > max", v1, theirMax);
|
||||
return false;
|
||||
}
|
||||
getCustomConfig()[k] = v1 == null ? v2 : v2 == null ? v1 : Math.max(v1, v2);
|
||||
continue;
|
||||
} else if (k === "max") {
|
||||
const theirMin = config2[1]?.["min"];
|
||||
if (theirMin != null && v1 < theirMin) {
|
||||
console.log("connection rejected: max < min", v1, theirMin);
|
||||
return false;
|
||||
}
|
||||
getCustomConfig()[k] = v1 == null ? v2 : v2 == null ? v1 : Math.min(v1, v2);
|
||||
continue;
|
||||
} else if (k === "step") {
|
||||
let step;
|
||||
if (v1 == null) {
|
||||
// No current step
|
||||
step = v2;
|
||||
} else if (v2 == null) {
|
||||
// No new step
|
||||
step = v1;
|
||||
} else {
|
||||
if (v1 < v2) {
|
||||
// Ensure v1 is larger for the mod
|
||||
const a = v2;
|
||||
v2 = v1;
|
||||
v1 = a;
|
||||
}
|
||||
if (v1 % v2) {
|
||||
console.log("connection rejected: steps not divisible", "current:", v1, "new:", v2);
|
||||
return false;
|
||||
}
|
||||
|
||||
step = v1;
|
||||
}
|
||||
|
||||
getCustomConfig()[k] = step;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
console.log(`connection rejected: config ${k} values dont match`, v1, v2);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (customConfig || forceUpdate) {
|
||||
if (customConfig) {
|
||||
output.widget[CONFIG] = [config1[0], customConfig];
|
||||
}
|
||||
|
||||
const widget = recreateWidget?.call(this);
|
||||
// When deleting a node this can be null
|
||||
if (widget) {
|
||||
const min = widget.options.min;
|
||||
const max = widget.options.max;
|
||||
if (min != null && widget.value < min) widget.value = min;
|
||||
if (max != null && widget.value > max) widget.value = max;
|
||||
widget.callback(widget.value);
|
||||
}
|
||||
}
|
||||
|
||||
return { customConfig };
|
||||
}
|
||||
|
||||
app.registerExtension({
|
||||
name: "Comfy.WidgetInputs",
|
||||
async beforeRegisterNodeDef(nodeType, nodeData, app) {
|
||||
@ -308,7 +412,7 @@ app.registerExtension({
|
||||
this.isVirtualNode = true;
|
||||
}
|
||||
|
||||
applyToGraph() {
|
||||
applyToGraph(extraLinks = []) {
|
||||
if (!this.outputs[0].links?.length) return;
|
||||
|
||||
function get_links(node) {
|
||||
@ -325,10 +429,9 @@ app.registerExtension({
|
||||
return links;
|
||||
}
|
||||
|
||||
let links = get_links(this);
|
||||
let links = [...get_links(this).map((l) => app.graph.links[l]), ...extraLinks];
|
||||
// For each output link copy our value over the original widget value
|
||||
for (const l of links) {
|
||||
const linkInfo = app.graph.links[l];
|
||||
for (const linkInfo of links) {
|
||||
const node = this.graph.getNodeById(linkInfo.target_id);
|
||||
const input = node.inputs[linkInfo.target_slot];
|
||||
const widgetName = input.widget.name;
|
||||
@ -405,7 +508,12 @@ app.registerExtension({
|
||||
}
|
||||
|
||||
if (this.outputs[slot].links?.length) {
|
||||
return this.#isValidConnection(input);
|
||||
const valid = this.#isValidConnection(input);
|
||||
if (valid) {
|
||||
// On connect of additional outputs, copy our value to their widget
|
||||
this.applyToGraph([{ target_id: target_node.id, target_slot }]);
|
||||
}
|
||||
return valid;
|
||||
}
|
||||
}
|
||||
|
||||
@ -462,12 +570,16 @@ app.registerExtension({
|
||||
}
|
||||
}
|
||||
|
||||
if (widget.type === "number" || widget.type === "combo") {
|
||||
if (!inputData?.[1]?.control_after_generate && (widget.type === "number" || widget.type === "combo")) {
|
||||
let control_value = this.widgets_values?.[1];
|
||||
if (!control_value) {
|
||||
control_value = "fixed";
|
||||
}
|
||||
addValueControlWidget(this, widget, control_value);
|
||||
addValueControlWidgets(this, widget, control_value, undefined, inputData);
|
||||
let filter = this.widgets_values?.[2];
|
||||
if(filter && this.widgets.length === 3) {
|
||||
this.widgets[2].value = filter;
|
||||
}
|
||||
}
|
||||
|
||||
// When our value changes, update other widgets to reflect our changes
|
||||
@ -503,6 +615,7 @@ app.registerExtension({
|
||||
this.#removeWidgets();
|
||||
this.#onFirstConnection(true);
|
||||
for (let i = 0; i < this.widgets?.length; i++) this.widgets[i].value = values[i];
|
||||
return this.widgets[0];
|
||||
}
|
||||
|
||||
#mergeWidgetConfig() {
|
||||
@ -543,108 +656,8 @@ app.registerExtension({
|
||||
#isValidConnection(input, forceUpdate) {
|
||||
// Only allow connections where the configs match
|
||||
const output = this.outputs[0];
|
||||
const config1 = output.widget[CONFIG] ?? output.widget[GET_CONFIG]();
|
||||
const config2 = input.widget[GET_CONFIG]();
|
||||
|
||||
if (config1[0] instanceof Array) {
|
||||
if (!isValidCombo(config1[0], config2[0])) return false;
|
||||
} else if (config1[0] !== config2[0]) {
|
||||
// Types dont match
|
||||
console.log(`connection rejected: types dont match`, config1[0], config2[0]);
|
||||
return false;
|
||||
}
|
||||
|
||||
const keys = new Set([...Object.keys(config1[1] ?? {}), ...Object.keys(config2[1] ?? {})]);
|
||||
|
||||
let customConfig;
|
||||
const getCustomConfig = () => {
|
||||
if (!customConfig) {
|
||||
if (typeof structuredClone === "undefined") {
|
||||
customConfig = JSON.parse(JSON.stringify(config1[1] ?? {}));
|
||||
} else {
|
||||
customConfig = structuredClone(config1[1] ?? {});
|
||||
}
|
||||
}
|
||||
return customConfig;
|
||||
};
|
||||
|
||||
const isNumber = config1[0] === "INT" || config1[0] === "FLOAT";
|
||||
for (const k of keys.values()) {
|
||||
if (k !== "default" && k !== "forceInput" && k !== "defaultInput") {
|
||||
let v1 = config1[1][k];
|
||||
let v2 = config2[1][k];
|
||||
|
||||
if (v1 === v2 || (!v1 && !v2)) continue;
|
||||
|
||||
if (isNumber) {
|
||||
if (k === "min") {
|
||||
const theirMax = config2[1]["max"];
|
||||
if (theirMax != null && v1 > theirMax) {
|
||||
console.log("connection rejected: min > max", v1, theirMax);
|
||||
return false;
|
||||
}
|
||||
getCustomConfig()[k] = v1 == null ? v2 : v2 == null ? v1 : Math.max(v1, v2);
|
||||
continue;
|
||||
} else if (k === "max") {
|
||||
const theirMin = config2[1]["min"];
|
||||
if (theirMin != null && v1 < theirMin) {
|
||||
console.log("connection rejected: max < min", v1, theirMin);
|
||||
return false;
|
||||
}
|
||||
getCustomConfig()[k] = v1 == null ? v2 : v2 == null ? v1 : Math.min(v1, v2);
|
||||
continue;
|
||||
} else if (k === "step") {
|
||||
let step;
|
||||
if (v1 == null) {
|
||||
// No current step
|
||||
step = v2;
|
||||
} else if (v2 == null) {
|
||||
// No new step
|
||||
step = v1;
|
||||
} else {
|
||||
if (v1 < v2) {
|
||||
// Ensure v1 is larger for the mod
|
||||
const a = v2;
|
||||
v2 = v1;
|
||||
v1 = a;
|
||||
}
|
||||
if (v1 % v2) {
|
||||
console.log("connection rejected: steps not divisible", "current:", v1, "new:", v2);
|
||||
return false;
|
||||
}
|
||||
|
||||
step = v1;
|
||||
}
|
||||
|
||||
getCustomConfig()[k] = step;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
console.log(`connection rejected: config ${k} values dont match`, v1, v2);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (customConfig || forceUpdate) {
|
||||
if (customConfig) {
|
||||
output.widget[CONFIG] = [config1[0], customConfig];
|
||||
}
|
||||
|
||||
this.#recreateWidget();
|
||||
|
||||
const widget = this.widgets[0];
|
||||
// When deleting a node this can be null
|
||||
if (widget) {
|
||||
const min = widget.options.min;
|
||||
const max = widget.options.max;
|
||||
if (min != null && widget.value < min) widget.value = min;
|
||||
if (max != null && widget.value > max) widget.value = max;
|
||||
widget.callback(widget.value);
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
return !!mergeIfValid.call(this, output, config2, forceUpdate, this.#recreateWidget);
|
||||
}
|
||||
|
||||
#removeWidgets() {
|
||||
|
@ -2533,7 +2533,7 @@
|
||||
var w = this.widgets[i];
|
||||
if(!w)
|
||||
continue;
|
||||
if(w.options && w.options.property && this.properties[ w.options.property ])
|
||||
if(w.options && w.options.property && (this.properties[ w.options.property ] != undefined))
|
||||
w.value = JSON.parse( JSON.stringify( this.properties[ w.options.property ] ) );
|
||||
}
|
||||
if (info.widgets_values) {
|
||||
@ -5714,10 +5714,10 @@ LGraphNode.prototype.executeAction = function(action)
|
||||
* @method enableWebGL
|
||||
**/
|
||||
LGraphCanvas.prototype.enableWebGL = function() {
|
||||
if (typeof GL === undefined) {
|
||||
if (typeof GL === "undefined") {
|
||||
throw "litegl.js must be included to use a WebGL canvas";
|
||||
}
|
||||
if (typeof enableWebGLCanvas === undefined) {
|
||||
if (typeof enableWebGLCanvas === "undefined") {
|
||||
throw "webglCanvas.js must be included to use this feature";
|
||||
}
|
||||
|
||||
@ -7110,15 +7110,16 @@ LGraphNode.prototype.executeAction = function(action)
|
||||
}
|
||||
};
|
||||
|
||||
LGraphCanvas.prototype.copyToClipboard = function() {
|
||||
LGraphCanvas.prototype.copyToClipboard = function(nodes) {
|
||||
var clipboard_info = {
|
||||
nodes: [],
|
||||
links: []
|
||||
};
|
||||
var index = 0;
|
||||
var selected_nodes_array = [];
|
||||
for (var i in this.selected_nodes) {
|
||||
var node = this.selected_nodes[i];
|
||||
if (!nodes) nodes = this.selected_nodes;
|
||||
for (var i in nodes) {
|
||||
var node = nodes[i];
|
||||
if (node.clonable === false)
|
||||
continue;
|
||||
node._relative_id = index;
|
||||
@ -11702,7 +11703,7 @@ LGraphNode.prototype.executeAction = function(action)
|
||||
default:
|
||||
iS = 0; // try with first if no name set
|
||||
}
|
||||
if (typeof options.node_from.outputs[iS] !== undefined){
|
||||
if (typeof options.node_from.outputs[iS] !== "undefined"){
|
||||
if (iS!==false && iS>-1){
|
||||
options.node_from.connectByType( iS, node, options.node_from.outputs[iS].type );
|
||||
}
|
||||
@ -11730,7 +11731,7 @@ LGraphNode.prototype.executeAction = function(action)
|
||||
default:
|
||||
iS = 0; // try with first if no name set
|
||||
}
|
||||
if (typeof options.node_to.inputs[iS] !== undefined){
|
||||
if (typeof options.node_to.inputs[iS] !== "undefined"){
|
||||
if (iS!==false && iS>-1){
|
||||
// try connection
|
||||
options.node_to.connectByTypeOutput(iS,node,options.node_to.inputs[iS].type);
|
||||
|
@ -254,9 +254,9 @@ class ComfyApi extends EventTarget {
|
||||
* Gets the prompt execution history
|
||||
* @returns Prompt history including node outputs
|
||||
*/
|
||||
async getHistory() {
|
||||
async getHistory(max_items=200) {
|
||||
try {
|
||||
const res = await this.fetchApi("/history");
|
||||
const res = await this.fetchApi(`/history?max_items=${max_items}`);
|
||||
return { History: Object.values(await res.json()) };
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
|
@ -4,7 +4,10 @@ import { ComfyUI, $el } from "./ui.js";
|
||||
import { api } from "./api.js";
|
||||
import { defaultGraph } from "./defaultGraph.js";
|
||||
import { getPngMetadata, getWebpMetadata, importA1111, getLatentMetadata } from "./pnginfo.js";
|
||||
import { addDomClippingSetting } from "./domWidget.js";
|
||||
import { createImageHost, calculateImageGrid } from "./ui/imagePreview.js"
|
||||
|
||||
export const ANIM_PREVIEW_WIDGET = "$$comfy_animation_preview"
|
||||
|
||||
function sanitizeNodeName(string) {
|
||||
let entityMap = {
|
||||
@ -409,7 +412,9 @@ export class ComfyApp {
|
||||
return shiftY;
|
||||
}
|
||||
|
||||
node.prototype.setSizeForImage = function () {
|
||||
node.prototype.setSizeForImage = function (force) {
|
||||
if(!force && this.animatedImages) return;
|
||||
|
||||
if (this.inputHeight) {
|
||||
this.setSize(this.size);
|
||||
return;
|
||||
@ -426,13 +431,20 @@ export class ComfyApp {
|
||||
let imagesChanged = false
|
||||
|
||||
const output = app.nodeOutputs[this.id + ""];
|
||||
if (output && output.images) {
|
||||
if (output?.images) {
|
||||
this.animatedImages = output?.animated?.find(Boolean);
|
||||
if (this.images !== output.images) {
|
||||
this.images = output.images;
|
||||
imagesChanged = true;
|
||||
imgURLs = imgURLs.concat(output.images.map(params => {
|
||||
return api.apiURL("/view?" + new URLSearchParams(params).toString() + app.getPreviewFormatParam() + app.getRandParam());
|
||||
}))
|
||||
imgURLs = imgURLs.concat(
|
||||
output.images.map((params) => {
|
||||
return api.apiURL(
|
||||
"/view?" +
|
||||
new URLSearchParams(params).toString() +
|
||||
(this.animatedImages ? "" : app.getPreviewFormatParam()) + app.getRandParam()
|
||||
);
|
||||
})
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@ -511,7 +523,35 @@ export class ComfyApp {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (this.imgs && this.imgs.length) {
|
||||
if (this.imgs?.length) {
|
||||
const widgetIdx = this.widgets?.findIndex((w) => w.name === ANIM_PREVIEW_WIDGET);
|
||||
|
||||
if(this.animatedImages) {
|
||||
// Instead of using the canvas we'll use a IMG
|
||||
if(widgetIdx > -1) {
|
||||
// Replace content
|
||||
const widget = this.widgets[widgetIdx];
|
||||
widget.options.host.updateImages(this.imgs);
|
||||
} else {
|
||||
const host = createImageHost(this);
|
||||
this.setSizeForImage(true);
|
||||
const widget = this.addDOMWidget(ANIM_PREVIEW_WIDGET, "img", host.el, {
|
||||
host,
|
||||
getHeight: host.getHeight,
|
||||
onDraw: host.onDraw,
|
||||
hideOnZoom: false
|
||||
});
|
||||
widget.serializeValue = () => undefined;
|
||||
widget.options.host.updateImages(this.imgs);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (widgetIdx > -1) {
|
||||
this.widgets[widgetIdx].onRemove?.();
|
||||
this.widgets.splice(widgetIdx, 1);
|
||||
}
|
||||
|
||||
const canvas = app.graph.list_of_graphcanvas[0];
|
||||
const mouse = canvas.graph_mouse;
|
||||
if (!canvas.pointer_is_down && this.pointerDown) {
|
||||
@ -551,31 +591,7 @@ export class ComfyApp {
|
||||
}
|
||||
else {
|
||||
cell_padding = 0;
|
||||
let best = 0;
|
||||
let w = this.imgs[0].naturalWidth;
|
||||
let h = this.imgs[0].naturalHeight;
|
||||
|
||||
// compact style
|
||||
for (let c = 1; c <= numImages; c++) {
|
||||
const rows = Math.ceil(numImages / c);
|
||||
const cW = dw / c;
|
||||
const cH = dh / rows;
|
||||
const scaleX = cW / w;
|
||||
const scaleY = cH / h;
|
||||
|
||||
const scale = Math.min(scaleX, scaleY, 1);
|
||||
const imageW = w * scale;
|
||||
const imageH = h * scale;
|
||||
const area = imageW * imageH * numImages;
|
||||
|
||||
if (area > best) {
|
||||
best = area;
|
||||
cellWidth = imageW;
|
||||
cellHeight = imageH;
|
||||
cols = c;
|
||||
shiftX = c * ((cW - imageW) / 2);
|
||||
}
|
||||
}
|
||||
({ cellWidth, cellHeight, cols, shiftX } = calculateImageGrid(this.imgs, dw, dh));
|
||||
}
|
||||
|
||||
let anyHovered = false;
|
||||
@ -767,7 +783,7 @@ export class ComfyApp {
|
||||
* Adds a handler on paste that extracts and loads images or workflows from pasted JSON data
|
||||
*/
|
||||
#addPasteHandler() {
|
||||
document.addEventListener("paste", (e) => {
|
||||
document.addEventListener("paste", async (e) => {
|
||||
// ctrl+shift+v is used to paste nodes with connections
|
||||
// this is handled by litegraph
|
||||
if(this.shiftDown) return;
|
||||
@ -815,7 +831,7 @@ export class ComfyApp {
|
||||
}
|
||||
|
||||
if (workflow && workflow.version && workflow.nodes && workflow.extra) {
|
||||
this.loadGraphData(workflow);
|
||||
await this.loadGraphData(workflow);
|
||||
}
|
||||
else {
|
||||
if (e.target.type === "text" || e.target.type === "textarea") {
|
||||
@ -1165,7 +1181,19 @@ export class ComfyApp {
|
||||
});
|
||||
|
||||
api.addEventListener("executed", ({ detail }) => {
|
||||
this.nodeOutputs[detail.node] = detail.output;
|
||||
const output = this.nodeOutputs[detail.node];
|
||||
if (detail.merge && output) {
|
||||
for (const k in detail.output ?? {}) {
|
||||
const v = output[k];
|
||||
if (v instanceof Array) {
|
||||
output[k] = v.concat(detail.output[k]);
|
||||
} else {
|
||||
output[k] = detail.output[k];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
this.nodeOutputs[detail.node] = detail.output;
|
||||
}
|
||||
const node = this.graph.getNodeById(detail.node);
|
||||
if (node) {
|
||||
if (node.onExecuted)
|
||||
@ -1276,9 +1304,11 @@ export class ComfyApp {
|
||||
canvasEl.tabIndex = "1";
|
||||
document.body.prepend(canvasEl);
|
||||
|
||||
addDomClippingSetting();
|
||||
this.#addProcessMouseHandler();
|
||||
this.#addProcessKeyHandler();
|
||||
this.#addConfigureHandler();
|
||||
this.#addApiUpdateHandlers();
|
||||
|
||||
this.graph = new LGraph();
|
||||
|
||||
@ -1315,7 +1345,7 @@ export class ComfyApp {
|
||||
const json = localStorage.getItem("workflow");
|
||||
if (json) {
|
||||
const workflow = JSON.parse(json);
|
||||
this.loadGraphData(workflow);
|
||||
await this.loadGraphData(workflow);
|
||||
restored = true;
|
||||
}
|
||||
} catch (err) {
|
||||
@ -1324,7 +1354,7 @@ export class ComfyApp {
|
||||
|
||||
// We failed to restore a workflow so load the default
|
||||
if (!restored) {
|
||||
this.loadGraphData();
|
||||
await this.loadGraphData();
|
||||
}
|
||||
|
||||
// Save current workflow automatically
|
||||
@ -1332,7 +1362,6 @@ export class ComfyApp {
|
||||
|
||||
this.#addDrawNodeHandler();
|
||||
this.#addDrawGroupsHandler();
|
||||
this.#addApiUpdateHandlers();
|
||||
this.#addDropHandler();
|
||||
this.#addCopyHandler();
|
||||
this.#addPasteHandler();
|
||||
@ -1352,11 +1381,95 @@ export class ComfyApp {
|
||||
await this.#invokeExtensionsAsync("registerCustomNodes");
|
||||
}
|
||||
|
||||
getWidgetType(inputData, inputName) {
|
||||
const type = inputData[0];
|
||||
|
||||
if (Array.isArray(type)) {
|
||||
return "COMBO";
|
||||
} else if (`${type}:${inputName}` in this.widgets) {
|
||||
return `${type}:${inputName}`;
|
||||
} else if (type in this.widgets) {
|
||||
return type;
|
||||
} else {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
async registerNodeDef(nodeId, nodeData) {
|
||||
const self = this;
|
||||
const node = Object.assign(
|
||||
function ComfyNode() {
|
||||
var inputs = nodeData["input"]["required"];
|
||||
if (nodeData["input"]["optional"] != undefined) {
|
||||
inputs = Object.assign({}, nodeData["input"]["required"], nodeData["input"]["optional"]);
|
||||
}
|
||||
const config = { minWidth: 1, minHeight: 1 };
|
||||
for (const inputName in inputs) {
|
||||
const inputData = inputs[inputName];
|
||||
const type = inputData[0];
|
||||
|
||||
let widgetCreated = true;
|
||||
const widgetType = self.getWidgetType(inputData, inputName);
|
||||
if(widgetType) {
|
||||
if(widgetType === "COMBO") {
|
||||
Object.assign(config, self.widgets.COMBO(this, inputName, inputData, app) || {});
|
||||
} else {
|
||||
Object.assign(config, self.widgets[widgetType](this, inputName, inputData, app) || {});
|
||||
}
|
||||
} else {
|
||||
// Node connection inputs
|
||||
this.addInput(inputName, type);
|
||||
widgetCreated = false;
|
||||
}
|
||||
|
||||
if(widgetCreated && inputData[1]?.forceInput && config?.widget) {
|
||||
if (!config.widget.options) config.widget.options = {};
|
||||
config.widget.options.forceInput = inputData[1].forceInput;
|
||||
}
|
||||
if(widgetCreated && inputData[1]?.defaultInput && config?.widget) {
|
||||
if (!config.widget.options) config.widget.options = {};
|
||||
config.widget.options.defaultInput = inputData[1].defaultInput;
|
||||
}
|
||||
}
|
||||
|
||||
for (const o in nodeData["output"]) {
|
||||
let output = nodeData["output"][o];
|
||||
if(output instanceof Array) output = "COMBO";
|
||||
const outputName = nodeData["output_name"][o] || output;
|
||||
const outputShape = nodeData["output_is_list"][o] ? LiteGraph.GRID_SHAPE : LiteGraph.CIRCLE_SHAPE ;
|
||||
this.addOutput(outputName, output, { shape: outputShape });
|
||||
}
|
||||
|
||||
const s = this.computeSize();
|
||||
s[0] = Math.max(config.minWidth, s[0] * 1.5);
|
||||
s[1] = Math.max(config.minHeight, s[1]);
|
||||
this.size = s;
|
||||
this.serialize_widgets = true;
|
||||
|
||||
app.#invokeExtensionsAsync("nodeCreated", this);
|
||||
},
|
||||
{
|
||||
title: nodeData.display_name || nodeData.name,
|
||||
comfyClass: nodeData.name,
|
||||
nodeData
|
||||
}
|
||||
);
|
||||
node.prototype.comfyClass = nodeData.name;
|
||||
|
||||
this.#addNodeContextMenuHandler(node);
|
||||
this.#addDrawBackgroundHandler(node, app);
|
||||
this.#addNodeKeyHandler(node);
|
||||
|
||||
await this.#invokeExtensionsAsync("beforeRegisterNodeDef", node, nodeData);
|
||||
LiteGraph.registerNodeType(nodeId, node);
|
||||
node.category = nodeData.category;
|
||||
}
|
||||
|
||||
async registerNodesFromDefs(defs) {
|
||||
await this.#invokeExtensionsAsync("addCustomNodeDefs", defs);
|
||||
|
||||
// Generate list of known widgets
|
||||
const widgets = Object.assign(
|
||||
this.widgets = Object.assign(
|
||||
{},
|
||||
ComfyWidgets,
|
||||
...(await this.#invokeExtensionsAsync("getCustomWidgets")).filter(Boolean)
|
||||
@ -1364,75 +1477,7 @@ export class ComfyApp {
|
||||
|
||||
// Register a node for each definition
|
||||
for (const nodeId in defs) {
|
||||
const nodeData = defs[nodeId];
|
||||
const node = Object.assign(
|
||||
function ComfyNode() {
|
||||
var inputs = nodeData["input"]["required"];
|
||||
if (nodeData["input"]["optional"] != undefined){
|
||||
inputs = Object.assign({}, nodeData["input"]["required"], nodeData["input"]["optional"])
|
||||
}
|
||||
const config = { minWidth: 1, minHeight: 1 };
|
||||
for (const inputName in inputs) {
|
||||
const inputData = inputs[inputName];
|
||||
const type = inputData[0];
|
||||
|
||||
let widgetCreated = true;
|
||||
if (Array.isArray(type)) {
|
||||
// Enums
|
||||
Object.assign(config, widgets.COMBO(this, inputName, inputData, app) || {});
|
||||
} else if (`${type}:${inputName}` in widgets) {
|
||||
// Support custom widgets by Type:Name
|
||||
Object.assign(config, widgets[`${type}:${inputName}`](this, inputName, inputData, app) || {});
|
||||
} else if (type in widgets) {
|
||||
// Standard type widgets
|
||||
Object.assign(config, widgets[type](this, inputName, inputData, app) || {});
|
||||
} else {
|
||||
// Node connection inputs
|
||||
this.addInput(inputName, type);
|
||||
widgetCreated = false;
|
||||
}
|
||||
|
||||
if(widgetCreated && inputData[1]?.forceInput && config?.widget) {
|
||||
if (!config.widget.options) config.widget.options = {};
|
||||
config.widget.options.forceInput = inputData[1].forceInput;
|
||||
}
|
||||
if(widgetCreated && inputData[1]?.defaultInput && config?.widget) {
|
||||
if (!config.widget.options) config.widget.options = {};
|
||||
config.widget.options.defaultInput = inputData[1].defaultInput;
|
||||
}
|
||||
}
|
||||
|
||||
for (const o in nodeData["output"]) {
|
||||
let output = nodeData["output"][o];
|
||||
if(output instanceof Array) output = "COMBO";
|
||||
const outputName = nodeData["output_name"][o] || output;
|
||||
const outputShape = nodeData["output_is_list"][o] ? LiteGraph.GRID_SHAPE : LiteGraph.CIRCLE_SHAPE ;
|
||||
this.addOutput(outputName, output, { shape: outputShape });
|
||||
}
|
||||
|
||||
const s = this.computeSize();
|
||||
s[0] = Math.max(config.minWidth, s[0] * 1.5);
|
||||
s[1] = Math.max(config.minHeight, s[1]);
|
||||
this.size = s;
|
||||
this.serialize_widgets = true;
|
||||
|
||||
app.#invokeExtensionsAsync("nodeCreated", this);
|
||||
},
|
||||
{
|
||||
title: nodeData.display_name || nodeData.name,
|
||||
comfyClass: nodeData.name,
|
||||
nodeData
|
||||
}
|
||||
);
|
||||
node.prototype.comfyClass = nodeData.name;
|
||||
|
||||
this.#addNodeContextMenuHandler(node);
|
||||
this.#addDrawBackgroundHandler(node, app);
|
||||
this.#addNodeKeyHandler(node);
|
||||
|
||||
await this.#invokeExtensionsAsync("beforeRegisterNodeDef", node, nodeData);
|
||||
LiteGraph.registerNodeType(nodeId, node);
|
||||
node.category = nodeData.category;
|
||||
this.registerNodeDef(nodeId, defs[nodeId]);
|
||||
}
|
||||
}
|
||||
|
||||
@ -1475,9 +1520,14 @@ export class ComfyApp {
|
||||
|
||||
showMissingNodesError(missingNodeTypes, hasAddedNodes = true) {
|
||||
this.ui.dialog.show(
|
||||
`When loading the graph, the following node types were not found: <ul>${Array.from(new Set(missingNodeTypes)).map(
|
||||
(t) => `<li>${t}</li>`
|
||||
).join("")}</ul>${hasAddedNodes ? "Nodes that have failed to load will show as red on the graph." : ""}`
|
||||
$el("div", [
|
||||
$el("span", { textContent: "When loading the graph, the following node types were not found: " }),
|
||||
$el(
|
||||
"ul",
|
||||
Array.from(new Set(missingNodeTypes)).map((t) => $el("li", { textContent: t }))
|
||||
),
|
||||
...(hasAddedNodes ? [$el("span", { textContent: "Nodes that have failed to load will show as red on the graph." })] : []),
|
||||
])
|
||||
);
|
||||
this.logging.addEntry("Comfy.App", "warn", {
|
||||
MissingNodes: missingNodeTypes,
|
||||
@ -1488,31 +1538,35 @@ export class ComfyApp {
|
||||
* Populates the graph with the specified workflow data
|
||||
* @param {*} graphData A serialized graph object
|
||||
*/
|
||||
loadGraphData(graphData) {
|
||||
async loadGraphData(graphData) {
|
||||
this.clean();
|
||||
|
||||
let reset_invalid_values = false;
|
||||
if (!graphData) {
|
||||
if (typeof structuredClone === "undefined")
|
||||
{
|
||||
graphData = JSON.parse(JSON.stringify(defaultGraph));
|
||||
}else
|
||||
{
|
||||
graphData = structuredClone(defaultGraph);
|
||||
}
|
||||
graphData = defaultGraph;
|
||||
reset_invalid_values = true;
|
||||
}
|
||||
|
||||
if (typeof structuredClone === "undefined")
|
||||
{
|
||||
graphData = JSON.parse(JSON.stringify(graphData));
|
||||
}else
|
||||
{
|
||||
graphData = structuredClone(graphData);
|
||||
}
|
||||
|
||||
const missingNodeTypes = [];
|
||||
await this.#invokeExtensionsAsync("beforeConfigureGraph", graphData, missingNodeTypes);
|
||||
for (let n of graphData.nodes) {
|
||||
// Patch T2IAdapterLoader to ControlNetLoader since they are the same node now
|
||||
if (n.type == "T2IAdapterLoader") n.type = "ControlNetLoader";
|
||||
if (n.type == "ConditioningAverage ") n.type = "ConditioningAverage"; //typo fix
|
||||
if (n.type == "SDV_img2vid_Conditioning") n.type = "SVD_img2vid_Conditioning"; //typo fix
|
||||
|
||||
// Find missing node types
|
||||
if (!(n.type in LiteGraph.registered_node_types)) {
|
||||
n.type = sanitizeNodeName(n.type);
|
||||
missingNodeTypes.push(n.type);
|
||||
n.type = sanitizeNodeName(n.type);
|
||||
}
|
||||
}
|
||||
|
||||
@ -1604,6 +1658,7 @@ export class ComfyApp {
|
||||
if (missingNodeTypes.length) {
|
||||
this.showMissingNodesError(missingNodeTypes);
|
||||
}
|
||||
await this.#invokeExtensionsAsync("afterConfigureGraph", missingNodeTypes);
|
||||
}
|
||||
|
||||
/**
|
||||
@ -1611,92 +1666,98 @@ export class ComfyApp {
|
||||
* @returns The workflow and node links
|
||||
*/
|
||||
async graphToPrompt() {
|
||||
for (const node of this.graph.computeExecutionOrder(false)) {
|
||||
if (node.isVirtualNode) {
|
||||
// Don't serialize frontend only nodes but let them make changes
|
||||
if (node.applyToGraph) {
|
||||
node.applyToGraph();
|
||||
for (const outerNode of this.graph.computeExecutionOrder(false)) {
|
||||
const innerNodes = outerNode.getInnerNodes ? outerNode.getInnerNodes() : [outerNode];
|
||||
for (const node of innerNodes) {
|
||||
if (node.isVirtualNode) {
|
||||
// Don't serialize frontend only nodes but let them make changes
|
||||
if (node.applyToGraph) {
|
||||
node.applyToGraph();
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
const workflow = this.graph.serialize();
|
||||
const output = {};
|
||||
// Process nodes in order of execution
|
||||
for (const node of this.graph.computeExecutionOrder(false)) {
|
||||
const n = workflow.nodes.find((n) => n.id === node.id);
|
||||
for (const outerNode of this.graph.computeExecutionOrder(false)) {
|
||||
const innerNodes = outerNode.getInnerNodes ? outerNode.getInnerNodes() : [outerNode];
|
||||
for (const node of innerNodes) {
|
||||
if (node.isVirtualNode) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (node.isVirtualNode) {
|
||||
continue;
|
||||
}
|
||||
if (node.mode === 2 || node.mode === 4) {
|
||||
// Don't serialize muted nodes
|
||||
continue;
|
||||
}
|
||||
|
||||
if (node.mode === 2 || node.mode === 4) {
|
||||
// Don't serialize muted nodes
|
||||
continue;
|
||||
}
|
||||
const inputs = {};
|
||||
const widgets = node.widgets;
|
||||
|
||||
const inputs = {};
|
||||
const widgets = node.widgets;
|
||||
|
||||
// Store all widget values
|
||||
if (widgets) {
|
||||
for (const i in widgets) {
|
||||
const widget = widgets[i];
|
||||
if (!widget.options || widget.options.serialize !== false) {
|
||||
inputs[widget.name] = widget.serializeValue ? await widget.serializeValue(n, i) : widget.value;
|
||||
// Store all widget values
|
||||
if (widgets) {
|
||||
for (const i in widgets) {
|
||||
const widget = widgets[i];
|
||||
if (!widget.options || widget.options.serialize !== false) {
|
||||
inputs[widget.name] = widget.serializeValue ? await widget.serializeValue(node, i) : widget.value;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Store all node links
|
||||
for (let i in node.inputs) {
|
||||
let parent = node.getInputNode(i);
|
||||
if (parent) {
|
||||
let link = node.getInputLink(i);
|
||||
while (parent.mode === 4 || parent.isVirtualNode) {
|
||||
let found = false;
|
||||
if (parent.isVirtualNode) {
|
||||
link = parent.getInputLink(link.origin_slot);
|
||||
if (link) {
|
||||
parent = parent.getInputNode(link.target_slot);
|
||||
if (parent) {
|
||||
found = true;
|
||||
}
|
||||
}
|
||||
} else if (link && parent.mode === 4) {
|
||||
let all_inputs = [link.origin_slot];
|
||||
if (parent.inputs) {
|
||||
all_inputs = all_inputs.concat(Object.keys(parent.inputs))
|
||||
for (let parent_input in all_inputs) {
|
||||
parent_input = all_inputs[parent_input];
|
||||
if (parent.inputs[parent_input]?.type === node.inputs[i].type) {
|
||||
link = parent.getInputLink(parent_input);
|
||||
if (link) {
|
||||
parent = parent.getInputNode(parent_input);
|
||||
}
|
||||
// Store all node links
|
||||
for (let i in node.inputs) {
|
||||
let parent = node.getInputNode(i);
|
||||
if (parent) {
|
||||
let link = node.getInputLink(i);
|
||||
while (parent.mode === 4 || parent.isVirtualNode) {
|
||||
let found = false;
|
||||
if (parent.isVirtualNode) {
|
||||
link = parent.getInputLink(link.origin_slot);
|
||||
if (link) {
|
||||
parent = parent.getInputNode(link.target_slot);
|
||||
if (parent) {
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
} else if (link && parent.mode === 4) {
|
||||
let all_inputs = [link.origin_slot];
|
||||
if (parent.inputs) {
|
||||
all_inputs = all_inputs.concat(Object.keys(parent.inputs))
|
||||
for (let parent_input in all_inputs) {
|
||||
parent_input = all_inputs[parent_input];
|
||||
if (parent.inputs[parent_input]?.type === node.inputs[i].type) {
|
||||
link = parent.getInputLink(parent_input);
|
||||
if (link) {
|
||||
parent = parent.getInputNode(parent_input);
|
||||
}
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!found) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!found) {
|
||||
break;
|
||||
if (link) {
|
||||
if (parent?.updateLink) {
|
||||
link = parent.updateLink(link);
|
||||
}
|
||||
inputs[node.inputs[i].name] = [String(link.origin_id), parseInt(link.origin_slot)];
|
||||
}
|
||||
}
|
||||
|
||||
if (link) {
|
||||
inputs[node.inputs[i].name] = [String(link.origin_id), parseInt(link.origin_slot)];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
output[String(node.id)] = {
|
||||
inputs,
|
||||
class_type: node.comfyClass,
|
||||
};
|
||||
output[String(node.id)] = {
|
||||
inputs,
|
||||
class_type: node.comfyClass,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
// Remove inputs connected to removed nodes
|
||||
@ -1816,7 +1877,7 @@ export class ComfyApp {
|
||||
const pngInfo = await getPngMetadata(file);
|
||||
if (pngInfo) {
|
||||
if (pngInfo.workflow) {
|
||||
this.loadGraphData(JSON.parse(pngInfo.workflow));
|
||||
await this.loadGraphData(JSON.parse(pngInfo.workflow));
|
||||
} else if (pngInfo.parameters) {
|
||||
importA1111(this.graph, pngInfo.parameters);
|
||||
}
|
||||
@ -1832,21 +1893,21 @@ export class ComfyApp {
|
||||
}
|
||||
} else if (file.type === "application/json" || file.name?.endsWith(".json")) {
|
||||
const reader = new FileReader();
|
||||
reader.onload = () => {
|
||||
reader.onload = async () => {
|
||||
const jsonContent = JSON.parse(reader.result);
|
||||
if (jsonContent?.templates) {
|
||||
this.loadTemplateData(jsonContent);
|
||||
} else if(this.isApiJson(jsonContent)) {
|
||||
this.loadApiJson(jsonContent);
|
||||
} else {
|
||||
this.loadGraphData(jsonContent);
|
||||
await this.loadGraphData(jsonContent);
|
||||
}
|
||||
};
|
||||
reader.readAsText(file);
|
||||
} else if (file.name?.endsWith(".latent") || file.name?.endsWith(".safetensors")) {
|
||||
const info = await getLatentMetadata(file);
|
||||
if (info.workflow) {
|
||||
this.loadGraphData(JSON.parse(info.workflow));
|
||||
await this.loadGraphData(JSON.parse(info.workflow));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1867,7 +1928,7 @@ export class ComfyApp {
|
||||
for (const id of ids) {
|
||||
const data = apiData[id];
|
||||
const node = LiteGraph.createNode(data.class_type);
|
||||
node.id = id;
|
||||
node.id = isNaN(+id) ? id : +id;
|
||||
graph.add(node);
|
||||
}
|
||||
|
||||
|
322
web/scripts/domWidget.js
Normal file
322
web/scripts/domWidget.js
Normal file
@ -0,0 +1,322 @@
|
||||
import { app, ANIM_PREVIEW_WIDGET } from "./app.js";
|
||||
|
||||
const SIZE = Symbol();
|
||||
|
||||
function intersect(a, b) {
|
||||
const x = Math.max(a.x, b.x);
|
||||
const num1 = Math.min(a.x + a.width, b.x + b.width);
|
||||
const y = Math.max(a.y, b.y);
|
||||
const num2 = Math.min(a.y + a.height, b.y + b.height);
|
||||
if (num1 >= x && num2 >= y) return [x, y, num1 - x, num2 - y];
|
||||
else return null;
|
||||
}
|
||||
|
||||
function getClipPath(node, element, elRect) {
|
||||
const selectedNode = Object.values(app.canvas.selected_nodes)[0];
|
||||
if (selectedNode && selectedNode !== node) {
|
||||
const MARGIN = 7;
|
||||
const scale = app.canvas.ds.scale;
|
||||
|
||||
const bounding = selectedNode.getBounding();
|
||||
const intersection = intersect(
|
||||
{ x: elRect.x / scale, y: elRect.y / scale, width: elRect.width / scale, height: elRect.height / scale },
|
||||
{
|
||||
x: selectedNode.pos[0] + app.canvas.ds.offset[0] - MARGIN,
|
||||
y: selectedNode.pos[1] + app.canvas.ds.offset[1] - LiteGraph.NODE_TITLE_HEIGHT - MARGIN,
|
||||
width: bounding[2] + MARGIN + MARGIN,
|
||||
height: bounding[3] + MARGIN + MARGIN,
|
||||
}
|
||||
);
|
||||
|
||||
if (!intersection) {
|
||||
return "";
|
||||
}
|
||||
|
||||
const widgetRect = element.getBoundingClientRect();
|
||||
const clipX = intersection[0] - widgetRect.x / scale + "px";
|
||||
const clipY = intersection[1] - widgetRect.y / scale + "px";
|
||||
const clipWidth = intersection[2] + "px";
|
||||
const clipHeight = intersection[3] + "px";
|
||||
const path = `polygon(0% 0%, 0% 100%, ${clipX} 100%, ${clipX} ${clipY}, calc(${clipX} + ${clipWidth}) ${clipY}, calc(${clipX} + ${clipWidth}) calc(${clipY} + ${clipHeight}), ${clipX} calc(${clipY} + ${clipHeight}), ${clipX} 100%, 100% 100%, 100% 0%)`;
|
||||
return path;
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
function computeSize(size) {
|
||||
if (this.widgets?.[0]?.last_y == null) return;
|
||||
|
||||
let y = this.widgets[0].last_y;
|
||||
let freeSpace = size[1] - y;
|
||||
|
||||
let widgetHeight = 0;
|
||||
let dom = [];
|
||||
for (const w of this.widgets) {
|
||||
if (w.type === "converted-widget") {
|
||||
// Ignore
|
||||
delete w.computedHeight;
|
||||
} else if (w.computeSize) {
|
||||
widgetHeight += w.computeSize()[1] + 4;
|
||||
} else if (w.element) {
|
||||
// Extract DOM widget size info
|
||||
const styles = getComputedStyle(w.element);
|
||||
let minHeight = w.options.getMinHeight?.() ?? parseInt(styles.getPropertyValue("--comfy-widget-min-height"));
|
||||
let maxHeight = w.options.getMaxHeight?.() ?? parseInt(styles.getPropertyValue("--comfy-widget-max-height"));
|
||||
|
||||
let prefHeight = w.options.getHeight?.() ?? styles.getPropertyValue("--comfy-widget-height");
|
||||
if (prefHeight.endsWith?.("%")) {
|
||||
prefHeight = size[1] * (parseFloat(prefHeight.substring(0, prefHeight.length - 1)) / 100);
|
||||
} else {
|
||||
prefHeight = parseInt(prefHeight);
|
||||
if (isNaN(minHeight)) {
|
||||
minHeight = prefHeight;
|
||||
}
|
||||
}
|
||||
if (isNaN(minHeight)) {
|
||||
minHeight = 50;
|
||||
}
|
||||
if (!isNaN(maxHeight)) {
|
||||
if (!isNaN(prefHeight)) {
|
||||
prefHeight = Math.min(prefHeight, maxHeight);
|
||||
} else {
|
||||
prefHeight = maxHeight;
|
||||
}
|
||||
}
|
||||
dom.push({
|
||||
minHeight,
|
||||
prefHeight,
|
||||
w,
|
||||
});
|
||||
} else {
|
||||
widgetHeight += LiteGraph.NODE_WIDGET_HEIGHT + 4;
|
||||
}
|
||||
}
|
||||
|
||||
freeSpace -= widgetHeight;
|
||||
|
||||
// Calculate sizes with all widgets at their min height
|
||||
const prefGrow = []; // Nodes that want to grow to their prefd size
|
||||
const canGrow = []; // Nodes that can grow to auto size
|
||||
let growBy = 0;
|
||||
for (const d of dom) {
|
||||
freeSpace -= d.minHeight;
|
||||
if (isNaN(d.prefHeight)) {
|
||||
canGrow.push(d);
|
||||
d.w.computedHeight = d.minHeight;
|
||||
} else {
|
||||
const diff = d.prefHeight - d.minHeight;
|
||||
if (diff > 0) {
|
||||
prefGrow.push(d);
|
||||
growBy += diff;
|
||||
d.diff = diff;
|
||||
} else {
|
||||
d.w.computedHeight = d.minHeight;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (this.imgs && !this.widgets.find((w) => w.name === ANIM_PREVIEW_WIDGET)) {
|
||||
// Allocate space for image
|
||||
freeSpace -= 220;
|
||||
}
|
||||
|
||||
if (freeSpace < 0) {
|
||||
// Not enough space for all widgets so we need to grow
|
||||
size[1] -= freeSpace;
|
||||
this.graph.setDirtyCanvas(true);
|
||||
} else {
|
||||
// Share the space between each
|
||||
const growDiff = freeSpace - growBy;
|
||||
if (growDiff > 0) {
|
||||
// All pref sizes can be fulfilled
|
||||
freeSpace = growDiff;
|
||||
for (const d of prefGrow) {
|
||||
d.w.computedHeight = d.prefHeight;
|
||||
}
|
||||
} else {
|
||||
// We need to grow evenly
|
||||
const shared = -growDiff / prefGrow.length;
|
||||
for (const d of prefGrow) {
|
||||
d.w.computedHeight = d.prefHeight - shared;
|
||||
}
|
||||
freeSpace = 0;
|
||||
}
|
||||
|
||||
if (freeSpace > 0 && canGrow.length) {
|
||||
// Grow any that are auto height
|
||||
const shared = freeSpace / canGrow.length;
|
||||
for (const d of canGrow) {
|
||||
d.w.computedHeight += shared;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Position each of the widgets
|
||||
for (const w of this.widgets) {
|
||||
w.y = y;
|
||||
if (w.computedHeight) {
|
||||
y += w.computedHeight;
|
||||
} else if (w.computeSize) {
|
||||
y += w.computeSize()[1] + 4;
|
||||
} else {
|
||||
y += LiteGraph.NODE_WIDGET_HEIGHT + 4;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Override the compute visible nodes function to allow us to hide/show DOM elements when the node goes offscreen
|
||||
const elementWidgets = new Set();
|
||||
const computeVisibleNodes = LGraphCanvas.prototype.computeVisibleNodes;
|
||||
LGraphCanvas.prototype.computeVisibleNodes = function () {
|
||||
const visibleNodes = computeVisibleNodes.apply(this, arguments);
|
||||
for (const node of app.graph._nodes) {
|
||||
if (elementWidgets.has(node)) {
|
||||
const hidden = visibleNodes.indexOf(node) === -1;
|
||||
for (const w of node.widgets) {
|
||||
if (w.element) {
|
||||
w.element.hidden = hidden;
|
||||
if (hidden) {
|
||||
w.options.onHide?.(w);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return visibleNodes;
|
||||
};
|
||||
|
||||
let enableDomClipping = true;
|
||||
|
||||
export function addDomClippingSetting() {
|
||||
app.ui.settings.addSetting({
|
||||
id: "Comfy.DOMClippingEnabled",
|
||||
name: "Enable DOM element clipping (enabling may reduce performance)",
|
||||
type: "boolean",
|
||||
defaultValue: enableDomClipping,
|
||||
onChange(value) {
|
||||
enableDomClipping = !!value;
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
LGraphNode.prototype.addDOMWidget = function (name, type, element, options) {
|
||||
options = { hideOnZoom: true, selectOn: ["focus", "click"], ...options };
|
||||
|
||||
if (!element.parentElement) {
|
||||
document.body.append(element);
|
||||
}
|
||||
|
||||
let mouseDownHandler;
|
||||
if (element.blur) {
|
||||
mouseDownHandler = (event) => {
|
||||
if (!element.contains(event.target)) {
|
||||
element.blur();
|
||||
}
|
||||
};
|
||||
document.addEventListener("mousedown", mouseDownHandler);
|
||||
}
|
||||
|
||||
const widget = {
|
||||
type,
|
||||
name,
|
||||
get value() {
|
||||
return options.getValue?.() ?? undefined;
|
||||
},
|
||||
set value(v) {
|
||||
options.setValue?.(v);
|
||||
widget.callback?.(widget.value);
|
||||
},
|
||||
draw: function (ctx, node, widgetWidth, y, widgetHeight) {
|
||||
if (widget.computedHeight == null) {
|
||||
computeSize.call(node, node.size);
|
||||
}
|
||||
|
||||
const hidden =
|
||||
node.flags?.collapsed ||
|
||||
(!!options.hideOnZoom && app.canvas.ds.scale < 0.5) ||
|
||||
widget.computedHeight <= 0 ||
|
||||
widget.type === "converted-widget";
|
||||
element.hidden = hidden;
|
||||
element.style.display = hidden ? "none" : null;
|
||||
if (hidden) {
|
||||
widget.options.onHide?.(widget);
|
||||
return;
|
||||
}
|
||||
|
||||
const margin = 10;
|
||||
const elRect = ctx.canvas.getBoundingClientRect();
|
||||
const transform = new DOMMatrix()
|
||||
.scaleSelf(elRect.width / ctx.canvas.width, elRect.height / ctx.canvas.height)
|
||||
.multiplySelf(ctx.getTransform())
|
||||
.translateSelf(margin, margin + y);
|
||||
|
||||
const scale = new DOMMatrix().scaleSelf(transform.a, transform.d);
|
||||
|
||||
Object.assign(element.style, {
|
||||
transformOrigin: "0 0",
|
||||
transform: scale,
|
||||
left: `${transform.a + transform.e}px`,
|
||||
top: `${transform.d + transform.f}px`,
|
||||
width: `${widgetWidth - margin * 2}px`,
|
||||
height: `${(widget.computedHeight ?? 50) - margin * 2}px`,
|
||||
position: "absolute",
|
||||
zIndex: app.graph._nodes.indexOf(node),
|
||||
});
|
||||
|
||||
if (enableDomClipping) {
|
||||
element.style.clipPath = getClipPath(node, element, elRect);
|
||||
element.style.willChange = "clip-path";
|
||||
}
|
||||
|
||||
this.options.onDraw?.(widget);
|
||||
},
|
||||
element,
|
||||
options,
|
||||
onRemove() {
|
||||
if (mouseDownHandler) {
|
||||
document.removeEventListener("mousedown", mouseDownHandler);
|
||||
}
|
||||
element.remove();
|
||||
},
|
||||
};
|
||||
|
||||
for (const evt of options.selectOn) {
|
||||
element.addEventListener(evt, () => {
|
||||
app.canvas.selectNode(this);
|
||||
app.canvas.bringToFront(this);
|
||||
});
|
||||
}
|
||||
|
||||
this.addCustomWidget(widget);
|
||||
elementWidgets.add(this);
|
||||
|
||||
const collapse = this.collapse;
|
||||
this.collapse = function() {
|
||||
collapse.apply(this, arguments);
|
||||
if(this.flags?.collapsed) {
|
||||
element.hidden = true;
|
||||
element.style.display = "none";
|
||||
}
|
||||
}
|
||||
|
||||
const onRemoved = this.onRemoved;
|
||||
this.onRemoved = function () {
|
||||
element.remove();
|
||||
elementWidgets.delete(this);
|
||||
onRemoved?.apply(this, arguments);
|
||||
};
|
||||
|
||||
if (!this[SIZE]) {
|
||||
this[SIZE] = true;
|
||||
const onResize = this.onResize;
|
||||
this.onResize = function (size) {
|
||||
options.beforeResize?.call(widget, this);
|
||||
computeSize.call(this, size);
|
||||
onResize?.apply(this, arguments);
|
||||
options.afterResize?.call(widget, this);
|
||||
};
|
||||
}
|
||||
|
||||
return widget;
|
||||
};
|
@ -24,7 +24,7 @@ export function getPngMetadata(file) {
|
||||
const length = dataView.getUint32(offset);
|
||||
// Get the chunk type
|
||||
const type = String.fromCharCode(...pngData.slice(offset + 4, offset + 8));
|
||||
if (type === "tEXt") {
|
||||
if (type === "tEXt" || type == "comf") {
|
||||
// Get the keyword
|
||||
let keyword_end = offset + 8;
|
||||
while (pngData[keyword_end] !== 0) {
|
||||
@ -50,7 +50,6 @@ export function getPngMetadata(file) {
|
||||
function parseExifData(exifData) {
|
||||
// Check for the correct TIFF header (0x4949 for little-endian or 0x4D4D for big-endian)
|
||||
const isLittleEndian = new Uint16Array(exifData.slice(0, 2))[0] === 0x4949;
|
||||
console.log(exifData);
|
||||
|
||||
// Function to read 16-bit and 32-bit integers from binary data
|
||||
function readInt(offset, isLittleEndian, length) {
|
||||
@ -126,6 +125,9 @@ export function getWebpMetadata(file) {
|
||||
const chunk_length = dataView.getUint32(offset + 4, true);
|
||||
const chunk_type = String.fromCharCode(...webp.slice(offset, offset + 4));
|
||||
if (chunk_type === "EXIF") {
|
||||
if (String.fromCharCode(...webp.slice(offset + 8, offset + 8 + 6)) == "Exif\0\0") {
|
||||
offset += 6;
|
||||
}
|
||||
let data = parseExifData(webp.slice(offset + 8, offset + 8 + chunk_length));
|
||||
for (var key in data) {
|
||||
var value = data[key];
|
||||
|
@ -462,8 +462,8 @@ class ComfyList {
|
||||
return $el("div", {textContent: item.prompt[0] + ": "}, [
|
||||
$el("button", {
|
||||
textContent: "Load",
|
||||
onclick: () => {
|
||||
app.loadGraphData(item.prompt[3].extra_pnginfo.workflow);
|
||||
onclick: async () => {
|
||||
await app.loadGraphData(item.prompt[3].extra_pnginfo.workflow);
|
||||
if (item.outputs) {
|
||||
app.nodeOutputs = item.outputs;
|
||||
}
|
||||
@ -599,7 +599,7 @@ export class ComfyUI {
|
||||
const fileInput = $el("input", {
|
||||
id: "comfy-file-input",
|
||||
type: "file",
|
||||
accept: ".json,image/png,.latent,.safetensors",
|
||||
accept: ".json,image/png,.latent,.safetensors,image/webp",
|
||||
style: {display: "none"},
|
||||
parent: document.body,
|
||||
onchange: () => {
|
||||
@ -784,9 +784,9 @@ export class ComfyUI {
|
||||
}
|
||||
}),
|
||||
$el("button", {
|
||||
id: "comfy-load-default-button", textContent: "Load Default", onclick: () => {
|
||||
id: "comfy-load-default-button", textContent: "Load Default", onclick: async () => {
|
||||
if (!confirmClear.value || confirm("Load default workflow?")) {
|
||||
app.loadGraphData()
|
||||
await app.loadGraphData()
|
||||
}
|
||||
}
|
||||
}),
|
||||
|
97
web/scripts/ui/imagePreview.js
Normal file
97
web/scripts/ui/imagePreview.js
Normal file
@ -0,0 +1,97 @@
|
||||
import { $el } from "../ui.js";
|
||||
|
||||
export function calculateImageGrid(imgs, dw, dh) {
|
||||
let best = 0;
|
||||
let w = imgs[0].naturalWidth;
|
||||
let h = imgs[0].naturalHeight;
|
||||
const numImages = imgs.length;
|
||||
|
||||
let cellWidth, cellHeight, cols, rows, shiftX;
|
||||
// compact style
|
||||
for (let c = 1; c <= numImages; c++) {
|
||||
const r = Math.ceil(numImages / c);
|
||||
const cW = dw / c;
|
||||
const cH = dh / r;
|
||||
const scaleX = cW / w;
|
||||
const scaleY = cH / h;
|
||||
|
||||
const scale = Math.min(scaleX, scaleY, 1);
|
||||
const imageW = w * scale;
|
||||
const imageH = h * scale;
|
||||
const area = imageW * imageH * numImages;
|
||||
|
||||
if (area > best) {
|
||||
best = area;
|
||||
cellWidth = imageW;
|
||||
cellHeight = imageH;
|
||||
cols = c;
|
||||
rows = r;
|
||||
shiftX = c * ((cW - imageW) / 2);
|
||||
}
|
||||
}
|
||||
|
||||
return { cellWidth, cellHeight, cols, rows, shiftX };
|
||||
}
|
||||
|
||||
export function createImageHost(node) {
|
||||
const el = $el("div.comfy-img-preview");
|
||||
let currentImgs;
|
||||
let first = true;
|
||||
|
||||
function updateSize() {
|
||||
let w = null;
|
||||
let h = null;
|
||||
|
||||
if (currentImgs) {
|
||||
let elH = el.clientHeight;
|
||||
if (first) {
|
||||
first = false;
|
||||
// On first run, if we are small then grow a bit
|
||||
if (elH < 190) {
|
||||
elH = 190;
|
||||
}
|
||||
el.style.setProperty("--comfy-widget-min-height", elH);
|
||||
} else {
|
||||
el.style.setProperty("--comfy-widget-min-height", null);
|
||||
}
|
||||
|
||||
const nw = node.size[0];
|
||||
({ cellWidth: w, cellHeight: h } = calculateImageGrid(currentImgs, nw - 20, elH));
|
||||
w += "px";
|
||||
h += "px";
|
||||
|
||||
el.style.setProperty("--comfy-img-preview-width", w);
|
||||
el.style.setProperty("--comfy-img-preview-height", h);
|
||||
}
|
||||
}
|
||||
return {
|
||||
el,
|
||||
updateImages(imgs) {
|
||||
if (imgs !== currentImgs) {
|
||||
if (currentImgs == null) {
|
||||
requestAnimationFrame(() => {
|
||||
updateSize();
|
||||
});
|
||||
}
|
||||
el.replaceChildren(...imgs);
|
||||
currentImgs = imgs;
|
||||
node.onResize(node.size);
|
||||
node.graph.setDirtyCanvas(true, true);
|
||||
}
|
||||
},
|
||||
getHeight() {
|
||||
updateSize();
|
||||
},
|
||||
onDraw() {
|
||||
// Element from point uses a hittest find elements so we need to toggle pointer events
|
||||
el.style.pointerEvents = "all";
|
||||
const over = document.elementFromPoint(app.canvas.mouse[0], app.canvas.mouse[1]);
|
||||
el.style.pointerEvents = "none";
|
||||
|
||||
if(!over) return;
|
||||
// Set the overIndex so Open Image etc work
|
||||
const idx = currentImgs.indexOf(over);
|
||||
node.overIndex = idx;
|
||||
},
|
||||
};
|
||||
}
|
@ -1,4 +1,5 @@
|
||||
import { api } from "./api.js"
|
||||
import "./domWidget.js";
|
||||
|
||||
function getNumberDefaults(inputData, defaultStep, precision, enable_rounding) {
|
||||
let defaultVal = inputData[1]["default"];
|
||||
@ -22,18 +23,89 @@ function getNumberDefaults(inputData, defaultStep, precision, enable_rounding) {
|
||||
return { val: defaultVal, config: { min, max, step: 10.0 * step, round, precision } };
|
||||
}
|
||||
|
||||
export function addValueControlWidget(node, targetWidget, defaultValue = "randomize", values) {
|
||||
const valueControl = node.addWidget("combo", "control_after_generate", defaultValue, function (v) { }, {
|
||||
values: ["fixed", "increment", "decrement", "randomize"],
|
||||
serialize: false, // Don't include this in prompt.
|
||||
});
|
||||
valueControl.afterQueued = () => {
|
||||
export function addValueControlWidget(node, targetWidget, defaultValue = "randomize", values, widgetName, inputData) {
|
||||
let name = inputData[1]?.control_after_generate;
|
||||
if(typeof name !== "string") {
|
||||
name = widgetName;
|
||||
}
|
||||
const widgets = addValueControlWidgets(node, targetWidget, defaultValue, {
|
||||
addFilterList: false,
|
||||
controlAfterGenerateName: name
|
||||
}, inputData);
|
||||
return widgets[0];
|
||||
}
|
||||
|
||||
export function addValueControlWidgets(node, targetWidget, defaultValue = "randomize", options, inputData) {
|
||||
if (!defaultValue) defaultValue = "randomize";
|
||||
if (!options) options = {};
|
||||
|
||||
const getName = (defaultName, optionName) => {
|
||||
let name = defaultName;
|
||||
if (options[optionName]) {
|
||||
name = options[optionName];
|
||||
} else if (typeof inputData?.[1]?.[defaultName] === "string") {
|
||||
name = inputData?.[1]?.[defaultName];
|
||||
} else if (inputData?.[1]?.control_prefix) {
|
||||
name = inputData?.[1]?.control_prefix + " " + name
|
||||
}
|
||||
return name;
|
||||
}
|
||||
|
||||
const widgets = [];
|
||||
const valueControl = node.addWidget(
|
||||
"combo",
|
||||
getName("control_after_generate", "controlAfterGenerateName"),
|
||||
defaultValue,
|
||||
function () {},
|
||||
{
|
||||
values: ["fixed", "increment", "decrement", "randomize"],
|
||||
serialize: false, // Don't include this in prompt.
|
||||
}
|
||||
);
|
||||
widgets.push(valueControl);
|
||||
|
||||
const isCombo = targetWidget.type === "combo";
|
||||
let comboFilter;
|
||||
if (isCombo && options.addFilterList !== false) {
|
||||
comboFilter = node.addWidget(
|
||||
"string",
|
||||
getName("control_filter_list", "controlFilterListName"),
|
||||
"",
|
||||
function () {},
|
||||
{
|
||||
serialize: false, // Don't include this in prompt.
|
||||
}
|
||||
);
|
||||
widgets.push(comboFilter);
|
||||
}
|
||||
|
||||
valueControl.afterQueued = () => {
|
||||
var v = valueControl.value;
|
||||
|
||||
if (targetWidget.type == "combo" && v !== "fixed") {
|
||||
let current_index = targetWidget.options.values.indexOf(targetWidget.value);
|
||||
let current_length = targetWidget.options.values.length;
|
||||
if (isCombo && v !== "fixed") {
|
||||
let values = targetWidget.options.values;
|
||||
const filter = comboFilter?.value;
|
||||
if (filter) {
|
||||
let check;
|
||||
if (filter.startsWith("/") && filter.endsWith("/")) {
|
||||
try {
|
||||
const regex = new RegExp(filter.substring(1, filter.length - 1));
|
||||
check = (item) => regex.test(item);
|
||||
} catch (error) {
|
||||
console.error("Error constructing RegExp filter for node " + node.id, filter, error);
|
||||
}
|
||||
}
|
||||
if (!check) {
|
||||
const lower = filter.toLocaleLowerCase();
|
||||
check = (item) => item.toLocaleLowerCase().includes(lower);
|
||||
}
|
||||
values = values.filter(item => check(item));
|
||||
if (!values.length && targetWidget.options.values.length) {
|
||||
console.warn("Filter for node " + node.id + " has filtered out all items", filter);
|
||||
}
|
||||
}
|
||||
let current_index = values.indexOf(targetWidget.value);
|
||||
let current_length = values.length;
|
||||
|
||||
switch (v) {
|
||||
case "increment":
|
||||
@ -50,11 +122,12 @@ export function addValueControlWidget(node, targetWidget, defaultValue = "random
|
||||
current_index = Math.max(0, current_index);
|
||||
current_index = Math.min(current_length - 1, current_index);
|
||||
if (current_index >= 0) {
|
||||
let value = targetWidget.options.values[current_index];
|
||||
let value = values[current_index];
|
||||
targetWidget.value = value;
|
||||
targetWidget.callback(value);
|
||||
}
|
||||
} else { //number
|
||||
} else {
|
||||
//number
|
||||
let min = targetWidget.options.min;
|
||||
let max = targetWidget.options.max;
|
||||
// limit to something that javascript can handle
|
||||
@ -77,186 +150,68 @@ export function addValueControlWidget(node, targetWidget, defaultValue = "random
|
||||
default:
|
||||
break;
|
||||
}
|
||||
/*check if values are over or under their respective
|
||||
* ranges and set them to min or max.*/
|
||||
if (targetWidget.value < min)
|
||||
targetWidget.value = min;
|
||||
/*check if values are over or under their respective
|
||||
* ranges and set them to min or max.*/
|
||||
if (targetWidget.value < min) targetWidget.value = min;
|
||||
|
||||
if (targetWidget.value > max)
|
||||
targetWidget.value = max;
|
||||
targetWidget.callback(targetWidget.value);
|
||||
}
|
||||
}
|
||||
return valueControl;
|
||||
};
|
||||
return widgets;
|
||||
};
|
||||
|
||||
function seedWidget(node, inputName, inputData, app) {
|
||||
const seed = ComfyWidgets.INT(node, inputName, inputData, app);
|
||||
const seedControl = addValueControlWidget(node, seed.widget, "randomize");
|
||||
function seedWidget(node, inputName, inputData, app, widgetName) {
|
||||
const seed = createIntWidget(node, inputName, inputData, app, true);
|
||||
const seedControl = addValueControlWidget(node, seed.widget, "randomize", undefined, widgetName, inputData);
|
||||
|
||||
seed.widget.linkedWidgets = [seedControl];
|
||||
return seed;
|
||||
}
|
||||
|
||||
const MultilineSymbol = Symbol();
|
||||
const MultilineResizeSymbol = Symbol();
|
||||
function createIntWidget(node, inputName, inputData, app, isSeedInput) {
|
||||
const control = inputData[1]?.control_after_generate;
|
||||
if (!isSeedInput && control) {
|
||||
return seedWidget(node, inputName, inputData, app, typeof control === "string" ? control : undefined);
|
||||
}
|
||||
|
||||
let widgetType = isSlider(inputData[1]["display"], app);
|
||||
const { val, config } = getNumberDefaults(inputData, 1, 0, true);
|
||||
Object.assign(config, { precision: 0 });
|
||||
return {
|
||||
widget: node.addWidget(
|
||||
widgetType,
|
||||
inputName,
|
||||
val,
|
||||
function (v) {
|
||||
const s = this.options.step / 10;
|
||||
this.value = Math.round(v / s) * s;
|
||||
},
|
||||
config
|
||||
),
|
||||
};
|
||||
}
|
||||
|
||||
function addMultilineWidget(node, name, opts, app) {
|
||||
const MIN_SIZE = 50;
|
||||
const inputEl = document.createElement("textarea");
|
||||
inputEl.className = "comfy-multiline-input";
|
||||
inputEl.value = opts.defaultVal;
|
||||
inputEl.placeholder = opts.placeholder || name;
|
||||
|
||||
function computeSize(size) {
|
||||
if (node.widgets[0].last_y == null) return;
|
||||
|
||||
let y = node.widgets[0].last_y;
|
||||
let freeSpace = size[1] - y;
|
||||
|
||||
// Compute the height of all non customtext widgets
|
||||
let widgetHeight = 0;
|
||||
const multi = [];
|
||||
for (let i = 0; i < node.widgets.length; i++) {
|
||||
const w = node.widgets[i];
|
||||
if (w.type === "customtext") {
|
||||
multi.push(w);
|
||||
} else {
|
||||
if (w.computeSize) {
|
||||
widgetHeight += w.computeSize()[1] + 4;
|
||||
} else {
|
||||
widgetHeight += LiteGraph.NODE_WIDGET_HEIGHT + 4;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// See how large each text input can be
|
||||
freeSpace -= widgetHeight;
|
||||
freeSpace /= multi.length + (!!node.imgs?.length);
|
||||
|
||||
if (freeSpace < MIN_SIZE) {
|
||||
// There isnt enough space for all the widgets, increase the size of the node
|
||||
freeSpace = MIN_SIZE;
|
||||
node.size[1] = y + widgetHeight + freeSpace * (multi.length + (!!node.imgs?.length));
|
||||
node.graph.setDirtyCanvas(true);
|
||||
}
|
||||
|
||||
// Position each of the widgets
|
||||
for (const w of node.widgets) {
|
||||
w.y = y;
|
||||
if (w.type === "customtext") {
|
||||
y += freeSpace;
|
||||
w.computedHeight = freeSpace - multi.length*4;
|
||||
} else if (w.computeSize) {
|
||||
y += w.computeSize()[1] + 4;
|
||||
} else {
|
||||
y += LiteGraph.NODE_WIDGET_HEIGHT + 4;
|
||||
}
|
||||
}
|
||||
|
||||
node.inputHeight = freeSpace;
|
||||
}
|
||||
|
||||
const widget = {
|
||||
type: "customtext",
|
||||
name,
|
||||
get value() {
|
||||
return this.inputEl.value;
|
||||
const widget = node.addDOMWidget(name, "customtext", inputEl, {
|
||||
getValue() {
|
||||
return inputEl.value;
|
||||
},
|
||||
set value(x) {
|
||||
this.inputEl.value = x;
|
||||
setValue(v) {
|
||||
inputEl.value = v;
|
||||
},
|
||||
draw: function (ctx, _, widgetWidth, y, widgetHeight) {
|
||||
if (!this.parent.inputHeight) {
|
||||
// If we are initially offscreen when created we wont have received a resize event
|
||||
// Calculate it here instead
|
||||
computeSize(node.size);
|
||||
}
|
||||
const visible = app.canvas.ds.scale > 0.5 && this.type === "customtext";
|
||||
const margin = 10;
|
||||
const elRect = ctx.canvas.getBoundingClientRect();
|
||||
const transform = new DOMMatrix()
|
||||
.scaleSelf(elRect.width / ctx.canvas.width, elRect.height / ctx.canvas.height)
|
||||
.multiplySelf(ctx.getTransform())
|
||||
.translateSelf(margin, margin + y);
|
||||
|
||||
const scale = new DOMMatrix().scaleSelf(transform.a, transform.d)
|
||||
Object.assign(this.inputEl.style, {
|
||||
transformOrigin: "0 0",
|
||||
transform: scale,
|
||||
left: `${transform.a + transform.e}px`,
|
||||
top: `${transform.d + transform.f}px`,
|
||||
width: `${widgetWidth - (margin * 2)}px`,
|
||||
height: `${this.parent.inputHeight - (margin * 2)}px`,
|
||||
position: "absolute",
|
||||
background: (!node.color)?'':node.color,
|
||||
color: (!node.color)?'':'white',
|
||||
zIndex: app.graph._nodes.indexOf(node),
|
||||
});
|
||||
this.inputEl.hidden = !visible;
|
||||
},
|
||||
};
|
||||
widget.inputEl = document.createElement("textarea");
|
||||
widget.inputEl.className = "comfy-multiline-input";
|
||||
widget.inputEl.value = opts.defaultVal;
|
||||
widget.inputEl.placeholder = opts.placeholder || "";
|
||||
document.addEventListener("mousedown", function (event) {
|
||||
if (!widget.inputEl.contains(event.target)) {
|
||||
widget.inputEl.blur();
|
||||
}
|
||||
});
|
||||
widget.parent = node;
|
||||
document.body.appendChild(widget.inputEl);
|
||||
widget.inputEl = inputEl;
|
||||
|
||||
node.addCustomWidget(widget);
|
||||
|
||||
app.canvas.onDrawBackground = function () {
|
||||
// Draw node isnt fired once the node is off the screen
|
||||
// if it goes off screen quickly, the input may not be removed
|
||||
// this shifts it off screen so it can be moved back if the node is visible.
|
||||
for (let n in app.graph._nodes) {
|
||||
n = graph._nodes[n];
|
||||
for (let w in n.widgets) {
|
||||
let wid = n.widgets[w];
|
||||
if (Object.hasOwn(wid, "inputEl")) {
|
||||
wid.inputEl.style.left = -8000 + "px";
|
||||
wid.inputEl.style.position = "absolute";
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
node.onRemoved = function () {
|
||||
// When removing this node we need to remove the input from the DOM
|
||||
for (let y in this.widgets) {
|
||||
if (this.widgets[y].inputEl) {
|
||||
this.widgets[y].inputEl.remove();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
widget.onRemove = () => {
|
||||
widget.inputEl?.remove();
|
||||
|
||||
// Restore original size handler if we are the last
|
||||
if (!--node[MultilineSymbol]) {
|
||||
node.onResize = node[MultilineResizeSymbol];
|
||||
delete node[MultilineSymbol];
|
||||
delete node[MultilineResizeSymbol];
|
||||
}
|
||||
};
|
||||
|
||||
if (node[MultilineSymbol]) {
|
||||
node[MultilineSymbol]++;
|
||||
} else {
|
||||
node[MultilineSymbol] = 1;
|
||||
const onResize = (node[MultilineResizeSymbol] = node.onResize);
|
||||
|
||||
node.onResize = function (size) {
|
||||
computeSize(size);
|
||||
|
||||
// Call original resizer handler
|
||||
if (onResize) {
|
||||
onResize.apply(this, arguments);
|
||||
}
|
||||
};
|
||||
}
|
||||
inputEl.addEventListener("input", () => {
|
||||
widget.callback?.(widget.value);
|
||||
});
|
||||
|
||||
return { minWidth: 400, minHeight: 200, widget };
|
||||
}
|
||||
@ -288,31 +243,26 @@ export const ComfyWidgets = {
|
||||
}, config) };
|
||||
},
|
||||
INT(node, inputName, inputData, app) {
|
||||
let widgetType = isSlider(inputData[1]["display"], app);
|
||||
const { val, config } = getNumberDefaults(inputData, 1, 0, true);
|
||||
Object.assign(config, { precision: 0 });
|
||||
return {
|
||||
widget: node.addWidget(
|
||||
widgetType,
|
||||
inputName,
|
||||
val,
|
||||
function (v) {
|
||||
const s = this.options.step / 10;
|
||||
this.value = Math.round(v / s) * s;
|
||||
},
|
||||
config
|
||||
),
|
||||
};
|
||||
return createIntWidget(node, inputName, inputData, app);
|
||||
},
|
||||
BOOLEAN(node, inputName, inputData) {
|
||||
let defaultVal = inputData[1]["default"];
|
||||
let defaultVal = false;
|
||||
let options = {};
|
||||
if (inputData[1]) {
|
||||
if (inputData[1].default)
|
||||
defaultVal = inputData[1].default;
|
||||
if (inputData[1].label_on)
|
||||
options["on"] = inputData[1].label_on;
|
||||
if (inputData[1].label_off)
|
||||
options["off"] = inputData[1].label_off;
|
||||
}
|
||||
return {
|
||||
widget: node.addWidget(
|
||||
"toggle",
|
||||
inputName,
|
||||
defaultVal,
|
||||
() => {},
|
||||
{"on": inputData[1].label_on, "off": inputData[1].label_off}
|
||||
options,
|
||||
)
|
||||
};
|
||||
},
|
||||
@ -338,10 +288,14 @@ export const ComfyWidgets = {
|
||||
if (inputData[1] && inputData[1].default) {
|
||||
defaultValue = inputData[1].default;
|
||||
}
|
||||
return { widget: node.addWidget("combo", inputName, defaultValue, () => {}, { values: type }) };
|
||||
const res = { widget: node.addWidget("combo", inputName, defaultValue, () => {}, { values: type }) };
|
||||
if (inputData[1]?.control_after_generate) {
|
||||
res.widget.linkedWidgets = addValueControlWidgets(node, res.widget, undefined, undefined, inputData);
|
||||
}
|
||||
return res;
|
||||
},
|
||||
IMAGEUPLOAD(node, inputName, inputData, app) {
|
||||
const imageWidget = node.widgets.find((w) => w.name === "image");
|
||||
const imageWidget = node.widgets.find((w) => w.name === (inputData[1]?.widget ?? "image"));
|
||||
let uploadWidget;
|
||||
|
||||
function showImage(name) {
|
||||
@ -455,9 +409,10 @@ export const ComfyWidgets = {
|
||||
document.body.append(fileInput);
|
||||
|
||||
// Create the button widget for selecting the files
|
||||
uploadWidget = node.addWidget("button", "choose file to upload", "image", () => {
|
||||
uploadWidget = node.addWidget("button", inputName, "image", () => {
|
||||
fileInput.click();
|
||||
});
|
||||
uploadWidget.label = "choose file to upload";
|
||||
uploadWidget.serialize = false;
|
||||
|
||||
// Add handler to check if an image is being dragged over our node
|
||||
|
@ -409,6 +409,21 @@ dialog::backdrop {
|
||||
width: calc(100% - 10px);
|
||||
}
|
||||
|
||||
.comfy-img-preview {
|
||||
pointer-events: none;
|
||||
overflow: hidden;
|
||||
display: flex;
|
||||
flex-wrap: wrap;
|
||||
align-content: flex-start;
|
||||
justify-content: center;
|
||||
}
|
||||
|
||||
.comfy-img-preview img {
|
||||
object-fit: contain;
|
||||
width: var(--comfy-img-preview-width);
|
||||
height: var(--comfy-img-preview-height);
|
||||
}
|
||||
|
||||
/* Search box */
|
||||
|
||||
.litegraph.litesearchbox {
|
||||
|
Loading…
Reference in New Issue
Block a user