mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Merge remote-tracking branch 'origin/master' into custom_routes
This commit is contained in:
commit
38833ceb62
@ -78,7 +78,7 @@ class DDIMSampler(object):
|
|||||||
dynamic_threshold=None,
|
dynamic_threshold=None,
|
||||||
ucg_schedule=None,
|
ucg_schedule=None,
|
||||||
denoise_function=None,
|
denoise_function=None,
|
||||||
cond_concat=None,
|
extra_args=None,
|
||||||
to_zero=True,
|
to_zero=True,
|
||||||
end_step=None,
|
end_step=None,
|
||||||
**kwargs
|
**kwargs
|
||||||
@ -101,7 +101,7 @@ class DDIMSampler(object):
|
|||||||
dynamic_threshold=dynamic_threshold,
|
dynamic_threshold=dynamic_threshold,
|
||||||
ucg_schedule=ucg_schedule,
|
ucg_schedule=ucg_schedule,
|
||||||
denoise_function=denoise_function,
|
denoise_function=denoise_function,
|
||||||
cond_concat=cond_concat,
|
extra_args=extra_args,
|
||||||
to_zero=to_zero,
|
to_zero=to_zero,
|
||||||
end_step=end_step
|
end_step=end_step
|
||||||
)
|
)
|
||||||
@ -174,7 +174,7 @@ class DDIMSampler(object):
|
|||||||
dynamic_threshold=dynamic_threshold,
|
dynamic_threshold=dynamic_threshold,
|
||||||
ucg_schedule=ucg_schedule,
|
ucg_schedule=ucg_schedule,
|
||||||
denoise_function=None,
|
denoise_function=None,
|
||||||
cond_concat=None
|
extra_args=None
|
||||||
)
|
)
|
||||||
return samples, intermediates
|
return samples, intermediates
|
||||||
|
|
||||||
@ -185,7 +185,7 @@ class DDIMSampler(object):
|
|||||||
mask=None, x0=None, img_callback=None, log_every_t=100,
|
mask=None, x0=None, img_callback=None, log_every_t=100,
|
||||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||||
unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
|
unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
|
||||||
ucg_schedule=None, denoise_function=None, cond_concat=None, to_zero=True, end_step=None):
|
ucg_schedule=None, denoise_function=None, extra_args=None, to_zero=True, end_step=None):
|
||||||
device = self.model.betas.device
|
device = self.model.betas.device
|
||||||
b = shape[0]
|
b = shape[0]
|
||||||
if x_T is None:
|
if x_T is None:
|
||||||
@ -225,7 +225,7 @@ class DDIMSampler(object):
|
|||||||
corrector_kwargs=corrector_kwargs,
|
corrector_kwargs=corrector_kwargs,
|
||||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||||
unconditional_conditioning=unconditional_conditioning,
|
unconditional_conditioning=unconditional_conditioning,
|
||||||
dynamic_threshold=dynamic_threshold, denoise_function=denoise_function, cond_concat=cond_concat)
|
dynamic_threshold=dynamic_threshold, denoise_function=denoise_function, extra_args=extra_args)
|
||||||
img, pred_x0 = outs
|
img, pred_x0 = outs
|
||||||
if callback: callback(i)
|
if callback: callback(i)
|
||||||
if img_callback: img_callback(pred_x0, i)
|
if img_callback: img_callback(pred_x0, i)
|
||||||
@ -249,11 +249,11 @@ class DDIMSampler(object):
|
|||||||
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
||||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||||
unconditional_guidance_scale=1., unconditional_conditioning=None,
|
unconditional_guidance_scale=1., unconditional_conditioning=None,
|
||||||
dynamic_threshold=None, denoise_function=None, cond_concat=None):
|
dynamic_threshold=None, denoise_function=None, extra_args=None):
|
||||||
b, *_, device = *x.shape, x.device
|
b, *_, device = *x.shape, x.device
|
||||||
|
|
||||||
if denoise_function is not None:
|
if denoise_function is not None:
|
||||||
model_output = denoise_function(self.model.apply_model, x, t, unconditional_conditioning, c, unconditional_guidance_scale, cond_concat)
|
model_output = denoise_function(self.model.apply_model, x, t, **extra_args)
|
||||||
elif unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
elif unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||||
model_output = self.model.apply_model(x, t, c)
|
model_output = self.model.apply_model(x, t, c)
|
||||||
else:
|
else:
|
||||||
|
@ -1317,12 +1317,12 @@ class DiffusionWrapper(torch.nn.Module):
|
|||||||
self.conditioning_key = conditioning_key
|
self.conditioning_key = conditioning_key
|
||||||
assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm']
|
assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm']
|
||||||
|
|
||||||
def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None, control=None):
|
def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None, control=None, transformer_options={}):
|
||||||
if self.conditioning_key is None:
|
if self.conditioning_key is None:
|
||||||
out = self.diffusion_model(x, t, control=control)
|
out = self.diffusion_model(x, t, control=control, transformer_options=transformer_options)
|
||||||
elif self.conditioning_key == 'concat':
|
elif self.conditioning_key == 'concat':
|
||||||
xc = torch.cat([x] + c_concat, dim=1)
|
xc = torch.cat([x] + c_concat, dim=1)
|
||||||
out = self.diffusion_model(xc, t, control=control)
|
out = self.diffusion_model(xc, t, control=control, transformer_options=transformer_options)
|
||||||
elif self.conditioning_key == 'crossattn':
|
elif self.conditioning_key == 'crossattn':
|
||||||
if not self.sequential_cross_attn:
|
if not self.sequential_cross_attn:
|
||||||
cc = torch.cat(c_crossattn, 1)
|
cc = torch.cat(c_crossattn, 1)
|
||||||
@ -1332,25 +1332,25 @@ class DiffusionWrapper(torch.nn.Module):
|
|||||||
# TorchScript changes names of the arguments
|
# TorchScript changes names of the arguments
|
||||||
# with argument cc defined as context=cc scripted model will produce
|
# with argument cc defined as context=cc scripted model will produce
|
||||||
# an error: RuntimeError: forward() is missing value for argument 'argument_3'.
|
# an error: RuntimeError: forward() is missing value for argument 'argument_3'.
|
||||||
out = self.scripted_diffusion_model(x, t, cc, control=control)
|
out = self.scripted_diffusion_model(x, t, cc, control=control, transformer_options=transformer_options)
|
||||||
else:
|
else:
|
||||||
out = self.diffusion_model(x, t, context=cc, control=control)
|
out = self.diffusion_model(x, t, context=cc, control=control, transformer_options=transformer_options)
|
||||||
elif self.conditioning_key == 'hybrid':
|
elif self.conditioning_key == 'hybrid':
|
||||||
xc = torch.cat([x] + c_concat, dim=1)
|
xc = torch.cat([x] + c_concat, dim=1)
|
||||||
cc = torch.cat(c_crossattn, 1)
|
cc = torch.cat(c_crossattn, 1)
|
||||||
out = self.diffusion_model(xc, t, context=cc, control=control)
|
out = self.diffusion_model(xc, t, context=cc, control=control, transformer_options=transformer_options)
|
||||||
elif self.conditioning_key == 'hybrid-adm':
|
elif self.conditioning_key == 'hybrid-adm':
|
||||||
assert c_adm is not None
|
assert c_adm is not None
|
||||||
xc = torch.cat([x] + c_concat, dim=1)
|
xc = torch.cat([x] + c_concat, dim=1)
|
||||||
cc = torch.cat(c_crossattn, 1)
|
cc = torch.cat(c_crossattn, 1)
|
||||||
out = self.diffusion_model(xc, t, context=cc, y=c_adm, control=control)
|
out = self.diffusion_model(xc, t, context=cc, y=c_adm, control=control, transformer_options=transformer_options)
|
||||||
elif self.conditioning_key == 'crossattn-adm':
|
elif self.conditioning_key == 'crossattn-adm':
|
||||||
assert c_adm is not None
|
assert c_adm is not None
|
||||||
cc = torch.cat(c_crossattn, 1)
|
cc = torch.cat(c_crossattn, 1)
|
||||||
out = self.diffusion_model(x, t, context=cc, y=c_adm, control=control)
|
out = self.diffusion_model(x, t, context=cc, y=c_adm, control=control, transformer_options=transformer_options)
|
||||||
elif self.conditioning_key == 'adm':
|
elif self.conditioning_key == 'adm':
|
||||||
cc = c_crossattn[0]
|
cc = c_crossattn[0]
|
||||||
out = self.diffusion_model(x, t, y=cc, control=control)
|
out = self.diffusion_model(x, t, y=cc, control=control, transformer_options=transformer_options)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@ -11,6 +11,7 @@ from .sub_quadratic_attention import efficient_dot_product_attention
|
|||||||
|
|
||||||
import model_management
|
import model_management
|
||||||
|
|
||||||
|
from . import tomesd
|
||||||
|
|
||||||
if model_management.xformers_enabled():
|
if model_management.xformers_enabled():
|
||||||
import xformers
|
import xformers
|
||||||
@ -504,12 +505,22 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
self.norm3 = nn.LayerNorm(dim)
|
self.norm3 = nn.LayerNorm(dim)
|
||||||
self.checkpoint = checkpoint
|
self.checkpoint = checkpoint
|
||||||
|
|
||||||
def forward(self, x, context=None):
|
def forward(self, x, context=None, transformer_options={}):
|
||||||
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
|
return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint)
|
||||||
|
|
||||||
def _forward(self, x, context=None):
|
def _forward(self, x, context=None, transformer_options={}):
|
||||||
x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
|
n = self.norm1(x)
|
||||||
x = self.attn2(self.norm2(x), context=context) + x
|
if "tomesd" in transformer_options:
|
||||||
|
m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"])
|
||||||
|
n = u(self.attn1(m(n), context=context if self.disable_self_attn else None))
|
||||||
|
else:
|
||||||
|
n = self.attn1(n, context=context if self.disable_self_attn else None)
|
||||||
|
|
||||||
|
x += n
|
||||||
|
n = self.norm2(x)
|
||||||
|
n = self.attn2(n, context=context)
|
||||||
|
|
||||||
|
x += n
|
||||||
x = self.ff(self.norm3(x)) + x
|
x = self.ff(self.norm3(x)) + x
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@ -557,7 +568,7 @@ class SpatialTransformer(nn.Module):
|
|||||||
self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
|
self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
|
||||||
self.use_linear = use_linear
|
self.use_linear = use_linear
|
||||||
|
|
||||||
def forward(self, x, context=None):
|
def forward(self, x, context=None, transformer_options={}):
|
||||||
# note: if no context is given, cross-attention defaults to self-attention
|
# note: if no context is given, cross-attention defaults to self-attention
|
||||||
if not isinstance(context, list):
|
if not isinstance(context, list):
|
||||||
context = [context]
|
context = [context]
|
||||||
@ -570,7 +581,7 @@ class SpatialTransformer(nn.Module):
|
|||||||
if self.use_linear:
|
if self.use_linear:
|
||||||
x = self.proj_in(x)
|
x = self.proj_in(x)
|
||||||
for i, block in enumerate(self.transformer_blocks):
|
for i, block in enumerate(self.transformer_blocks):
|
||||||
x = block(x, context=context[i])
|
x = block(x, context=context[i], transformer_options=transformer_options)
|
||||||
if self.use_linear:
|
if self.use_linear:
|
||||||
x = self.proj_out(x)
|
x = self.proj_out(x)
|
||||||
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
|
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
|
||||||
|
@ -76,12 +76,12 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
|||||||
support it as an extra input.
|
support it as an extra input.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def forward(self, x, emb, context=None):
|
def forward(self, x, emb, context=None, transformer_options={}):
|
||||||
for layer in self:
|
for layer in self:
|
||||||
if isinstance(layer, TimestepBlock):
|
if isinstance(layer, TimestepBlock):
|
||||||
x = layer(x, emb)
|
x = layer(x, emb)
|
||||||
elif isinstance(layer, SpatialTransformer):
|
elif isinstance(layer, SpatialTransformer):
|
||||||
x = layer(x, context)
|
x = layer(x, context, transformer_options)
|
||||||
else:
|
else:
|
||||||
x = layer(x)
|
x = layer(x)
|
||||||
return x
|
return x
|
||||||
@ -753,7 +753,7 @@ class UNetModel(nn.Module):
|
|||||||
self.middle_block.apply(convert_module_to_f32)
|
self.middle_block.apply(convert_module_to_f32)
|
||||||
self.output_blocks.apply(convert_module_to_f32)
|
self.output_blocks.apply(convert_module_to_f32)
|
||||||
|
|
||||||
def forward(self, x, timesteps=None, context=None, y=None, control=None, **kwargs):
|
def forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs):
|
||||||
"""
|
"""
|
||||||
Apply the model to an input batch.
|
Apply the model to an input batch.
|
||||||
:param x: an [N x C x ...] Tensor of inputs.
|
:param x: an [N x C x ...] Tensor of inputs.
|
||||||
@ -762,6 +762,7 @@ class UNetModel(nn.Module):
|
|||||||
:param y: an [N] Tensor of labels, if class-conditional.
|
:param y: an [N] Tensor of labels, if class-conditional.
|
||||||
:return: an [N x C x ...] Tensor of outputs.
|
:return: an [N x C x ...] Tensor of outputs.
|
||||||
"""
|
"""
|
||||||
|
transformer_options["original_shape"] = list(x.shape)
|
||||||
assert (y is not None) == (
|
assert (y is not None) == (
|
||||||
self.num_classes is not None
|
self.num_classes is not None
|
||||||
), "must specify y if and only if the model is class-conditional"
|
), "must specify y if and only if the model is class-conditional"
|
||||||
@ -775,13 +776,13 @@ class UNetModel(nn.Module):
|
|||||||
|
|
||||||
h = x.type(self.dtype)
|
h = x.type(self.dtype)
|
||||||
for id, module in enumerate(self.input_blocks):
|
for id, module in enumerate(self.input_blocks):
|
||||||
h = module(h, emb, context)
|
h = module(h, emb, context, transformer_options)
|
||||||
if control is not None and 'input' in control and len(control['input']) > 0:
|
if control is not None and 'input' in control and len(control['input']) > 0:
|
||||||
ctrl = control['input'].pop()
|
ctrl = control['input'].pop()
|
||||||
if ctrl is not None:
|
if ctrl is not None:
|
||||||
h += ctrl
|
h += ctrl
|
||||||
hs.append(h)
|
hs.append(h)
|
||||||
h = self.middle_block(h, emb, context)
|
h = self.middle_block(h, emb, context, transformer_options)
|
||||||
if control is not None and 'middle' in control and len(control['middle']) > 0:
|
if control is not None and 'middle' in control and len(control['middle']) > 0:
|
||||||
h += control['middle'].pop()
|
h += control['middle'].pop()
|
||||||
|
|
||||||
@ -793,7 +794,7 @@ class UNetModel(nn.Module):
|
|||||||
hsp += ctrl
|
hsp += ctrl
|
||||||
h = th.cat([h, hsp], dim=1)
|
h = th.cat([h, hsp], dim=1)
|
||||||
del hsp
|
del hsp
|
||||||
h = module(h, emb, context)
|
h = module(h, emb, context, transformer_options)
|
||||||
h = h.type(x.dtype)
|
h = h.type(x.dtype)
|
||||||
if self.predict_codebook_ids:
|
if self.predict_codebook_ids:
|
||||||
return self.id_predictor(h)
|
return self.id_predictor(h)
|
||||||
|
117
comfy/ldm/modules/tomesd.py
Normal file
117
comfy/ldm/modules/tomesd.py
Normal file
@ -0,0 +1,117 @@
|
|||||||
|
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from typing import Tuple, Callable
|
||||||
|
import math
|
||||||
|
|
||||||
|
def do_nothing(x: torch.Tensor, mode:str=None):
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def bipartite_soft_matching_random2d(metric: torch.Tensor,
|
||||||
|
w: int, h: int, sx: int, sy: int, r: int,
|
||||||
|
no_rand: bool = False) -> Tuple[Callable, Callable]:
|
||||||
|
"""
|
||||||
|
Partitions the tokens into src and dst and merges r tokens from src to dst.
|
||||||
|
Dst tokens are partitioned by choosing one randomy in each (sx, sy) region.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
- metric [B, N, C]: metric to use for similarity
|
||||||
|
- w: image width in tokens
|
||||||
|
- h: image height in tokens
|
||||||
|
- sx: stride in the x dimension for dst, must divide w
|
||||||
|
- sy: stride in the y dimension for dst, must divide h
|
||||||
|
- r: number of tokens to remove (by merging)
|
||||||
|
- no_rand: if true, disable randomness (use top left corner only)
|
||||||
|
"""
|
||||||
|
B, N, _ = metric.shape
|
||||||
|
|
||||||
|
if r <= 0:
|
||||||
|
return do_nothing, do_nothing
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
|
||||||
|
hsy, wsx = h // sy, w // sx
|
||||||
|
|
||||||
|
# For each sy by sx kernel, randomly assign one token to be dst and the rest src
|
||||||
|
idx_buffer = torch.zeros(1, hsy, wsx, sy*sx, 1, device=metric.device)
|
||||||
|
|
||||||
|
if no_rand:
|
||||||
|
rand_idx = torch.zeros(1, hsy, wsx, 1, 1, device=metric.device, dtype=torch.int64)
|
||||||
|
else:
|
||||||
|
rand_idx = torch.randint(sy*sx, size=(1, hsy, wsx, 1, 1), device=metric.device)
|
||||||
|
|
||||||
|
idx_buffer.scatter_(dim=3, index=rand_idx, src=-torch.ones_like(rand_idx, dtype=idx_buffer.dtype))
|
||||||
|
idx_buffer = idx_buffer.view(1, hsy, wsx, sy, sx, 1).transpose(2, 3).reshape(1, N, 1)
|
||||||
|
rand_idx = idx_buffer.argsort(dim=1)
|
||||||
|
|
||||||
|
num_dst = int((1 / (sx*sy)) * N)
|
||||||
|
a_idx = rand_idx[:, num_dst:, :] # src
|
||||||
|
b_idx = rand_idx[:, :num_dst, :] # dst
|
||||||
|
|
||||||
|
def split(x):
|
||||||
|
C = x.shape[-1]
|
||||||
|
src = x.gather(dim=1, index=a_idx.expand(B, N - num_dst, C))
|
||||||
|
dst = x.gather(dim=1, index=b_idx.expand(B, num_dst, C))
|
||||||
|
return src, dst
|
||||||
|
|
||||||
|
metric = metric / metric.norm(dim=-1, keepdim=True)
|
||||||
|
a, b = split(metric)
|
||||||
|
scores = a @ b.transpose(-1, -2)
|
||||||
|
|
||||||
|
# Can't reduce more than the # tokens in src
|
||||||
|
r = min(a.shape[1], r)
|
||||||
|
|
||||||
|
node_max, node_idx = scores.max(dim=-1)
|
||||||
|
edge_idx = node_max.argsort(dim=-1, descending=True)[..., None]
|
||||||
|
|
||||||
|
unm_idx = edge_idx[..., r:, :] # Unmerged Tokens
|
||||||
|
src_idx = edge_idx[..., :r, :] # Merged Tokens
|
||||||
|
dst_idx = node_idx[..., None].gather(dim=-2, index=src_idx)
|
||||||
|
|
||||||
|
def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
|
||||||
|
src, dst = split(x)
|
||||||
|
n, t1, c = src.shape
|
||||||
|
|
||||||
|
unm = src.gather(dim=-2, index=unm_idx.expand(n, t1 - r, c))
|
||||||
|
src = src.gather(dim=-2, index=src_idx.expand(n, r, c))
|
||||||
|
dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode)
|
||||||
|
|
||||||
|
return torch.cat([unm, dst], dim=1)
|
||||||
|
|
||||||
|
def unmerge(x: torch.Tensor) -> torch.Tensor:
|
||||||
|
unm_len = unm_idx.shape[1]
|
||||||
|
unm, dst = x[..., :unm_len, :], x[..., unm_len:, :]
|
||||||
|
_, _, c = unm.shape
|
||||||
|
|
||||||
|
src = dst.gather(dim=-2, index=dst_idx.expand(B, r, c))
|
||||||
|
|
||||||
|
# Combine back to the original shape
|
||||||
|
out = torch.zeros(B, N, c, device=x.device, dtype=x.dtype)
|
||||||
|
out.scatter_(dim=-2, index=b_idx.expand(B, num_dst, c), src=dst)
|
||||||
|
out.scatter_(dim=-2, index=a_idx.expand(B, a_idx.shape[1], 1).gather(dim=1, index=unm_idx).expand(B, unm_len, c), src=unm)
|
||||||
|
out.scatter_(dim=-2, index=a_idx.expand(B, a_idx.shape[1], 1).gather(dim=1, index=src_idx).expand(B, r, c), src=src)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
return merge, unmerge
|
||||||
|
|
||||||
|
|
||||||
|
def get_functions(x, ratio, original_shape):
|
||||||
|
b, c, original_h, original_w = original_shape
|
||||||
|
original_tokens = original_h * original_w
|
||||||
|
downsample = int(math.sqrt(original_tokens // x.shape[1]))
|
||||||
|
stride_x = 2
|
||||||
|
stride_y = 2
|
||||||
|
max_downsample = 1
|
||||||
|
|
||||||
|
if downsample <= max_downsample:
|
||||||
|
w = original_w // downsample
|
||||||
|
h = original_h // downsample
|
||||||
|
r = int(x.shape[1] * ratio)
|
||||||
|
no_rand = False
|
||||||
|
m, u = bipartite_soft_matching_random2d(x, w, h, stride_x, stride_y, r, no_rand)
|
||||||
|
return m, u
|
||||||
|
|
||||||
|
nothing = lambda y: y
|
||||||
|
return nothing, nothing
|
@ -26,7 +26,7 @@ class CFGDenoiser(torch.nn.Module):
|
|||||||
|
|
||||||
#The main sampling function shared by all the samplers
|
#The main sampling function shared by all the samplers
|
||||||
#Returns predicted noise
|
#Returns predicted noise
|
||||||
def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, cond_concat=None):
|
def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, cond_concat=None, model_options={}):
|
||||||
def get_area_and_mult(cond, x_in, cond_concat_in, timestep_in):
|
def get_area_and_mult(cond, x_in, cond_concat_in, timestep_in):
|
||||||
area = (x_in.shape[2], x_in.shape[3], 0, 0)
|
area = (x_in.shape[2], x_in.shape[3], 0, 0)
|
||||||
strength = 1.0
|
strength = 1.0
|
||||||
@ -104,7 +104,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
|||||||
out['c_concat'] = [torch.cat(c_concat)]
|
out['c_concat'] = [torch.cat(c_concat)]
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, cond_concat_in):
|
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, cond_concat_in, model_options):
|
||||||
out_cond = torch.zeros_like(x_in)
|
out_cond = torch.zeros_like(x_in)
|
||||||
out_count = torch.ones_like(x_in)/100000.0
|
out_count = torch.ones_like(x_in)/100000.0
|
||||||
|
|
||||||
@ -169,6 +169,9 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
|||||||
if control is not None:
|
if control is not None:
|
||||||
c['control'] = control.get_control(input_x, timestep_, c['c_crossattn'], len(cond_or_uncond))
|
c['control'] = control.get_control(input_x, timestep_, c['c_crossattn'], len(cond_or_uncond))
|
||||||
|
|
||||||
|
if 'transformer_options' in model_options:
|
||||||
|
c['transformer_options'] = model_options['transformer_options']
|
||||||
|
|
||||||
output = model_function(input_x, timestep_, cond=c).chunk(batch_chunks)
|
output = model_function(input_x, timestep_, cond=c).chunk(batch_chunks)
|
||||||
del input_x
|
del input_x
|
||||||
|
|
||||||
@ -192,7 +195,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
|||||||
|
|
||||||
|
|
||||||
max_total_area = model_management.maximum_batch_area()
|
max_total_area = model_management.maximum_batch_area()
|
||||||
cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, cond_concat)
|
cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, cond_concat, model_options)
|
||||||
return uncond + (cond - uncond) * cond_scale
|
return uncond + (cond - uncond) * cond_scale
|
||||||
|
|
||||||
|
|
||||||
@ -209,8 +212,8 @@ class CFGNoisePredictor(torch.nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.inner_model = model
|
self.inner_model = model
|
||||||
self.alphas_cumprod = model.alphas_cumprod
|
self.alphas_cumprod = model.alphas_cumprod
|
||||||
def apply_model(self, x, timestep, cond, uncond, cond_scale, cond_concat=None):
|
def apply_model(self, x, timestep, cond, uncond, cond_scale, cond_concat=None, model_options={}):
|
||||||
out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, cond_concat)
|
out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, cond_concat, model_options=model_options)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
@ -218,11 +221,11 @@ class KSamplerX0Inpaint(torch.nn.Module):
|
|||||||
def __init__(self, model):
|
def __init__(self, model):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.inner_model = model
|
self.inner_model = model
|
||||||
def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, cond_concat=None):
|
def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, cond_concat=None, model_options={}):
|
||||||
if denoise_mask is not None:
|
if denoise_mask is not None:
|
||||||
latent_mask = 1. - denoise_mask
|
latent_mask = 1. - denoise_mask
|
||||||
x = x * denoise_mask + (self.latent_image + self.noise * sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1))) * latent_mask
|
x = x * denoise_mask + (self.latent_image + self.noise * sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1))) * latent_mask
|
||||||
out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, cond_concat=cond_concat)
|
out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, cond_concat=cond_concat, model_options=model_options)
|
||||||
if denoise_mask is not None:
|
if denoise_mask is not None:
|
||||||
out *= denoise_mask
|
out *= denoise_mask
|
||||||
|
|
||||||
@ -330,7 +333,7 @@ class KSampler:
|
|||||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde",
|
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde",
|
||||||
"dpmpp_2m", "ddim", "uni_pc", "uni_pc_bh2"]
|
"dpmpp_2m", "ddim", "uni_pc", "uni_pc_bh2"]
|
||||||
|
|
||||||
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None):
|
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.model_denoise = CFGNoisePredictor(self.model)
|
self.model_denoise = CFGNoisePredictor(self.model)
|
||||||
if self.model.parameterization == "v":
|
if self.model.parameterization == "v":
|
||||||
@ -350,6 +353,7 @@ class KSampler:
|
|||||||
self.sigma_max=float(self.model_wrap.sigma_max)
|
self.sigma_max=float(self.model_wrap.sigma_max)
|
||||||
self.set_steps(steps, denoise)
|
self.set_steps(steps, denoise)
|
||||||
self.denoise = denoise
|
self.denoise = denoise
|
||||||
|
self.model_options = model_options
|
||||||
|
|
||||||
def _calculate_sigmas(self, steps):
|
def _calculate_sigmas(self, steps):
|
||||||
sigmas = None
|
sigmas = None
|
||||||
@ -418,7 +422,7 @@ class KSampler:
|
|||||||
else:
|
else:
|
||||||
precision_scope = contextlib.nullcontext
|
precision_scope = contextlib.nullcontext
|
||||||
|
|
||||||
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg}
|
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": self.model_options}
|
||||||
|
|
||||||
cond_concat = None
|
cond_concat = None
|
||||||
if hasattr(self.model, 'concat_keys'):
|
if hasattr(self.model, 'concat_keys'):
|
||||||
@ -467,7 +471,7 @@ class KSampler:
|
|||||||
x_T=z_enc,
|
x_T=z_enc,
|
||||||
x0=latent_image,
|
x0=latent_image,
|
||||||
denoise_function=sampling_function,
|
denoise_function=sampling_function,
|
||||||
cond_concat=cond_concat,
|
extra_args=extra_args,
|
||||||
mask=noise_mask,
|
mask=noise_mask,
|
||||||
to_zero=sigmas[-1]==0,
|
to_zero=sigmas[-1]==0,
|
||||||
end_step=sigmas.shape[0] - 1)
|
end_step=sigmas.shape[0] - 1)
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
import contextlib
|
import contextlib
|
||||||
|
import copy
|
||||||
|
|
||||||
import sd1_clip
|
import sd1_clip
|
||||||
import sd2_clip
|
import sd2_clip
|
||||||
@ -274,12 +275,20 @@ class ModelPatcher:
|
|||||||
self.model = model
|
self.model = model
|
||||||
self.patches = []
|
self.patches = []
|
||||||
self.backup = {}
|
self.backup = {}
|
||||||
|
self.model_options = {"transformer_options":{}}
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
n = ModelPatcher(self.model)
|
n = ModelPatcher(self.model)
|
||||||
n.patches = self.patches[:]
|
n.patches = self.patches[:]
|
||||||
|
n.model_options = copy.deepcopy(self.model_options)
|
||||||
return n
|
return n
|
||||||
|
|
||||||
|
def set_model_tomesd(self, ratio):
|
||||||
|
self.model_options["transformer_options"]["tomesd"] = {"ratio": ratio}
|
||||||
|
|
||||||
|
def model_dtype(self):
|
||||||
|
return self.model.diffusion_model.dtype
|
||||||
|
|
||||||
def add_patches(self, patches, strength=1.0):
|
def add_patches(self, patches, strength=1.0):
|
||||||
p = {}
|
p = {}
|
||||||
model_sd = self.model.state_dict()
|
model_sd = self.model.state_dict()
|
||||||
|
19
nodes.py
19
nodes.py
@ -254,6 +254,22 @@ class LoraLoader:
|
|||||||
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora_path, strength_model, strength_clip)
|
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora_path, strength_model, strength_clip)
|
||||||
return (model_lora, clip_lora)
|
return (model_lora, clip_lora)
|
||||||
|
|
||||||
|
class TomePatchModel:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "model": ("MODEL",),
|
||||||
|
"ratio": ("FLOAT", {"default": 0.3, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("MODEL",)
|
||||||
|
FUNCTION = "patch"
|
||||||
|
|
||||||
|
CATEGORY = "_for_testing"
|
||||||
|
|
||||||
|
def patch(self, model, ratio):
|
||||||
|
m = model.clone()
|
||||||
|
m.set_model_tomesd(ratio)
|
||||||
|
return (m, )
|
||||||
|
|
||||||
class VAELoader:
|
class VAELoader:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -646,7 +662,7 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
|
|||||||
model_management.load_controlnet_gpu(control_net_models)
|
model_management.load_controlnet_gpu(control_net_models)
|
||||||
|
|
||||||
if sampler_name in comfy.samplers.KSampler.SAMPLERS:
|
if sampler_name in comfy.samplers.KSampler.SAMPLERS:
|
||||||
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise)
|
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
|
||||||
else:
|
else:
|
||||||
#other samplers
|
#other samplers
|
||||||
pass
|
pass
|
||||||
@ -1016,6 +1032,7 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"CLIPVisionLoader": CLIPVisionLoader,
|
"CLIPVisionLoader": CLIPVisionLoader,
|
||||||
"VAEDecodeTiled": VAEDecodeTiled,
|
"VAEDecodeTiled": VAEDecodeTiled,
|
||||||
"VAEEncodeTiled": VAEEncodeTiled,
|
"VAEEncodeTiled": VAEEncodeTiled,
|
||||||
|
"TomePatchModel": TomePatchModel,
|
||||||
}
|
}
|
||||||
|
|
||||||
def load_custom_node(module_path):
|
def load_custom_node(module_path):
|
||||||
|
@ -101,7 +101,7 @@ app.registerExtension({
|
|||||||
callback: () => convertToWidget(this, w),
|
callback: () => convertToWidget(this, w),
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
const config = nodeData?.input?.required[w.name] || [w.type, w.options || {}];
|
const config = nodeData?.input?.required[w.name] || nodeData?.input?.optional?.[w.name] || [w.type, w.options || {}];
|
||||||
if (isConvertableWidget(w, config)) {
|
if (isConvertableWidget(w, config)) {
|
||||||
toInput.push({
|
toInput.push({
|
||||||
content: `Convert ${w.name} to input`,
|
content: `Convert ${w.name} to input`,
|
||||||
|
Loading…
Reference in New Issue
Block a user