mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-03-12 22:02:14 +00:00
Support SVD img2vid model.
This commit is contained in:
parent
022033a0e7
commit
871cc20e13
@ -54,6 +54,7 @@ class ControlNet(nn.Module):
|
|||||||
transformer_depth_output=None,
|
transformer_depth_output=None,
|
||||||
device=None,
|
device=None,
|
||||||
operations=comfy.ops,
|
operations=comfy.ops,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert use_spatial_transformer == True, "use_spatial_transformer has to be true"
|
assert use_spatial_transformer == True, "use_spatial_transformer has to be true"
|
||||||
|
@ -5,8 +5,10 @@ import torch.nn.functional as F
|
|||||||
from torch import nn, einsum
|
from torch import nn, einsum
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
from typing import Optional, Any
|
from typing import Optional, Any
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
from .diffusionmodules.util import checkpoint
|
|
||||||
|
from .diffusionmodules.util import checkpoint, AlphaBlender, timestep_embedding
|
||||||
from .sub_quadratic_attention import efficient_dot_product_attention
|
from .sub_quadratic_attention import efficient_dot_product_attention
|
||||||
|
|
||||||
from comfy import model_management
|
from comfy import model_management
|
||||||
@ -370,21 +372,45 @@ class CrossAttention(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class BasicTransformerBlock(nn.Module):
|
class BasicTransformerBlock(nn.Module):
|
||||||
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
|
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, ff_in=False, inner_dim=None,
|
||||||
disable_self_attn=False, dtype=None, device=None, operations=comfy.ops):
|
disable_self_attn=False, disable_temporal_crossattention=False, switch_temporal_ca_to_sa=False, dtype=None, device=None, operations=comfy.ops):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
self.ff_in = ff_in or inner_dim is not None
|
||||||
|
if inner_dim is None:
|
||||||
|
inner_dim = dim
|
||||||
|
|
||||||
|
self.is_res = inner_dim == dim
|
||||||
|
|
||||||
|
if self.ff_in:
|
||||||
|
self.norm_in = nn.LayerNorm(dim, dtype=dtype, device=device)
|
||||||
|
self.ff_in = FeedForward(dim, dim_out=inner_dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
self.disable_self_attn = disable_self_attn
|
self.disable_self_attn = disable_self_attn
|
||||||
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
|
self.attn1 = CrossAttention(query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout,
|
||||||
context_dim=context_dim if self.disable_self_attn else None, dtype=dtype, device=device, operations=operations) # is a self-attention if not self.disable_self_attn
|
context_dim=context_dim if self.disable_self_attn else None, dtype=dtype, device=device, operations=operations) # is a self-attention if not self.disable_self_attn
|
||||||
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations)
|
self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations)
|
||||||
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
|
|
||||||
heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device, operations=operations) # is self-attn if context is none
|
if disable_temporal_crossattention:
|
||||||
self.norm1 = nn.LayerNorm(dim, dtype=dtype, device=device)
|
if switch_temporal_ca_to_sa:
|
||||||
self.norm2 = nn.LayerNorm(dim, dtype=dtype, device=device)
|
raise ValueError
|
||||||
self.norm3 = nn.LayerNorm(dim, dtype=dtype, device=device)
|
else:
|
||||||
|
self.attn2 = None
|
||||||
|
else:
|
||||||
|
context_dim_attn2 = None
|
||||||
|
if not switch_temporal_ca_to_sa:
|
||||||
|
context_dim_attn2 = context_dim
|
||||||
|
|
||||||
|
self.attn2 = CrossAttention(query_dim=inner_dim, context_dim=context_dim_attn2,
|
||||||
|
heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device, operations=operations) # is self-attn if context is none
|
||||||
|
self.norm2 = nn.LayerNorm(inner_dim, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
self.norm1 = nn.LayerNorm(inner_dim, dtype=dtype, device=device)
|
||||||
|
self.norm3 = nn.LayerNorm(inner_dim, dtype=dtype, device=device)
|
||||||
self.checkpoint = checkpoint
|
self.checkpoint = checkpoint
|
||||||
self.n_heads = n_heads
|
self.n_heads = n_heads
|
||||||
self.d_head = d_head
|
self.d_head = d_head
|
||||||
|
self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa
|
||||||
|
|
||||||
def forward(self, x, context=None, transformer_options={}):
|
def forward(self, x, context=None, transformer_options={}):
|
||||||
return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint)
|
return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint)
|
||||||
@ -418,6 +444,12 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
else:
|
else:
|
||||||
transformer_patches_replace = {}
|
transformer_patches_replace = {}
|
||||||
|
|
||||||
|
if self.ff_in:
|
||||||
|
x_skip = x
|
||||||
|
x = self.ff_in(self.norm_in(x))
|
||||||
|
if self.is_res:
|
||||||
|
x += x_skip
|
||||||
|
|
||||||
n = self.norm1(x)
|
n = self.norm1(x)
|
||||||
if self.disable_self_attn:
|
if self.disable_self_attn:
|
||||||
context_attn1 = context
|
context_attn1 = context
|
||||||
@ -465,31 +497,34 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
for p in patch:
|
for p in patch:
|
||||||
x = p(x, extra_options)
|
x = p(x, extra_options)
|
||||||
|
|
||||||
n = self.norm2(x)
|
if self.attn2 is not None:
|
||||||
|
n = self.norm2(x)
|
||||||
context_attn2 = context
|
if self.switch_temporal_ca_to_sa:
|
||||||
value_attn2 = None
|
context_attn2 = n
|
||||||
if "attn2_patch" in transformer_patches:
|
else:
|
||||||
patch = transformer_patches["attn2_patch"]
|
context_attn2 = context
|
||||||
value_attn2 = context_attn2
|
value_attn2 = None
|
||||||
for p in patch:
|
if "attn2_patch" in transformer_patches:
|
||||||
n, context_attn2, value_attn2 = p(n, context_attn2, value_attn2, extra_options)
|
patch = transformer_patches["attn2_patch"]
|
||||||
|
|
||||||
attn2_replace_patch = transformer_patches_replace.get("attn2", {})
|
|
||||||
block_attn2 = transformer_block
|
|
||||||
if block_attn2 not in attn2_replace_patch:
|
|
||||||
block_attn2 = block
|
|
||||||
|
|
||||||
if block_attn2 in attn2_replace_patch:
|
|
||||||
if value_attn2 is None:
|
|
||||||
value_attn2 = context_attn2
|
value_attn2 = context_attn2
|
||||||
n = self.attn2.to_q(n)
|
for p in patch:
|
||||||
context_attn2 = self.attn2.to_k(context_attn2)
|
n, context_attn2, value_attn2 = p(n, context_attn2, value_attn2, extra_options)
|
||||||
value_attn2 = self.attn2.to_v(value_attn2)
|
|
||||||
n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options)
|
attn2_replace_patch = transformer_patches_replace.get("attn2", {})
|
||||||
n = self.attn2.to_out(n)
|
block_attn2 = transformer_block
|
||||||
else:
|
if block_attn2 not in attn2_replace_patch:
|
||||||
n = self.attn2(n, context=context_attn2, value=value_attn2)
|
block_attn2 = block
|
||||||
|
|
||||||
|
if block_attn2 in attn2_replace_patch:
|
||||||
|
if value_attn2 is None:
|
||||||
|
value_attn2 = context_attn2
|
||||||
|
n = self.attn2.to_q(n)
|
||||||
|
context_attn2 = self.attn2.to_k(context_attn2)
|
||||||
|
value_attn2 = self.attn2.to_v(value_attn2)
|
||||||
|
n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options)
|
||||||
|
n = self.attn2.to_out(n)
|
||||||
|
else:
|
||||||
|
n = self.attn2(n, context=context_attn2, value=value_attn2)
|
||||||
|
|
||||||
if "attn2_output_patch" in transformer_patches:
|
if "attn2_output_patch" in transformer_patches:
|
||||||
patch = transformer_patches["attn2_output_patch"]
|
patch = transformer_patches["attn2_output_patch"]
|
||||||
@ -497,7 +532,12 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
n = p(n, extra_options)
|
n = p(n, extra_options)
|
||||||
|
|
||||||
x += n
|
x += n
|
||||||
x = self.ff(self.norm3(x)) + x
|
if self.is_res:
|
||||||
|
x_skip = x
|
||||||
|
x = self.ff(self.norm3(x))
|
||||||
|
if self.is_res:
|
||||||
|
x += x_skip
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@ -565,3 +605,164 @@ class SpatialTransformer(nn.Module):
|
|||||||
x = self.proj_out(x)
|
x = self.proj_out(x)
|
||||||
return x + x_in
|
return x + x_in
|
||||||
|
|
||||||
|
|
||||||
|
class SpatialVideoTransformer(SpatialTransformer):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
n_heads,
|
||||||
|
d_head,
|
||||||
|
depth=1,
|
||||||
|
dropout=0.0,
|
||||||
|
use_linear=False,
|
||||||
|
context_dim=None,
|
||||||
|
use_spatial_context=False,
|
||||||
|
timesteps=None,
|
||||||
|
merge_strategy: str = "fixed",
|
||||||
|
merge_factor: float = 0.5,
|
||||||
|
time_context_dim=None,
|
||||||
|
ff_in=False,
|
||||||
|
checkpoint=False,
|
||||||
|
time_depth=1,
|
||||||
|
disable_self_attn=False,
|
||||||
|
disable_temporal_crossattention=False,
|
||||||
|
max_time_embed_period: int = 10000,
|
||||||
|
dtype=None, device=None, operations=comfy.ops
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
in_channels,
|
||||||
|
n_heads,
|
||||||
|
d_head,
|
||||||
|
depth=depth,
|
||||||
|
dropout=dropout,
|
||||||
|
use_checkpoint=checkpoint,
|
||||||
|
context_dim=context_dim,
|
||||||
|
use_linear=use_linear,
|
||||||
|
disable_self_attn=disable_self_attn,
|
||||||
|
dtype=dtype, device=device, operations=operations
|
||||||
|
)
|
||||||
|
self.time_depth = time_depth
|
||||||
|
self.depth = depth
|
||||||
|
self.max_time_embed_period = max_time_embed_period
|
||||||
|
|
||||||
|
time_mix_d_head = d_head
|
||||||
|
n_time_mix_heads = n_heads
|
||||||
|
|
||||||
|
time_mix_inner_dim = int(time_mix_d_head * n_time_mix_heads)
|
||||||
|
|
||||||
|
inner_dim = n_heads * d_head
|
||||||
|
if use_spatial_context:
|
||||||
|
time_context_dim = context_dim
|
||||||
|
|
||||||
|
self.time_stack = nn.ModuleList(
|
||||||
|
[
|
||||||
|
BasicTransformerBlock(
|
||||||
|
inner_dim,
|
||||||
|
n_time_mix_heads,
|
||||||
|
time_mix_d_head,
|
||||||
|
dropout=dropout,
|
||||||
|
context_dim=time_context_dim,
|
||||||
|
# timesteps=timesteps,
|
||||||
|
checkpoint=checkpoint,
|
||||||
|
ff_in=ff_in,
|
||||||
|
inner_dim=time_mix_inner_dim,
|
||||||
|
disable_self_attn=disable_self_attn,
|
||||||
|
disable_temporal_crossattention=disable_temporal_crossattention,
|
||||||
|
dtype=dtype, device=device, operations=operations
|
||||||
|
)
|
||||||
|
for _ in range(self.depth)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(self.time_stack) == len(self.transformer_blocks)
|
||||||
|
|
||||||
|
self.use_spatial_context = use_spatial_context
|
||||||
|
self.in_channels = in_channels
|
||||||
|
|
||||||
|
time_embed_dim = self.in_channels * 4
|
||||||
|
self.time_pos_embed = nn.Sequential(
|
||||||
|
operations.Linear(self.in_channels, time_embed_dim, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Linear(time_embed_dim, self.in_channels, dtype=dtype, device=device),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.time_mixer = AlphaBlender(
|
||||||
|
alpha=merge_factor, merge_strategy=merge_strategy
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
context: Optional[torch.Tensor] = None,
|
||||||
|
time_context: Optional[torch.Tensor] = None,
|
||||||
|
timesteps: Optional[int] = None,
|
||||||
|
image_only_indicator: Optional[torch.Tensor] = None,
|
||||||
|
transformer_options={}
|
||||||
|
) -> torch.Tensor:
|
||||||
|
_, _, h, w = x.shape
|
||||||
|
x_in = x
|
||||||
|
spatial_context = None
|
||||||
|
if exists(context):
|
||||||
|
spatial_context = context
|
||||||
|
|
||||||
|
if self.use_spatial_context:
|
||||||
|
assert (
|
||||||
|
context.ndim == 3
|
||||||
|
), f"n dims of spatial context should be 3 but are {context.ndim}"
|
||||||
|
|
||||||
|
if time_context is None:
|
||||||
|
time_context = context
|
||||||
|
time_context_first_timestep = time_context[::timesteps]
|
||||||
|
time_context = repeat(
|
||||||
|
time_context_first_timestep, "b ... -> (b n) ...", n=h * w
|
||||||
|
)
|
||||||
|
elif time_context is not None and not self.use_spatial_context:
|
||||||
|
time_context = repeat(time_context, "b ... -> (b n) ...", n=h * w)
|
||||||
|
if time_context.ndim == 2:
|
||||||
|
time_context = rearrange(time_context, "b c -> b 1 c")
|
||||||
|
|
||||||
|
x = self.norm(x)
|
||||||
|
if not self.use_linear:
|
||||||
|
x = self.proj_in(x)
|
||||||
|
x = rearrange(x, "b c h w -> b (h w) c")
|
||||||
|
if self.use_linear:
|
||||||
|
x = self.proj_in(x)
|
||||||
|
|
||||||
|
num_frames = torch.arange(timesteps, device=x.device)
|
||||||
|
num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
|
||||||
|
num_frames = rearrange(num_frames, "b t -> (b t)")
|
||||||
|
t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False, max_period=self.max_time_embed_period).to(x.dtype)
|
||||||
|
emb = self.time_pos_embed(t_emb)
|
||||||
|
emb = emb[:, None, :]
|
||||||
|
|
||||||
|
for it_, (block, mix_block) in enumerate(
|
||||||
|
zip(self.transformer_blocks, self.time_stack)
|
||||||
|
):
|
||||||
|
transformer_options["block_index"] = it_
|
||||||
|
x = block(
|
||||||
|
x,
|
||||||
|
context=spatial_context,
|
||||||
|
transformer_options=transformer_options,
|
||||||
|
)
|
||||||
|
|
||||||
|
x_mix = x
|
||||||
|
x_mix = x_mix + emb
|
||||||
|
|
||||||
|
B, S, C = x_mix.shape
|
||||||
|
x_mix = rearrange(x_mix, "(b t) s c -> (b s) t c", t=timesteps)
|
||||||
|
x_mix = mix_block(x_mix, context=time_context) #TODO: transformer_options
|
||||||
|
x_mix = rearrange(
|
||||||
|
x_mix, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps
|
||||||
|
)
|
||||||
|
|
||||||
|
x = self.time_mixer(x_spatial=x, x_temporal=x_mix, image_only_indicator=image_only_indicator)
|
||||||
|
|
||||||
|
if self.use_linear:
|
||||||
|
x = self.proj_out(x)
|
||||||
|
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
|
||||||
|
if not self.use_linear:
|
||||||
|
x = self.proj_out(x)
|
||||||
|
out = x + x_in
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@ -5,6 +5,8 @@ import numpy as np
|
|||||||
import torch as th
|
import torch as th
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from einops import rearrange
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
from .util import (
|
from .util import (
|
||||||
checkpoint,
|
checkpoint,
|
||||||
@ -12,8 +14,9 @@ from .util import (
|
|||||||
zero_module,
|
zero_module,
|
||||||
normalization,
|
normalization,
|
||||||
timestep_embedding,
|
timestep_embedding,
|
||||||
|
AlphaBlender,
|
||||||
)
|
)
|
||||||
from ..attention import SpatialTransformer
|
from ..attention import SpatialTransformer, SpatialVideoTransformer, default
|
||||||
from comfy.ldm.util import exists
|
from comfy.ldm.util import exists
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
|
|
||||||
@ -29,10 +32,15 @@ class TimestepBlock(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
#This is needed because accelerate makes a copy of transformer_options which breaks "current_index"
|
#This is needed because accelerate makes a copy of transformer_options which breaks "current_index"
|
||||||
def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None):
|
def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None, time_context=None, num_video_frames=None, image_only_indicator=None):
|
||||||
for layer in ts:
|
for layer in ts:
|
||||||
if isinstance(layer, TimestepBlock):
|
if isinstance(layer, VideoResBlock):
|
||||||
|
x = layer(x, emb, num_video_frames, image_only_indicator)
|
||||||
|
elif isinstance(layer, TimestepBlock):
|
||||||
x = layer(x, emb)
|
x = layer(x, emb)
|
||||||
|
elif isinstance(layer, SpatialVideoTransformer):
|
||||||
|
x = layer(x, context, time_context, num_video_frames, image_only_indicator, transformer_options)
|
||||||
|
transformer_options["current_index"] += 1
|
||||||
elif isinstance(layer, SpatialTransformer):
|
elif isinstance(layer, SpatialTransformer):
|
||||||
x = layer(x, context, transformer_options)
|
x = layer(x, context, transformer_options)
|
||||||
if "current_index" in transformer_options:
|
if "current_index" in transformer_options:
|
||||||
@ -145,6 +153,9 @@ class ResBlock(TimestepBlock):
|
|||||||
use_checkpoint=False,
|
use_checkpoint=False,
|
||||||
up=False,
|
up=False,
|
||||||
down=False,
|
down=False,
|
||||||
|
kernel_size=3,
|
||||||
|
exchange_temb_dims=False,
|
||||||
|
skip_t_emb=False,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
device=None,
|
device=None,
|
||||||
operations=comfy.ops
|
operations=comfy.ops
|
||||||
@ -157,11 +168,17 @@ class ResBlock(TimestepBlock):
|
|||||||
self.use_conv = use_conv
|
self.use_conv = use_conv
|
||||||
self.use_checkpoint = use_checkpoint
|
self.use_checkpoint = use_checkpoint
|
||||||
self.use_scale_shift_norm = use_scale_shift_norm
|
self.use_scale_shift_norm = use_scale_shift_norm
|
||||||
|
self.exchange_temb_dims = exchange_temb_dims
|
||||||
|
|
||||||
|
if isinstance(kernel_size, list):
|
||||||
|
padding = [k // 2 for k in kernel_size]
|
||||||
|
else:
|
||||||
|
padding = kernel_size // 2
|
||||||
|
|
||||||
self.in_layers = nn.Sequential(
|
self.in_layers = nn.Sequential(
|
||||||
nn.GroupNorm(32, channels, dtype=dtype, device=device),
|
nn.GroupNorm(32, channels, dtype=dtype, device=device),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
operations.conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=dtype, device=device),
|
operations.conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.updown = up or down
|
self.updown = up or down
|
||||||
@ -175,19 +192,24 @@ class ResBlock(TimestepBlock):
|
|||||||
else:
|
else:
|
||||||
self.h_upd = self.x_upd = nn.Identity()
|
self.h_upd = self.x_upd = nn.Identity()
|
||||||
|
|
||||||
self.emb_layers = nn.Sequential(
|
self.skip_t_emb = skip_t_emb
|
||||||
nn.SiLU(),
|
if self.skip_t_emb:
|
||||||
operations.Linear(
|
self.emb_layers = None
|
||||||
emb_channels,
|
self.exchange_temb_dims = False
|
||||||
2 * self.out_channels if use_scale_shift_norm else self.out_channels, dtype=dtype, device=device
|
else:
|
||||||
),
|
self.emb_layers = nn.Sequential(
|
||||||
)
|
nn.SiLU(),
|
||||||
|
operations.Linear(
|
||||||
|
emb_channels,
|
||||||
|
2 * self.out_channels if use_scale_shift_norm else self.out_channels, dtype=dtype, device=device
|
||||||
|
),
|
||||||
|
)
|
||||||
self.out_layers = nn.Sequential(
|
self.out_layers = nn.Sequential(
|
||||||
nn.GroupNorm(32, self.out_channels, dtype=dtype, device=device),
|
nn.GroupNorm(32, self.out_channels, dtype=dtype, device=device),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
nn.Dropout(p=dropout),
|
nn.Dropout(p=dropout),
|
||||||
zero_module(
|
zero_module(
|
||||||
operations.conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1, dtype=dtype, device=device)
|
operations.conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device)
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -195,7 +217,7 @@ class ResBlock(TimestepBlock):
|
|||||||
self.skip_connection = nn.Identity()
|
self.skip_connection = nn.Identity()
|
||||||
elif use_conv:
|
elif use_conv:
|
||||||
self.skip_connection = operations.conv_nd(
|
self.skip_connection = operations.conv_nd(
|
||||||
dims, channels, self.out_channels, 3, padding=1, dtype=dtype, device=device
|
dims, channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.skip_connection = operations.conv_nd(dims, channels, self.out_channels, 1, dtype=dtype, device=device)
|
self.skip_connection = operations.conv_nd(dims, channels, self.out_channels, 1, dtype=dtype, device=device)
|
||||||
@ -221,19 +243,110 @@ class ResBlock(TimestepBlock):
|
|||||||
h = in_conv(h)
|
h = in_conv(h)
|
||||||
else:
|
else:
|
||||||
h = self.in_layers(x)
|
h = self.in_layers(x)
|
||||||
emb_out = self.emb_layers(emb).type(h.dtype)
|
|
||||||
while len(emb_out.shape) < len(h.shape):
|
emb_out = None
|
||||||
emb_out = emb_out[..., None]
|
if not self.skip_t_emb:
|
||||||
|
emb_out = self.emb_layers(emb).type(h.dtype)
|
||||||
|
while len(emb_out.shape) < len(h.shape):
|
||||||
|
emb_out = emb_out[..., None]
|
||||||
if self.use_scale_shift_norm:
|
if self.use_scale_shift_norm:
|
||||||
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
|
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
|
||||||
scale, shift = th.chunk(emb_out, 2, dim=1)
|
h = out_norm(h)
|
||||||
h = out_norm(h) * (1 + scale) + shift
|
if emb_out is not None:
|
||||||
|
scale, shift = th.chunk(emb_out, 2, dim=1)
|
||||||
|
h *= (1 + scale)
|
||||||
|
h += shift
|
||||||
h = out_rest(h)
|
h = out_rest(h)
|
||||||
else:
|
else:
|
||||||
h = h + emb_out
|
if emb_out is not None:
|
||||||
|
if self.exchange_temb_dims:
|
||||||
|
emb_out = rearrange(emb_out, "b t c ... -> b c t ...")
|
||||||
|
h = h + emb_out
|
||||||
h = self.out_layers(h)
|
h = self.out_layers(h)
|
||||||
return self.skip_connection(x) + h
|
return self.skip_connection(x) + h
|
||||||
|
|
||||||
|
|
||||||
|
class VideoResBlock(ResBlock):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels: int,
|
||||||
|
emb_channels: int,
|
||||||
|
dropout: float,
|
||||||
|
video_kernel_size=3,
|
||||||
|
merge_strategy: str = "fixed",
|
||||||
|
merge_factor: float = 0.5,
|
||||||
|
out_channels=None,
|
||||||
|
use_conv: bool = False,
|
||||||
|
use_scale_shift_norm: bool = False,
|
||||||
|
dims: int = 2,
|
||||||
|
use_checkpoint: bool = False,
|
||||||
|
up: bool = False,
|
||||||
|
down: bool = False,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=comfy.ops
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
channels,
|
||||||
|
emb_channels,
|
||||||
|
dropout,
|
||||||
|
out_channels=out_channels,
|
||||||
|
use_conv=use_conv,
|
||||||
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
|
dims=dims,
|
||||||
|
use_checkpoint=use_checkpoint,
|
||||||
|
up=up,
|
||||||
|
down=down,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations
|
||||||
|
)
|
||||||
|
|
||||||
|
self.time_stack = ResBlock(
|
||||||
|
default(out_channels, channels),
|
||||||
|
emb_channels,
|
||||||
|
dropout=dropout,
|
||||||
|
dims=3,
|
||||||
|
out_channels=default(out_channels, channels),
|
||||||
|
use_scale_shift_norm=False,
|
||||||
|
use_conv=False,
|
||||||
|
up=False,
|
||||||
|
down=False,
|
||||||
|
kernel_size=video_kernel_size,
|
||||||
|
use_checkpoint=use_checkpoint,
|
||||||
|
exchange_temb_dims=True,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations
|
||||||
|
)
|
||||||
|
self.time_mixer = AlphaBlender(
|
||||||
|
alpha=merge_factor,
|
||||||
|
merge_strategy=merge_strategy,
|
||||||
|
rearrange_pattern="b t -> b 1 t 1 1",
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: th.Tensor,
|
||||||
|
emb: th.Tensor,
|
||||||
|
num_video_frames: int,
|
||||||
|
image_only_indicator = None,
|
||||||
|
) -> th.Tensor:
|
||||||
|
x = super().forward(x, emb)
|
||||||
|
|
||||||
|
x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames)
|
||||||
|
x = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames)
|
||||||
|
|
||||||
|
x = self.time_stack(
|
||||||
|
x, rearrange(emb, "(b t) ... -> b t ...", t=num_video_frames)
|
||||||
|
)
|
||||||
|
x = self.time_mixer(
|
||||||
|
x_spatial=x_mix, x_temporal=x, image_only_indicator=image_only_indicator
|
||||||
|
)
|
||||||
|
x = rearrange(x, "b c t h w -> (b t) c h w")
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
class Timestep(nn.Module):
|
class Timestep(nn.Module):
|
||||||
def __init__(self, dim):
|
def __init__(self, dim):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -310,6 +423,16 @@ class UNetModel(nn.Module):
|
|||||||
adm_in_channels=None,
|
adm_in_channels=None,
|
||||||
transformer_depth_middle=None,
|
transformer_depth_middle=None,
|
||||||
transformer_depth_output=None,
|
transformer_depth_output=None,
|
||||||
|
use_temporal_resblock=False,
|
||||||
|
use_temporal_attention=False,
|
||||||
|
time_context_dim=None,
|
||||||
|
extra_ff_mix_layer=False,
|
||||||
|
use_spatial_context=False,
|
||||||
|
merge_strategy=None,
|
||||||
|
merge_factor=0.0,
|
||||||
|
video_kernel_size=None,
|
||||||
|
disable_temporal_crossattention=False,
|
||||||
|
max_ddpm_temb_period=10000,
|
||||||
device=None,
|
device=None,
|
||||||
operations=comfy.ops,
|
operations=comfy.ops,
|
||||||
):
|
):
|
||||||
@ -364,8 +487,12 @@ class UNetModel(nn.Module):
|
|||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.num_head_channels = num_head_channels
|
self.num_head_channels = num_head_channels
|
||||||
self.num_heads_upsample = num_heads_upsample
|
self.num_heads_upsample = num_heads_upsample
|
||||||
|
self.use_temporal_resblocks = use_temporal_resblock
|
||||||
self.predict_codebook_ids = n_embed is not None
|
self.predict_codebook_ids = n_embed is not None
|
||||||
|
|
||||||
|
self.default_num_video_frames = None
|
||||||
|
self.default_image_only_indicator = None
|
||||||
|
|
||||||
time_embed_dim = model_channels * 4
|
time_embed_dim = model_channels * 4
|
||||||
self.time_embed = nn.Sequential(
|
self.time_embed = nn.Sequential(
|
||||||
operations.Linear(model_channels, time_embed_dim, dtype=self.dtype, device=device),
|
operations.Linear(model_channels, time_embed_dim, dtype=self.dtype, device=device),
|
||||||
@ -402,13 +529,104 @@ class UNetModel(nn.Module):
|
|||||||
input_block_chans = [model_channels]
|
input_block_chans = [model_channels]
|
||||||
ch = model_channels
|
ch = model_channels
|
||||||
ds = 1
|
ds = 1
|
||||||
|
|
||||||
|
def get_attention_layer(
|
||||||
|
ch,
|
||||||
|
num_heads,
|
||||||
|
dim_head,
|
||||||
|
depth=1,
|
||||||
|
context_dim=None,
|
||||||
|
use_checkpoint=False,
|
||||||
|
disable_self_attn=False,
|
||||||
|
):
|
||||||
|
if use_temporal_attention:
|
||||||
|
return SpatialVideoTransformer(
|
||||||
|
ch,
|
||||||
|
num_heads,
|
||||||
|
dim_head,
|
||||||
|
depth=depth,
|
||||||
|
context_dim=context_dim,
|
||||||
|
time_context_dim=time_context_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
ff_in=extra_ff_mix_layer,
|
||||||
|
use_spatial_context=use_spatial_context,
|
||||||
|
merge_strategy=merge_strategy,
|
||||||
|
merge_factor=merge_factor,
|
||||||
|
checkpoint=use_checkpoint,
|
||||||
|
use_linear=use_linear_in_transformer,
|
||||||
|
disable_self_attn=disable_self_attn,
|
||||||
|
disable_temporal_crossattention=disable_temporal_crossattention,
|
||||||
|
max_time_embed_period=max_ddpm_temb_period,
|
||||||
|
dtype=self.dtype, device=device, operations=operations
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return SpatialTransformer(
|
||||||
|
ch, num_heads, dim_head, depth=depth, context_dim=context_dim,
|
||||||
|
disable_self_attn=disable_self_attn, use_linear=use_linear_in_transformer,
|
||||||
|
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_resblock(
|
||||||
|
merge_factor,
|
||||||
|
merge_strategy,
|
||||||
|
video_kernel_size,
|
||||||
|
ch,
|
||||||
|
time_embed_dim,
|
||||||
|
dropout,
|
||||||
|
out_channels,
|
||||||
|
dims,
|
||||||
|
use_checkpoint,
|
||||||
|
use_scale_shift_norm,
|
||||||
|
down=False,
|
||||||
|
up=False,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=comfy.ops
|
||||||
|
):
|
||||||
|
if self.use_temporal_resblocks:
|
||||||
|
return VideoResBlock(
|
||||||
|
merge_factor=merge_factor,
|
||||||
|
merge_strategy=merge_strategy,
|
||||||
|
video_kernel_size=video_kernel_size,
|
||||||
|
channels=ch,
|
||||||
|
emb_channels=time_embed_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
out_channels=out_channels,
|
||||||
|
dims=dims,
|
||||||
|
use_checkpoint=use_checkpoint,
|
||||||
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
|
down=down,
|
||||||
|
up=up,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return ResBlock(
|
||||||
|
channels=ch,
|
||||||
|
emb_channels=time_embed_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
out_channels=out_channels,
|
||||||
|
use_checkpoint=use_checkpoint,
|
||||||
|
dims=dims,
|
||||||
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
|
down=down,
|
||||||
|
up=up,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations
|
||||||
|
)
|
||||||
|
|
||||||
for level, mult in enumerate(channel_mult):
|
for level, mult in enumerate(channel_mult):
|
||||||
for nr in range(self.num_res_blocks[level]):
|
for nr in range(self.num_res_blocks[level]):
|
||||||
layers = [
|
layers = [
|
||||||
ResBlock(
|
get_resblock(
|
||||||
ch,
|
merge_factor=merge_factor,
|
||||||
time_embed_dim,
|
merge_strategy=merge_strategy,
|
||||||
dropout,
|
video_kernel_size=video_kernel_size,
|
||||||
|
ch=ch,
|
||||||
|
time_embed_dim=time_embed_dim,
|
||||||
|
dropout=dropout,
|
||||||
out_channels=mult * model_channels,
|
out_channels=mult * model_channels,
|
||||||
dims=dims,
|
dims=dims,
|
||||||
use_checkpoint=use_checkpoint,
|
use_checkpoint=use_checkpoint,
|
||||||
@ -435,11 +653,9 @@ class UNetModel(nn.Module):
|
|||||||
disabled_sa = False
|
disabled_sa = False
|
||||||
|
|
||||||
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
|
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
|
||||||
layers.append(SpatialTransformer(
|
layers.append(get_attention_layer(
|
||||||
ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
|
ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
|
||||||
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
disable_self_attn=disabled_sa, use_checkpoint=use_checkpoint)
|
||||||
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||||
self._feature_size += ch
|
self._feature_size += ch
|
||||||
@ -448,10 +664,13 @@ class UNetModel(nn.Module):
|
|||||||
out_ch = ch
|
out_ch = ch
|
||||||
self.input_blocks.append(
|
self.input_blocks.append(
|
||||||
TimestepEmbedSequential(
|
TimestepEmbedSequential(
|
||||||
ResBlock(
|
get_resblock(
|
||||||
ch,
|
merge_factor=merge_factor,
|
||||||
time_embed_dim,
|
merge_strategy=merge_strategy,
|
||||||
dropout,
|
video_kernel_size=video_kernel_size,
|
||||||
|
ch=ch,
|
||||||
|
time_embed_dim=time_embed_dim,
|
||||||
|
dropout=dropout,
|
||||||
out_channels=out_ch,
|
out_channels=out_ch,
|
||||||
dims=dims,
|
dims=dims,
|
||||||
use_checkpoint=use_checkpoint,
|
use_checkpoint=use_checkpoint,
|
||||||
@ -481,10 +700,14 @@ class UNetModel(nn.Module):
|
|||||||
#num_heads = 1
|
#num_heads = 1
|
||||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||||
mid_block = [
|
mid_block = [
|
||||||
ResBlock(
|
get_resblock(
|
||||||
ch,
|
merge_factor=merge_factor,
|
||||||
time_embed_dim,
|
merge_strategy=merge_strategy,
|
||||||
dropout,
|
video_kernel_size=video_kernel_size,
|
||||||
|
ch=ch,
|
||||||
|
time_embed_dim=time_embed_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
out_channels=None,
|
||||||
dims=dims,
|
dims=dims,
|
||||||
use_checkpoint=use_checkpoint,
|
use_checkpoint=use_checkpoint,
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
@ -493,15 +716,18 @@ class UNetModel(nn.Module):
|
|||||||
operations=operations
|
operations=operations
|
||||||
)]
|
)]
|
||||||
if transformer_depth_middle >= 0:
|
if transformer_depth_middle >= 0:
|
||||||
mid_block += [SpatialTransformer( # always uses a self-attn
|
mid_block += [get_attention_layer( # always uses a self-attn
|
||||||
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
|
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
|
||||||
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
|
disable_self_attn=disable_middle_self_attn, use_checkpoint=use_checkpoint
|
||||||
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
|
|
||||||
),
|
),
|
||||||
ResBlock(
|
get_resblock(
|
||||||
ch,
|
merge_factor=merge_factor,
|
||||||
time_embed_dim,
|
merge_strategy=merge_strategy,
|
||||||
dropout,
|
video_kernel_size=video_kernel_size,
|
||||||
|
ch=ch,
|
||||||
|
time_embed_dim=time_embed_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
out_channels=None,
|
||||||
dims=dims,
|
dims=dims,
|
||||||
use_checkpoint=use_checkpoint,
|
use_checkpoint=use_checkpoint,
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
@ -517,10 +743,13 @@ class UNetModel(nn.Module):
|
|||||||
for i in range(self.num_res_blocks[level] + 1):
|
for i in range(self.num_res_blocks[level] + 1):
|
||||||
ich = input_block_chans.pop()
|
ich = input_block_chans.pop()
|
||||||
layers = [
|
layers = [
|
||||||
ResBlock(
|
get_resblock(
|
||||||
ch + ich,
|
merge_factor=merge_factor,
|
||||||
time_embed_dim,
|
merge_strategy=merge_strategy,
|
||||||
dropout,
|
video_kernel_size=video_kernel_size,
|
||||||
|
ch=ch + ich,
|
||||||
|
time_embed_dim=time_embed_dim,
|
||||||
|
dropout=dropout,
|
||||||
out_channels=model_channels * mult,
|
out_channels=model_channels * mult,
|
||||||
dims=dims,
|
dims=dims,
|
||||||
use_checkpoint=use_checkpoint,
|
use_checkpoint=use_checkpoint,
|
||||||
@ -548,19 +777,21 @@ class UNetModel(nn.Module):
|
|||||||
|
|
||||||
if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
|
if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
|
||||||
layers.append(
|
layers.append(
|
||||||
SpatialTransformer(
|
get_attention_layer(
|
||||||
ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
|
ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
|
||||||
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
disable_self_attn=disabled_sa, use_checkpoint=use_checkpoint
|
||||||
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if level and i == self.num_res_blocks[level]:
|
if level and i == self.num_res_blocks[level]:
|
||||||
out_ch = ch
|
out_ch = ch
|
||||||
layers.append(
|
layers.append(
|
||||||
ResBlock(
|
get_resblock(
|
||||||
ch,
|
merge_factor=merge_factor,
|
||||||
time_embed_dim,
|
merge_strategy=merge_strategy,
|
||||||
dropout,
|
video_kernel_size=video_kernel_size,
|
||||||
|
ch=ch,
|
||||||
|
time_embed_dim=time_embed_dim,
|
||||||
|
dropout=dropout,
|
||||||
out_channels=out_ch,
|
out_channels=out_ch,
|
||||||
dims=dims,
|
dims=dims,
|
||||||
use_checkpoint=use_checkpoint,
|
use_checkpoint=use_checkpoint,
|
||||||
@ -602,6 +833,10 @@ class UNetModel(nn.Module):
|
|||||||
transformer_options["current_index"] = 0
|
transformer_options["current_index"] = 0
|
||||||
transformer_patches = transformer_options.get("patches", {})
|
transformer_patches = transformer_options.get("patches", {})
|
||||||
|
|
||||||
|
num_video_frames = kwargs.get("num_video_frames", self.default_num_video_frames)
|
||||||
|
image_only_indicator = kwargs.get("image_only_indicator", self.default_image_only_indicator)
|
||||||
|
time_context = kwargs.get("time_context", None)
|
||||||
|
|
||||||
assert (y is not None) == (
|
assert (y is not None) == (
|
||||||
self.num_classes is not None
|
self.num_classes is not None
|
||||||
), "must specify y if and only if the model is class-conditional"
|
), "must specify y if and only if the model is class-conditional"
|
||||||
@ -616,7 +851,7 @@ class UNetModel(nn.Module):
|
|||||||
h = x.type(self.dtype)
|
h = x.type(self.dtype)
|
||||||
for id, module in enumerate(self.input_blocks):
|
for id, module in enumerate(self.input_blocks):
|
||||||
transformer_options["block"] = ("input", id)
|
transformer_options["block"] = ("input", id)
|
||||||
h = forward_timestep_embed(module, h, emb, context, transformer_options)
|
h = forward_timestep_embed(module, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
|
||||||
h = apply_control(h, control, 'input')
|
h = apply_control(h, control, 'input')
|
||||||
if "input_block_patch" in transformer_patches:
|
if "input_block_patch" in transformer_patches:
|
||||||
patch = transformer_patches["input_block_patch"]
|
patch = transformer_patches["input_block_patch"]
|
||||||
@ -630,9 +865,10 @@ class UNetModel(nn.Module):
|
|||||||
h = p(h, transformer_options)
|
h = p(h, transformer_options)
|
||||||
|
|
||||||
transformer_options["block"] = ("middle", 0)
|
transformer_options["block"] = ("middle", 0)
|
||||||
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options)
|
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
|
||||||
h = apply_control(h, control, 'middle')
|
h = apply_control(h, control, 'middle')
|
||||||
|
|
||||||
|
|
||||||
for id, module in enumerate(self.output_blocks):
|
for id, module in enumerate(self.output_blocks):
|
||||||
transformer_options["block"] = ("output", id)
|
transformer_options["block"] = ("output", id)
|
||||||
hsp = hs.pop()
|
hsp = hs.pop()
|
||||||
@ -649,7 +885,7 @@ class UNetModel(nn.Module):
|
|||||||
output_shape = hs[-1].shape
|
output_shape = hs[-1].shape
|
||||||
else:
|
else:
|
||||||
output_shape = None
|
output_shape = None
|
||||||
h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape)
|
h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
|
||||||
h = h.type(x.dtype)
|
h = h.type(x.dtype)
|
||||||
if self.predict_codebook_ids:
|
if self.predict_codebook_ids:
|
||||||
return self.id_predictor(h)
|
return self.id_predictor(h)
|
||||||
|
@ -13,11 +13,78 @@ import math
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from einops import repeat
|
from einops import repeat, rearrange
|
||||||
|
|
||||||
from comfy.ldm.util import instantiate_from_config
|
from comfy.ldm.util import instantiate_from_config
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
|
|
||||||
|
class AlphaBlender(nn.Module):
|
||||||
|
strategies = ["learned", "fixed", "learned_with_images"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
alpha: float,
|
||||||
|
merge_strategy: str = "learned_with_images",
|
||||||
|
rearrange_pattern: str = "b t -> (b t) 1 1",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.merge_strategy = merge_strategy
|
||||||
|
self.rearrange_pattern = rearrange_pattern
|
||||||
|
|
||||||
|
assert (
|
||||||
|
merge_strategy in self.strategies
|
||||||
|
), f"merge_strategy needs to be in {self.strategies}"
|
||||||
|
|
||||||
|
if self.merge_strategy == "fixed":
|
||||||
|
self.register_buffer("mix_factor", torch.Tensor([alpha]))
|
||||||
|
elif (
|
||||||
|
self.merge_strategy == "learned"
|
||||||
|
or self.merge_strategy == "learned_with_images"
|
||||||
|
):
|
||||||
|
self.register_parameter(
|
||||||
|
"mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"unknown merge strategy {self.merge_strategy}")
|
||||||
|
|
||||||
|
def get_alpha(self, image_only_indicator: torch.Tensor) -> torch.Tensor:
|
||||||
|
# skip_time_mix = rearrange(repeat(skip_time_mix, 'b -> (b t) () () ()', t=t), '(b t) 1 ... -> b 1 t ...', t=t)
|
||||||
|
if self.merge_strategy == "fixed":
|
||||||
|
# make shape compatible
|
||||||
|
# alpha = repeat(self.mix_factor, '1 -> b () t () ()', t=t, b=bs)
|
||||||
|
alpha = self.mix_factor
|
||||||
|
elif self.merge_strategy == "learned":
|
||||||
|
alpha = torch.sigmoid(self.mix_factor)
|
||||||
|
# make shape compatible
|
||||||
|
# alpha = repeat(alpha, '1 -> s () ()', s = t * bs)
|
||||||
|
elif self.merge_strategy == "learned_with_images":
|
||||||
|
assert image_only_indicator is not None, "need image_only_indicator ..."
|
||||||
|
alpha = torch.where(
|
||||||
|
image_only_indicator.bool(),
|
||||||
|
torch.ones(1, 1, device=image_only_indicator.device),
|
||||||
|
rearrange(torch.sigmoid(self.mix_factor), "... -> ... 1"),
|
||||||
|
)
|
||||||
|
alpha = rearrange(alpha, self.rearrange_pattern)
|
||||||
|
# make shape compatible
|
||||||
|
# alpha = repeat(alpha, '1 -> s () ()', s = t * bs)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
return alpha
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x_spatial,
|
||||||
|
x_temporal,
|
||||||
|
image_only_indicator=None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
alpha = self.get_alpha(image_only_indicator)
|
||||||
|
x = (
|
||||||
|
alpha.to(x_spatial.dtype) * x_spatial
|
||||||
|
+ (1.0 - alpha).to(x_spatial.dtype) * x_temporal
|
||||||
|
)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
||||||
if schedule == "linear":
|
if schedule == "linear":
|
||||||
betas = (
|
betas = (
|
||||||
|
244
comfy/ldm/modules/temporal_ae.py
Normal file
244
comfy/ldm/modules/temporal_ae.py
Normal file
@ -0,0 +1,244 @@
|
|||||||
|
import functools
|
||||||
|
from typing import Callable, Iterable, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
|
||||||
|
import comfy.ops
|
||||||
|
|
||||||
|
from .diffusionmodules.model import (
|
||||||
|
AttnBlock,
|
||||||
|
Decoder,
|
||||||
|
ResnetBlock,
|
||||||
|
)
|
||||||
|
from .diffusionmodules.openaimodel import ResBlock, timestep_embedding
|
||||||
|
from .attention import BasicTransformerBlock
|
||||||
|
|
||||||
|
def partialclass(cls, *args, **kwargs):
|
||||||
|
class NewCls(cls):
|
||||||
|
__init__ = functools.partialmethod(cls.__init__, *args, **kwargs)
|
||||||
|
|
||||||
|
return NewCls
|
||||||
|
|
||||||
|
|
||||||
|
class VideoResBlock(ResnetBlock):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
out_channels,
|
||||||
|
*args,
|
||||||
|
dropout=0.0,
|
||||||
|
video_kernel_size=3,
|
||||||
|
alpha=0.0,
|
||||||
|
merge_strategy="learned",
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(out_channels=out_channels, dropout=dropout, *args, **kwargs)
|
||||||
|
if video_kernel_size is None:
|
||||||
|
video_kernel_size = [3, 1, 1]
|
||||||
|
self.time_stack = ResBlock(
|
||||||
|
channels=out_channels,
|
||||||
|
emb_channels=0,
|
||||||
|
dropout=dropout,
|
||||||
|
dims=3,
|
||||||
|
use_scale_shift_norm=False,
|
||||||
|
use_conv=False,
|
||||||
|
up=False,
|
||||||
|
down=False,
|
||||||
|
kernel_size=video_kernel_size,
|
||||||
|
use_checkpoint=False,
|
||||||
|
skip_t_emb=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.merge_strategy = merge_strategy
|
||||||
|
if self.merge_strategy == "fixed":
|
||||||
|
self.register_buffer("mix_factor", torch.Tensor([alpha]))
|
||||||
|
elif self.merge_strategy == "learned":
|
||||||
|
self.register_parameter(
|
||||||
|
"mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"unknown merge strategy {self.merge_strategy}")
|
||||||
|
|
||||||
|
def get_alpha(self, bs):
|
||||||
|
if self.merge_strategy == "fixed":
|
||||||
|
return self.mix_factor
|
||||||
|
elif self.merge_strategy == "learned":
|
||||||
|
return torch.sigmoid(self.mix_factor)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def forward(self, x, temb, skip_video=False, timesteps=None):
|
||||||
|
b, c, h, w = x.shape
|
||||||
|
if timesteps is None:
|
||||||
|
timesteps = b
|
||||||
|
|
||||||
|
x = super().forward(x, temb)
|
||||||
|
|
||||||
|
if not skip_video:
|
||||||
|
x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
|
||||||
|
|
||||||
|
x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
|
||||||
|
|
||||||
|
x = self.time_stack(x, temb)
|
||||||
|
|
||||||
|
alpha = self.get_alpha(bs=b // timesteps)
|
||||||
|
x = alpha * x + (1.0 - alpha) * x_mix
|
||||||
|
|
||||||
|
x = rearrange(x, "b c t h w -> (b t) c h w")
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class AE3DConv(torch.nn.Conv2d):
|
||||||
|
def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs):
|
||||||
|
super().__init__(in_channels, out_channels, *args, **kwargs)
|
||||||
|
if isinstance(video_kernel_size, Iterable):
|
||||||
|
padding = [int(k // 2) for k in video_kernel_size]
|
||||||
|
else:
|
||||||
|
padding = int(video_kernel_size // 2)
|
||||||
|
|
||||||
|
self.time_mix_conv = torch.nn.Conv3d(
|
||||||
|
in_channels=out_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=video_kernel_size,
|
||||||
|
padding=padding,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, input, timesteps=None, skip_video=False):
|
||||||
|
if timesteps is None:
|
||||||
|
timesteps = input.shape[0]
|
||||||
|
x = super().forward(input)
|
||||||
|
if skip_video:
|
||||||
|
return x
|
||||||
|
x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
|
||||||
|
x = self.time_mix_conv(x)
|
||||||
|
return rearrange(x, "b c t h w -> (b t) c h w")
|
||||||
|
|
||||||
|
|
||||||
|
class AttnVideoBlock(AttnBlock):
|
||||||
|
def __init__(
|
||||||
|
self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned"
|
||||||
|
):
|
||||||
|
super().__init__(in_channels)
|
||||||
|
# no context, single headed, as in base class
|
||||||
|
self.time_mix_block = BasicTransformerBlock(
|
||||||
|
dim=in_channels,
|
||||||
|
n_heads=1,
|
||||||
|
d_head=in_channels,
|
||||||
|
checkpoint=False,
|
||||||
|
ff_in=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
time_embed_dim = self.in_channels * 4
|
||||||
|
self.video_time_embed = torch.nn.Sequential(
|
||||||
|
comfy.ops.Linear(self.in_channels, time_embed_dim),
|
||||||
|
torch.nn.SiLU(),
|
||||||
|
comfy.ops.Linear(time_embed_dim, self.in_channels),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.merge_strategy = merge_strategy
|
||||||
|
if self.merge_strategy == "fixed":
|
||||||
|
self.register_buffer("mix_factor", torch.Tensor([alpha]))
|
||||||
|
elif self.merge_strategy == "learned":
|
||||||
|
self.register_parameter(
|
||||||
|
"mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"unknown merge strategy {self.merge_strategy}")
|
||||||
|
|
||||||
|
def forward(self, x, timesteps=None, skip_time_block=False):
|
||||||
|
if skip_time_block:
|
||||||
|
return super().forward(x)
|
||||||
|
|
||||||
|
if timesteps is None:
|
||||||
|
timesteps = x.shape[0]
|
||||||
|
|
||||||
|
x_in = x
|
||||||
|
x = self.attention(x)
|
||||||
|
h, w = x.shape[2:]
|
||||||
|
x = rearrange(x, "b c h w -> b (h w) c")
|
||||||
|
|
||||||
|
x_mix = x
|
||||||
|
num_frames = torch.arange(timesteps, device=x.device)
|
||||||
|
num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
|
||||||
|
num_frames = rearrange(num_frames, "b t -> (b t)")
|
||||||
|
t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False)
|
||||||
|
emb = self.video_time_embed(t_emb) # b, n_channels
|
||||||
|
emb = emb[:, None, :]
|
||||||
|
x_mix = x_mix + emb
|
||||||
|
|
||||||
|
alpha = self.get_alpha()
|
||||||
|
x_mix = self.time_mix_block(x_mix, timesteps=timesteps)
|
||||||
|
x = alpha * x + (1.0 - alpha) * x_mix # alpha merge
|
||||||
|
|
||||||
|
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
|
||||||
|
x = self.proj_out(x)
|
||||||
|
|
||||||
|
return x_in + x
|
||||||
|
|
||||||
|
def get_alpha(
|
||||||
|
self,
|
||||||
|
):
|
||||||
|
if self.merge_strategy == "fixed":
|
||||||
|
return self.mix_factor
|
||||||
|
elif self.merge_strategy == "learned":
|
||||||
|
return torch.sigmoid(self.mix_factor)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def make_time_attn(
|
||||||
|
in_channels,
|
||||||
|
attn_type="vanilla",
|
||||||
|
attn_kwargs=None,
|
||||||
|
alpha: float = 0,
|
||||||
|
merge_strategy: str = "learned",
|
||||||
|
):
|
||||||
|
return partialclass(
|
||||||
|
AttnVideoBlock, in_channels, alpha=alpha, merge_strategy=merge_strategy
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Conv2DWrapper(torch.nn.Conv2d):
|
||||||
|
def forward(self, input: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||||
|
return super().forward(input)
|
||||||
|
|
||||||
|
|
||||||
|
class VideoDecoder(Decoder):
|
||||||
|
available_time_modes = ["all", "conv-only", "attn-only"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
video_kernel_size: Union[int, list] = 3,
|
||||||
|
alpha: float = 0.0,
|
||||||
|
merge_strategy: str = "learned",
|
||||||
|
time_mode: str = "conv-only",
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.video_kernel_size = video_kernel_size
|
||||||
|
self.alpha = alpha
|
||||||
|
self.merge_strategy = merge_strategy
|
||||||
|
self.time_mode = time_mode
|
||||||
|
assert (
|
||||||
|
self.time_mode in self.available_time_modes
|
||||||
|
), f"time_mode parameter has to be in {self.available_time_modes}"
|
||||||
|
|
||||||
|
if self.time_mode != "attn-only":
|
||||||
|
kwargs["conv_out_op"] = partialclass(AE3DConv, video_kernel_size=self.video_kernel_size)
|
||||||
|
if self.time_mode not in ["conv-only", "only-last-conv"]:
|
||||||
|
kwargs["attn_op"] = partialclass(make_time_attn, alpha=self.alpha, merge_strategy=self.merge_strategy)
|
||||||
|
if self.time_mode not in ["attn-only", "only-last-conv"]:
|
||||||
|
kwargs["resnet_op"] = partialclass(VideoResBlock, video_kernel_size=self.video_kernel_size, alpha=self.alpha, merge_strategy=self.merge_strategy)
|
||||||
|
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def get_last_layer(self, skip_time_mix=False, **kwargs):
|
||||||
|
if self.time_mode == "attn-only":
|
||||||
|
raise NotImplementedError("TODO")
|
||||||
|
else:
|
||||||
|
return (
|
||||||
|
self.conv_out.time_mix_conv.weight
|
||||||
|
if not skip_time_mix
|
||||||
|
else self.conv_out.weight
|
||||||
|
)
|
@ -10,17 +10,22 @@ from . import utils
|
|||||||
class ModelType(Enum):
|
class ModelType(Enum):
|
||||||
EPS = 1
|
EPS = 1
|
||||||
V_PREDICTION = 2
|
V_PREDICTION = 2
|
||||||
|
V_PREDICTION_EDM = 3
|
||||||
|
|
||||||
|
|
||||||
from comfy.model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete
|
from comfy.model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete, ModelSamplingContinuousEDM
|
||||||
|
|
||||||
|
|
||||||
def model_sampling(model_config, model_type):
|
def model_sampling(model_config, model_type):
|
||||||
|
s = ModelSamplingDiscrete
|
||||||
|
|
||||||
if model_type == ModelType.EPS:
|
if model_type == ModelType.EPS:
|
||||||
c = EPS
|
c = EPS
|
||||||
elif model_type == ModelType.V_PREDICTION:
|
elif model_type == ModelType.V_PREDICTION:
|
||||||
c = V_PREDICTION
|
c = V_PREDICTION
|
||||||
|
elif model_type == ModelType.V_PREDICTION_EDM:
|
||||||
s = ModelSamplingDiscrete
|
c = V_PREDICTION
|
||||||
|
s = ModelSamplingContinuousEDM
|
||||||
|
|
||||||
class ModelSampling(s, c):
|
class ModelSampling(s, c):
|
||||||
pass
|
pass
|
||||||
@ -262,3 +267,48 @@ class SDXL(BaseModel):
|
|||||||
out.append(self.embedder(torch.Tensor([target_width])))
|
out.append(self.embedder(torch.Tensor([target_width])))
|
||||||
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1)
|
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1)
|
||||||
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
|
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
|
||||||
|
|
||||||
|
class SVD_img2vid(BaseModel):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.V_PREDICTION_EDM, device=None):
|
||||||
|
super().__init__(model_config, model_type, device=device)
|
||||||
|
self.embedder = Timestep(256)
|
||||||
|
|
||||||
|
def encode_adm(self, **kwargs):
|
||||||
|
fps_id = kwargs.get("fps", 6) - 1
|
||||||
|
motion_bucket_id = kwargs.get("motion_bucket_id", 127)
|
||||||
|
augmentation = kwargs.get("augmentation_level", 0)
|
||||||
|
|
||||||
|
out = []
|
||||||
|
out.append(self.embedder(torch.Tensor([fps_id])))
|
||||||
|
out.append(self.embedder(torch.Tensor([motion_bucket_id])))
|
||||||
|
out.append(self.embedder(torch.Tensor([augmentation])))
|
||||||
|
|
||||||
|
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0)
|
||||||
|
return flat
|
||||||
|
|
||||||
|
def extra_conds(self, **kwargs):
|
||||||
|
out = {}
|
||||||
|
adm = self.encode_adm(**kwargs)
|
||||||
|
if adm is not None:
|
||||||
|
out['y'] = comfy.conds.CONDRegular(adm)
|
||||||
|
|
||||||
|
latent_image = kwargs.get("concat_latent_image", None)
|
||||||
|
noise = kwargs.get("noise", None)
|
||||||
|
device = kwargs["device"]
|
||||||
|
|
||||||
|
if latent_image is None:
|
||||||
|
latent_image = torch.zeros_like(noise)
|
||||||
|
|
||||||
|
if latent_image.shape[1:] != noise.shape[1:]:
|
||||||
|
latent_image = utils.common_upscale(latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||||
|
|
||||||
|
latent_image = utils.repeat_to_batch_size(latent_image, noise.shape[0])
|
||||||
|
|
||||||
|
out['c_concat'] = comfy.conds.CONDNoiseShape(latent_image)
|
||||||
|
|
||||||
|
if "time_conditioning" in kwargs:
|
||||||
|
out["time_context"] = comfy.conds.CONDCrossAttn(kwargs["time_conditioning"])
|
||||||
|
|
||||||
|
out['image_only_indicator'] = comfy.conds.CONDConstant(torch.zeros((1,), device=device))
|
||||||
|
out['num_video_frames'] = comfy.conds.CONDConstant(noise.shape[0])
|
||||||
|
return out
|
||||||
|
@ -24,7 +24,8 @@ def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
|
|||||||
last_transformer_depth = count_blocks(state_dict_keys, transformer_prefix + '{}')
|
last_transformer_depth = count_blocks(state_dict_keys, transformer_prefix + '{}')
|
||||||
context_dim = state_dict['{}0.attn2.to_k.weight'.format(transformer_prefix)].shape[1]
|
context_dim = state_dict['{}0.attn2.to_k.weight'.format(transformer_prefix)].shape[1]
|
||||||
use_linear_in_transformer = len(state_dict['{}1.proj_in.weight'.format(prefix)].shape) == 2
|
use_linear_in_transformer = len(state_dict['{}1.proj_in.weight'.format(prefix)].shape) == 2
|
||||||
return last_transformer_depth, context_dim, use_linear_in_transformer
|
time_stack = '{}1.time_stack.0.attn1.to_q.weight'.format(prefix) in state_dict or '{}1.time_mix_blocks.0.attn1.to_q.weight'.format(prefix) in state_dict
|
||||||
|
return last_transformer_depth, context_dim, use_linear_in_transformer, time_stack
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def detect_unet_config(state_dict, key_prefix, dtype):
|
def detect_unet_config(state_dict, key_prefix, dtype):
|
||||||
@ -57,6 +58,7 @@ def detect_unet_config(state_dict, key_prefix, dtype):
|
|||||||
context_dim = None
|
context_dim = None
|
||||||
use_linear_in_transformer = False
|
use_linear_in_transformer = False
|
||||||
|
|
||||||
|
video_model = False
|
||||||
|
|
||||||
current_res = 1
|
current_res = 1
|
||||||
count = 0
|
count = 0
|
||||||
@ -99,6 +101,7 @@ def detect_unet_config(state_dict, key_prefix, dtype):
|
|||||||
if context_dim is None:
|
if context_dim is None:
|
||||||
context_dim = out[1]
|
context_dim = out[1]
|
||||||
use_linear_in_transformer = out[2]
|
use_linear_in_transformer = out[2]
|
||||||
|
video_model = out[3]
|
||||||
else:
|
else:
|
||||||
transformer_depth.append(0)
|
transformer_depth.append(0)
|
||||||
|
|
||||||
@ -127,6 +130,19 @@ def detect_unet_config(state_dict, key_prefix, dtype):
|
|||||||
unet_config["transformer_depth_middle"] = transformer_depth_middle
|
unet_config["transformer_depth_middle"] = transformer_depth_middle
|
||||||
unet_config['use_linear_in_transformer'] = use_linear_in_transformer
|
unet_config['use_linear_in_transformer'] = use_linear_in_transformer
|
||||||
unet_config["context_dim"] = context_dim
|
unet_config["context_dim"] = context_dim
|
||||||
|
|
||||||
|
if video_model:
|
||||||
|
unet_config["extra_ff_mix_layer"] = True
|
||||||
|
unet_config["use_spatial_context"] = True
|
||||||
|
unet_config["merge_strategy"] = "learned_with_images"
|
||||||
|
unet_config["merge_factor"] = 0.0
|
||||||
|
unet_config["video_kernel_size"] = [3, 1, 1]
|
||||||
|
unet_config["use_temporal_resblock"] = True
|
||||||
|
unet_config["use_temporal_attention"] = True
|
||||||
|
else:
|
||||||
|
unet_config["use_temporal_resblock"] = False
|
||||||
|
unet_config["use_temporal_attention"] = False
|
||||||
|
|
||||||
return unet_config
|
return unet_config
|
||||||
|
|
||||||
def model_config_from_unet_config(unet_config):
|
def model_config_from_unet_config(unet_config):
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
|
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
|
||||||
|
import math
|
||||||
|
|
||||||
class EPS:
|
class EPS:
|
||||||
def calculate_input(self, sigma, noise):
|
def calculate_input(self, sigma, noise):
|
||||||
@ -83,3 +83,47 @@ class ModelSamplingDiscrete(torch.nn.Module):
|
|||||||
percent = 1.0 - percent
|
percent = 1.0 - percent
|
||||||
return self.sigma(torch.tensor(percent * 999.0)).item()
|
return self.sigma(torch.tensor(percent * 999.0)).item()
|
||||||
|
|
||||||
|
|
||||||
|
class ModelSamplingContinuousEDM(torch.nn.Module):
|
||||||
|
def __init__(self, model_config=None):
|
||||||
|
super().__init__()
|
||||||
|
self.sigma_data = 1.0
|
||||||
|
|
||||||
|
if model_config is not None:
|
||||||
|
sampling_settings = model_config.sampling_settings
|
||||||
|
else:
|
||||||
|
sampling_settings = {}
|
||||||
|
|
||||||
|
sigma_min = sampling_settings.get("sigma_min", 0.002)
|
||||||
|
sigma_max = sampling_settings.get("sigma_max", 120.0)
|
||||||
|
self.set_sigma_range(sigma_min, sigma_max)
|
||||||
|
|
||||||
|
def set_sigma_range(self, sigma_min, sigma_max):
|
||||||
|
sigmas = torch.linspace(math.log(sigma_min), math.log(sigma_max), 1000).exp()
|
||||||
|
|
||||||
|
self.register_buffer('sigmas', sigmas) #for compatibility with some schedulers
|
||||||
|
self.register_buffer('log_sigmas', sigmas.log())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sigma_min(self):
|
||||||
|
return self.sigmas[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sigma_max(self):
|
||||||
|
return self.sigmas[-1]
|
||||||
|
|
||||||
|
def timestep(self, sigma):
|
||||||
|
return 0.25 * sigma.log()
|
||||||
|
|
||||||
|
def sigma(self, timestep):
|
||||||
|
return (timestep / 0.25).exp()
|
||||||
|
|
||||||
|
def percent_to_sigma(self, percent):
|
||||||
|
if percent <= 0.0:
|
||||||
|
return 999999999.9
|
||||||
|
if percent >= 1.0:
|
||||||
|
return 0.0
|
||||||
|
percent = 1.0 - percent
|
||||||
|
|
||||||
|
log_sigma_min = math.log(self.sigma_min)
|
||||||
|
return math.exp((math.log(self.sigma_max) - log_sigma_min) * percent + log_sigma_min)
|
||||||
|
10
comfy/sd.py
10
comfy/sd.py
@ -159,7 +159,15 @@ class VAE:
|
|||||||
self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype)
|
self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype)
|
||||||
|
|
||||||
if config is None:
|
if config is None:
|
||||||
if "taesd_decoder.1.weight" in sd:
|
if "decoder.mid.block_1.mix_factor" in sd:
|
||||||
|
encoder_config = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
||||||
|
decoder_config = encoder_config.copy()
|
||||||
|
decoder_config["video_kernel_size"] = [3, 1, 1]
|
||||||
|
decoder_config["alpha"] = 0.0
|
||||||
|
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
|
||||||
|
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': encoder_config},
|
||||||
|
decoder_config={'target': "comfy.ldm.modules.temporal_ae.VideoDecoder", 'params': decoder_config})
|
||||||
|
elif "taesd_decoder.1.weight" in sd:
|
||||||
self.first_stage_model = comfy.taesd.taesd.TAESD()
|
self.first_stage_model = comfy.taesd.taesd.TAESD()
|
||||||
else:
|
else:
|
||||||
#default SD1.x/SD2.x VAE parameters
|
#default SD1.x/SD2.x VAE parameters
|
||||||
|
@ -17,6 +17,7 @@ class SD15(supported_models_base.BASE):
|
|||||||
"model_channels": 320,
|
"model_channels": 320,
|
||||||
"use_linear_in_transformer": False,
|
"use_linear_in_transformer": False,
|
||||||
"adm_in_channels": None,
|
"adm_in_channels": None,
|
||||||
|
"use_temporal_attention": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
unet_extra_config = {
|
unet_extra_config = {
|
||||||
@ -56,6 +57,7 @@ class SD20(supported_models_base.BASE):
|
|||||||
"model_channels": 320,
|
"model_channels": 320,
|
||||||
"use_linear_in_transformer": True,
|
"use_linear_in_transformer": True,
|
||||||
"adm_in_channels": None,
|
"adm_in_channels": None,
|
||||||
|
"use_temporal_attention": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
latent_format = latent_formats.SD15
|
latent_format = latent_formats.SD15
|
||||||
@ -88,6 +90,7 @@ class SD21UnclipL(SD20):
|
|||||||
"model_channels": 320,
|
"model_channels": 320,
|
||||||
"use_linear_in_transformer": True,
|
"use_linear_in_transformer": True,
|
||||||
"adm_in_channels": 1536,
|
"adm_in_channels": 1536,
|
||||||
|
"use_temporal_attention": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
clip_vision_prefix = "embedder.model.visual."
|
clip_vision_prefix = "embedder.model.visual."
|
||||||
@ -100,6 +103,7 @@ class SD21UnclipH(SD20):
|
|||||||
"model_channels": 320,
|
"model_channels": 320,
|
||||||
"use_linear_in_transformer": True,
|
"use_linear_in_transformer": True,
|
||||||
"adm_in_channels": 2048,
|
"adm_in_channels": 2048,
|
||||||
|
"use_temporal_attention": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
clip_vision_prefix = "embedder.model.visual."
|
clip_vision_prefix = "embedder.model.visual."
|
||||||
@ -112,6 +116,7 @@ class SDXLRefiner(supported_models_base.BASE):
|
|||||||
"context_dim": 1280,
|
"context_dim": 1280,
|
||||||
"adm_in_channels": 2560,
|
"adm_in_channels": 2560,
|
||||||
"transformer_depth": [0, 0, 4, 4, 4, 4, 0, 0],
|
"transformer_depth": [0, 0, 4, 4, 4, 4, 0, 0],
|
||||||
|
"use_temporal_attention": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
latent_format = latent_formats.SDXL
|
latent_format = latent_formats.SDXL
|
||||||
@ -148,7 +153,8 @@ class SDXL(supported_models_base.BASE):
|
|||||||
"use_linear_in_transformer": True,
|
"use_linear_in_transformer": True,
|
||||||
"transformer_depth": [0, 0, 2, 2, 10, 10],
|
"transformer_depth": [0, 0, 2, 2, 10, 10],
|
||||||
"context_dim": 2048,
|
"context_dim": 2048,
|
||||||
"adm_in_channels": 2816
|
"adm_in_channels": 2816,
|
||||||
|
"use_temporal_attention": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
latent_format = latent_formats.SDXL
|
latent_format = latent_formats.SDXL
|
||||||
@ -203,8 +209,34 @@ class SSD1B(SDXL):
|
|||||||
"use_linear_in_transformer": True,
|
"use_linear_in_transformer": True,
|
||||||
"transformer_depth": [0, 0, 2, 2, 4, 4],
|
"transformer_depth": [0, 0, 2, 2, 4, 4],
|
||||||
"context_dim": 2048,
|
"context_dim": 2048,
|
||||||
"adm_in_channels": 2816
|
"adm_in_channels": 2816,
|
||||||
|
"use_temporal_attention": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class SVD_img2vid(supported_models_base.BASE):
|
||||||
|
unet_config = {
|
||||||
|
"model_channels": 320,
|
||||||
|
"in_channels": 8,
|
||||||
|
"use_linear_in_transformer": True,
|
||||||
|
"transformer_depth": [1, 1, 1, 1, 1, 1, 0, 0],
|
||||||
|
"context_dim": 1024,
|
||||||
|
"adm_in_channels": 768,
|
||||||
|
"use_temporal_attention": True,
|
||||||
|
"use_temporal_resblock": True
|
||||||
|
}
|
||||||
|
|
||||||
|
clip_vision_prefix = "conditioner.embedders.0.open_clip.model.visual."
|
||||||
|
|
||||||
|
latent_format = latent_formats.SD15
|
||||||
|
|
||||||
|
sampling_settings = {"sigma_max": 700.0, "sigma_min": 0.002}
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
out = model_base.SVD_img2vid(self, device=device)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def clip_target(self):
|
||||||
|
return None
|
||||||
|
|
||||||
models = [SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B]
|
models = [SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B]
|
||||||
|
models += [SVD_img2vid]
|
||||||
|
@ -128,6 +128,36 @@ class ModelSamplingDiscrete:
|
|||||||
m.add_object_patch("model_sampling", model_sampling)
|
m.add_object_patch("model_sampling", model_sampling)
|
||||||
return (m, )
|
return (m, )
|
||||||
|
|
||||||
|
class ModelSamplingContinuousEDM:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "model": ("MODEL",),
|
||||||
|
"sampling": (["v_prediction", "eps"],),
|
||||||
|
"sigma_max": ("FLOAT", {"default": 120.0, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
|
||||||
|
"sigma_min": ("FLOAT", {"default": 0.002, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
|
||||||
|
}}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("MODEL",)
|
||||||
|
FUNCTION = "patch"
|
||||||
|
|
||||||
|
CATEGORY = "advanced/model"
|
||||||
|
|
||||||
|
def patch(self, model, sampling, sigma_max, sigma_min):
|
||||||
|
m = model.clone()
|
||||||
|
|
||||||
|
if sampling == "eps":
|
||||||
|
sampling_type = comfy.model_sampling.EPS
|
||||||
|
elif sampling == "v_prediction":
|
||||||
|
sampling_type = comfy.model_sampling.V_PREDICTION
|
||||||
|
|
||||||
|
class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingContinuousEDM, sampling_type):
|
||||||
|
pass
|
||||||
|
|
||||||
|
model_sampling = ModelSamplingAdvanced()
|
||||||
|
model_sampling.set_sigma_range(sigma_min, sigma_max)
|
||||||
|
m.add_object_patch("model_sampling", model_sampling)
|
||||||
|
return (m, )
|
||||||
|
|
||||||
class RescaleCFG:
|
class RescaleCFG:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -169,5 +199,6 @@ class RescaleCFG:
|
|||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"ModelSamplingDiscrete": ModelSamplingDiscrete,
|
"ModelSamplingDiscrete": ModelSamplingDiscrete,
|
||||||
|
"ModelSamplingContinuousEDM": ModelSamplingContinuousEDM,
|
||||||
"RescaleCFG": RescaleCFG,
|
"RescaleCFG": RescaleCFG,
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user