diff --git a/comfy/cldm/cldm.py b/comfy/cldm/cldm.py index 9a63202a..76a525b3 100644 --- a/comfy/cldm/cldm.py +++ b/comfy/cldm/cldm.py @@ -54,6 +54,7 @@ class ControlNet(nn.Module): transformer_depth_output=None, device=None, operations=comfy.ops, + **kwargs, ): super().__init__() assert use_spatial_transformer == True, "use_spatial_transformer has to be true" diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 016795a5..947e2008 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -5,8 +5,10 @@ import torch.nn.functional as F from torch import nn, einsum from einops import rearrange, repeat from typing import Optional, Any +from functools import partial -from .diffusionmodules.util import checkpoint + +from .diffusionmodules.util import checkpoint, AlphaBlender, timestep_embedding from .sub_quadratic_attention import efficient_dot_product_attention from comfy import model_management @@ -370,21 +372,45 @@ class CrossAttention(nn.Module): class BasicTransformerBlock(nn.Module): - def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, - disable_self_attn=False, dtype=None, device=None, operations=comfy.ops): + def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, ff_in=False, inner_dim=None, + disable_self_attn=False, disable_temporal_crossattention=False, switch_temporal_ca_to_sa=False, dtype=None, device=None, operations=comfy.ops): super().__init__() + + self.ff_in = ff_in or inner_dim is not None + if inner_dim is None: + inner_dim = dim + + self.is_res = inner_dim == dim + + if self.ff_in: + self.norm_in = nn.LayerNorm(dim, dtype=dtype, device=device) + self.ff_in = FeedForward(dim, dim_out=inner_dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations) + self.disable_self_attn = disable_self_attn - self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, + self.attn1 = CrossAttention(query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout, context_dim=context_dim if self.disable_self_attn else None, dtype=dtype, device=device, operations=operations) # is a self-attention if not self.disable_self_attn - self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations) - self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, - heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device, operations=operations) # is self-attn if context is none - self.norm1 = nn.LayerNorm(dim, dtype=dtype, device=device) - self.norm2 = nn.LayerNorm(dim, dtype=dtype, device=device) - self.norm3 = nn.LayerNorm(dim, dtype=dtype, device=device) + self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations) + + if disable_temporal_crossattention: + if switch_temporal_ca_to_sa: + raise ValueError + else: + self.attn2 = None + else: + context_dim_attn2 = None + if not switch_temporal_ca_to_sa: + context_dim_attn2 = context_dim + + self.attn2 = CrossAttention(query_dim=inner_dim, context_dim=context_dim_attn2, + heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device, operations=operations) # is self-attn if context is none + self.norm2 = nn.LayerNorm(inner_dim, dtype=dtype, device=device) + + self.norm1 = nn.LayerNorm(inner_dim, dtype=dtype, device=device) + self.norm3 = nn.LayerNorm(inner_dim, dtype=dtype, device=device) self.checkpoint = checkpoint self.n_heads = n_heads self.d_head = d_head + self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa def forward(self, x, context=None, transformer_options={}): return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint) @@ -418,6 +444,12 @@ class BasicTransformerBlock(nn.Module): else: transformer_patches_replace = {} + if self.ff_in: + x_skip = x + x = self.ff_in(self.norm_in(x)) + if self.is_res: + x += x_skip + n = self.norm1(x) if self.disable_self_attn: context_attn1 = context @@ -465,31 +497,34 @@ class BasicTransformerBlock(nn.Module): for p in patch: x = p(x, extra_options) - n = self.norm2(x) - - context_attn2 = context - value_attn2 = None - if "attn2_patch" in transformer_patches: - patch = transformer_patches["attn2_patch"] - value_attn2 = context_attn2 - for p in patch: - n, context_attn2, value_attn2 = p(n, context_attn2, value_attn2, extra_options) - - attn2_replace_patch = transformer_patches_replace.get("attn2", {}) - block_attn2 = transformer_block - if block_attn2 not in attn2_replace_patch: - block_attn2 = block - - if block_attn2 in attn2_replace_patch: - if value_attn2 is None: + if self.attn2 is not None: + n = self.norm2(x) + if self.switch_temporal_ca_to_sa: + context_attn2 = n + else: + context_attn2 = context + value_attn2 = None + if "attn2_patch" in transformer_patches: + patch = transformer_patches["attn2_patch"] value_attn2 = context_attn2 - n = self.attn2.to_q(n) - context_attn2 = self.attn2.to_k(context_attn2) - value_attn2 = self.attn2.to_v(value_attn2) - n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options) - n = self.attn2.to_out(n) - else: - n = self.attn2(n, context=context_attn2, value=value_attn2) + for p in patch: + n, context_attn2, value_attn2 = p(n, context_attn2, value_attn2, extra_options) + + attn2_replace_patch = transformer_patches_replace.get("attn2", {}) + block_attn2 = transformer_block + if block_attn2 not in attn2_replace_patch: + block_attn2 = block + + if block_attn2 in attn2_replace_patch: + if value_attn2 is None: + value_attn2 = context_attn2 + n = self.attn2.to_q(n) + context_attn2 = self.attn2.to_k(context_attn2) + value_attn2 = self.attn2.to_v(value_attn2) + n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options) + n = self.attn2.to_out(n) + else: + n = self.attn2(n, context=context_attn2, value=value_attn2) if "attn2_output_patch" in transformer_patches: patch = transformer_patches["attn2_output_patch"] @@ -497,7 +532,12 @@ class BasicTransformerBlock(nn.Module): n = p(n, extra_options) x += n - x = self.ff(self.norm3(x)) + x + if self.is_res: + x_skip = x + x = self.ff(self.norm3(x)) + if self.is_res: + x += x_skip + return x @@ -565,3 +605,164 @@ class SpatialTransformer(nn.Module): x = self.proj_out(x) return x + x_in + +class SpatialVideoTransformer(SpatialTransformer): + def __init__( + self, + in_channels, + n_heads, + d_head, + depth=1, + dropout=0.0, + use_linear=False, + context_dim=None, + use_spatial_context=False, + timesteps=None, + merge_strategy: str = "fixed", + merge_factor: float = 0.5, + time_context_dim=None, + ff_in=False, + checkpoint=False, + time_depth=1, + disable_self_attn=False, + disable_temporal_crossattention=False, + max_time_embed_period: int = 10000, + dtype=None, device=None, operations=comfy.ops + ): + super().__init__( + in_channels, + n_heads, + d_head, + depth=depth, + dropout=dropout, + use_checkpoint=checkpoint, + context_dim=context_dim, + use_linear=use_linear, + disable_self_attn=disable_self_attn, + dtype=dtype, device=device, operations=operations + ) + self.time_depth = time_depth + self.depth = depth + self.max_time_embed_period = max_time_embed_period + + time_mix_d_head = d_head + n_time_mix_heads = n_heads + + time_mix_inner_dim = int(time_mix_d_head * n_time_mix_heads) + + inner_dim = n_heads * d_head + if use_spatial_context: + time_context_dim = context_dim + + self.time_stack = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + n_time_mix_heads, + time_mix_d_head, + dropout=dropout, + context_dim=time_context_dim, + # timesteps=timesteps, + checkpoint=checkpoint, + ff_in=ff_in, + inner_dim=time_mix_inner_dim, + disable_self_attn=disable_self_attn, + disable_temporal_crossattention=disable_temporal_crossattention, + dtype=dtype, device=device, operations=operations + ) + for _ in range(self.depth) + ] + ) + + assert len(self.time_stack) == len(self.transformer_blocks) + + self.use_spatial_context = use_spatial_context + self.in_channels = in_channels + + time_embed_dim = self.in_channels * 4 + self.time_pos_embed = nn.Sequential( + operations.Linear(self.in_channels, time_embed_dim, dtype=dtype, device=device), + nn.SiLU(), + operations.Linear(time_embed_dim, self.in_channels, dtype=dtype, device=device), + ) + + self.time_mixer = AlphaBlender( + alpha=merge_factor, merge_strategy=merge_strategy + ) + + def forward( + self, + x: torch.Tensor, + context: Optional[torch.Tensor] = None, + time_context: Optional[torch.Tensor] = None, + timesteps: Optional[int] = None, + image_only_indicator: Optional[torch.Tensor] = None, + transformer_options={} + ) -> torch.Tensor: + _, _, h, w = x.shape + x_in = x + spatial_context = None + if exists(context): + spatial_context = context + + if self.use_spatial_context: + assert ( + context.ndim == 3 + ), f"n dims of spatial context should be 3 but are {context.ndim}" + + if time_context is None: + time_context = context + time_context_first_timestep = time_context[::timesteps] + time_context = repeat( + time_context_first_timestep, "b ... -> (b n) ...", n=h * w + ) + elif time_context is not None and not self.use_spatial_context: + time_context = repeat(time_context, "b ... -> (b n) ...", n=h * w) + if time_context.ndim == 2: + time_context = rearrange(time_context, "b c -> b 1 c") + + x = self.norm(x) + if not self.use_linear: + x = self.proj_in(x) + x = rearrange(x, "b c h w -> b (h w) c") + if self.use_linear: + x = self.proj_in(x) + + num_frames = torch.arange(timesteps, device=x.device) + num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps) + num_frames = rearrange(num_frames, "b t -> (b t)") + t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False, max_period=self.max_time_embed_period).to(x.dtype) + emb = self.time_pos_embed(t_emb) + emb = emb[:, None, :] + + for it_, (block, mix_block) in enumerate( + zip(self.transformer_blocks, self.time_stack) + ): + transformer_options["block_index"] = it_ + x = block( + x, + context=spatial_context, + transformer_options=transformer_options, + ) + + x_mix = x + x_mix = x_mix + emb + + B, S, C = x_mix.shape + x_mix = rearrange(x_mix, "(b t) s c -> (b s) t c", t=timesteps) + x_mix = mix_block(x_mix, context=time_context) #TODO: transformer_options + x_mix = rearrange( + x_mix, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps + ) + + x = self.time_mixer(x_spatial=x, x_temporal=x_mix, image_only_indicator=image_only_indicator) + + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) + if not self.use_linear: + x = self.proj_out(x) + out = x + x_in + return out + + diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index e8f35a54..a497ed34 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -5,6 +5,8 @@ import numpy as np import torch as th import torch.nn as nn import torch.nn.functional as F +from einops import rearrange +from functools import partial from .util import ( checkpoint, @@ -12,8 +14,9 @@ from .util import ( zero_module, normalization, timestep_embedding, + AlphaBlender, ) -from ..attention import SpatialTransformer +from ..attention import SpatialTransformer, SpatialVideoTransformer, default from comfy.ldm.util import exists import comfy.ops @@ -29,10 +32,15 @@ class TimestepBlock(nn.Module): """ #This is needed because accelerate makes a copy of transformer_options which breaks "current_index" -def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None): +def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None, time_context=None, num_video_frames=None, image_only_indicator=None): for layer in ts: - if isinstance(layer, TimestepBlock): + if isinstance(layer, VideoResBlock): + x = layer(x, emb, num_video_frames, image_only_indicator) + elif isinstance(layer, TimestepBlock): x = layer(x, emb) + elif isinstance(layer, SpatialVideoTransformer): + x = layer(x, context, time_context, num_video_frames, image_only_indicator, transformer_options) + transformer_options["current_index"] += 1 elif isinstance(layer, SpatialTransformer): x = layer(x, context, transformer_options) if "current_index" in transformer_options: @@ -145,6 +153,9 @@ class ResBlock(TimestepBlock): use_checkpoint=False, up=False, down=False, + kernel_size=3, + exchange_temb_dims=False, + skip_t_emb=False, dtype=None, device=None, operations=comfy.ops @@ -157,11 +168,17 @@ class ResBlock(TimestepBlock): self.use_conv = use_conv self.use_checkpoint = use_checkpoint self.use_scale_shift_norm = use_scale_shift_norm + self.exchange_temb_dims = exchange_temb_dims + + if isinstance(kernel_size, list): + padding = [k // 2 for k in kernel_size] + else: + padding = kernel_size // 2 self.in_layers = nn.Sequential( nn.GroupNorm(32, channels, dtype=dtype, device=device), nn.SiLU(), - operations.conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=dtype, device=device), + operations.conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device), ) self.updown = up or down @@ -175,19 +192,24 @@ class ResBlock(TimestepBlock): else: self.h_upd = self.x_upd = nn.Identity() - self.emb_layers = nn.Sequential( - nn.SiLU(), - operations.Linear( - emb_channels, - 2 * self.out_channels if use_scale_shift_norm else self.out_channels, dtype=dtype, device=device - ), - ) + self.skip_t_emb = skip_t_emb + if self.skip_t_emb: + self.emb_layers = None + self.exchange_temb_dims = False + else: + self.emb_layers = nn.Sequential( + nn.SiLU(), + operations.Linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, dtype=dtype, device=device + ), + ) self.out_layers = nn.Sequential( nn.GroupNorm(32, self.out_channels, dtype=dtype, device=device), nn.SiLU(), nn.Dropout(p=dropout), zero_module( - operations.conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1, dtype=dtype, device=device) + operations.conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device) ), ) @@ -195,7 +217,7 @@ class ResBlock(TimestepBlock): self.skip_connection = nn.Identity() elif use_conv: self.skip_connection = operations.conv_nd( - dims, channels, self.out_channels, 3, padding=1, dtype=dtype, device=device + dims, channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device ) else: self.skip_connection = operations.conv_nd(dims, channels, self.out_channels, 1, dtype=dtype, device=device) @@ -221,19 +243,110 @@ class ResBlock(TimestepBlock): h = in_conv(h) else: h = self.in_layers(x) - emb_out = self.emb_layers(emb).type(h.dtype) - while len(emb_out.shape) < len(h.shape): - emb_out = emb_out[..., None] + + emb_out = None + if not self.skip_t_emb: + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] if self.use_scale_shift_norm: out_norm, out_rest = self.out_layers[0], self.out_layers[1:] - scale, shift = th.chunk(emb_out, 2, dim=1) - h = out_norm(h) * (1 + scale) + shift + h = out_norm(h) + if emb_out is not None: + scale, shift = th.chunk(emb_out, 2, dim=1) + h *= (1 + scale) + h += shift h = out_rest(h) else: - h = h + emb_out + if emb_out is not None: + if self.exchange_temb_dims: + emb_out = rearrange(emb_out, "b t c ... -> b c t ...") + h = h + emb_out h = self.out_layers(h) return self.skip_connection(x) + h + +class VideoResBlock(ResBlock): + def __init__( + self, + channels: int, + emb_channels: int, + dropout: float, + video_kernel_size=3, + merge_strategy: str = "fixed", + merge_factor: float = 0.5, + out_channels=None, + use_conv: bool = False, + use_scale_shift_norm: bool = False, + dims: int = 2, + use_checkpoint: bool = False, + up: bool = False, + down: bool = False, + dtype=None, + device=None, + operations=comfy.ops + ): + super().__init__( + channels, + emb_channels, + dropout, + out_channels=out_channels, + use_conv=use_conv, + use_scale_shift_norm=use_scale_shift_norm, + dims=dims, + use_checkpoint=use_checkpoint, + up=up, + down=down, + dtype=dtype, + device=device, + operations=operations + ) + + self.time_stack = ResBlock( + default(out_channels, channels), + emb_channels, + dropout=dropout, + dims=3, + out_channels=default(out_channels, channels), + use_scale_shift_norm=False, + use_conv=False, + up=False, + down=False, + kernel_size=video_kernel_size, + use_checkpoint=use_checkpoint, + exchange_temb_dims=True, + dtype=dtype, + device=device, + operations=operations + ) + self.time_mixer = AlphaBlender( + alpha=merge_factor, + merge_strategy=merge_strategy, + rearrange_pattern="b t -> b 1 t 1 1", + ) + + def forward( + self, + x: th.Tensor, + emb: th.Tensor, + num_video_frames: int, + image_only_indicator = None, + ) -> th.Tensor: + x = super().forward(x, emb) + + x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames) + x = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames) + + x = self.time_stack( + x, rearrange(emb, "(b t) ... -> b t ...", t=num_video_frames) + ) + x = self.time_mixer( + x_spatial=x_mix, x_temporal=x, image_only_indicator=image_only_indicator + ) + x = rearrange(x, "b c t h w -> (b t) c h w") + return x + + class Timestep(nn.Module): def __init__(self, dim): super().__init__() @@ -310,6 +423,16 @@ class UNetModel(nn.Module): adm_in_channels=None, transformer_depth_middle=None, transformer_depth_output=None, + use_temporal_resblock=False, + use_temporal_attention=False, + time_context_dim=None, + extra_ff_mix_layer=False, + use_spatial_context=False, + merge_strategy=None, + merge_factor=0.0, + video_kernel_size=None, + disable_temporal_crossattention=False, + max_ddpm_temb_period=10000, device=None, operations=comfy.ops, ): @@ -364,8 +487,12 @@ class UNetModel(nn.Module): self.num_heads = num_heads self.num_head_channels = num_head_channels self.num_heads_upsample = num_heads_upsample + self.use_temporal_resblocks = use_temporal_resblock self.predict_codebook_ids = n_embed is not None + self.default_num_video_frames = None + self.default_image_only_indicator = None + time_embed_dim = model_channels * 4 self.time_embed = nn.Sequential( operations.Linear(model_channels, time_embed_dim, dtype=self.dtype, device=device), @@ -402,13 +529,104 @@ class UNetModel(nn.Module): input_block_chans = [model_channels] ch = model_channels ds = 1 + + def get_attention_layer( + ch, + num_heads, + dim_head, + depth=1, + context_dim=None, + use_checkpoint=False, + disable_self_attn=False, + ): + if use_temporal_attention: + return SpatialVideoTransformer( + ch, + num_heads, + dim_head, + depth=depth, + context_dim=context_dim, + time_context_dim=time_context_dim, + dropout=dropout, + ff_in=extra_ff_mix_layer, + use_spatial_context=use_spatial_context, + merge_strategy=merge_strategy, + merge_factor=merge_factor, + checkpoint=use_checkpoint, + use_linear=use_linear_in_transformer, + disable_self_attn=disable_self_attn, + disable_temporal_crossattention=disable_temporal_crossattention, + max_time_embed_period=max_ddpm_temb_period, + dtype=self.dtype, device=device, operations=operations + ) + else: + return SpatialTransformer( + ch, num_heads, dim_head, depth=depth, context_dim=context_dim, + disable_self_attn=disable_self_attn, use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations + ) + + def get_resblock( + merge_factor, + merge_strategy, + video_kernel_size, + ch, + time_embed_dim, + dropout, + out_channels, + dims, + use_checkpoint, + use_scale_shift_norm, + down=False, + up=False, + dtype=None, + device=None, + operations=comfy.ops + ): + if self.use_temporal_resblocks: + return VideoResBlock( + merge_factor=merge_factor, + merge_strategy=merge_strategy, + video_kernel_size=video_kernel_size, + channels=ch, + emb_channels=time_embed_dim, + dropout=dropout, + out_channels=out_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=down, + up=up, + dtype=dtype, + device=device, + operations=operations + ) + else: + return ResBlock( + channels=ch, + emb_channels=time_embed_dim, + dropout=dropout, + out_channels=out_channels, + use_checkpoint=use_checkpoint, + dims=dims, + use_scale_shift_norm=use_scale_shift_norm, + down=down, + up=up, + dtype=dtype, + device=device, + operations=operations + ) + for level, mult in enumerate(channel_mult): for nr in range(self.num_res_blocks[level]): layers = [ - ResBlock( - ch, - time_embed_dim, - dropout, + get_resblock( + merge_factor=merge_factor, + merge_strategy=merge_strategy, + video_kernel_size=video_kernel_size, + ch=ch, + time_embed_dim=time_embed_dim, + dropout=dropout, out_channels=mult * model_channels, dims=dims, use_checkpoint=use_checkpoint, @@ -435,11 +653,9 @@ class UNetModel(nn.Module): disabled_sa = False if not exists(num_attention_blocks) or nr < num_attention_blocks[level]: - layers.append(SpatialTransformer( + layers.append(get_attention_layer( ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim, - disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, - use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations - ) + disable_self_attn=disabled_sa, use_checkpoint=use_checkpoint) ) self.input_blocks.append(TimestepEmbedSequential(*layers)) self._feature_size += ch @@ -448,10 +664,13 @@ class UNetModel(nn.Module): out_ch = ch self.input_blocks.append( TimestepEmbedSequential( - ResBlock( - ch, - time_embed_dim, - dropout, + get_resblock( + merge_factor=merge_factor, + merge_strategy=merge_strategy, + video_kernel_size=video_kernel_size, + ch=ch, + time_embed_dim=time_embed_dim, + dropout=dropout, out_channels=out_ch, dims=dims, use_checkpoint=use_checkpoint, @@ -481,10 +700,14 @@ class UNetModel(nn.Module): #num_heads = 1 dim_head = ch // num_heads if use_spatial_transformer else num_head_channels mid_block = [ - ResBlock( - ch, - time_embed_dim, - dropout, + get_resblock( + merge_factor=merge_factor, + merge_strategy=merge_strategy, + video_kernel_size=video_kernel_size, + ch=ch, + time_embed_dim=time_embed_dim, + dropout=dropout, + out_channels=None, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, @@ -493,15 +716,18 @@ class UNetModel(nn.Module): operations=operations )] if transformer_depth_middle >= 0: - mid_block += [SpatialTransformer( # always uses a self-attn + mid_block += [get_attention_layer( # always uses a self-attn ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim, - disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, - use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations + disable_self_attn=disable_middle_self_attn, use_checkpoint=use_checkpoint ), - ResBlock( - ch, - time_embed_dim, - dropout, + get_resblock( + merge_factor=merge_factor, + merge_strategy=merge_strategy, + video_kernel_size=video_kernel_size, + ch=ch, + time_embed_dim=time_embed_dim, + dropout=dropout, + out_channels=None, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, @@ -517,10 +743,13 @@ class UNetModel(nn.Module): for i in range(self.num_res_blocks[level] + 1): ich = input_block_chans.pop() layers = [ - ResBlock( - ch + ich, - time_embed_dim, - dropout, + get_resblock( + merge_factor=merge_factor, + merge_strategy=merge_strategy, + video_kernel_size=video_kernel_size, + ch=ch + ich, + time_embed_dim=time_embed_dim, + dropout=dropout, out_channels=model_channels * mult, dims=dims, use_checkpoint=use_checkpoint, @@ -548,19 +777,21 @@ class UNetModel(nn.Module): if not exists(num_attention_blocks) or i < num_attention_blocks[level]: layers.append( - SpatialTransformer( + get_attention_layer( ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim, - disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, - use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations + disable_self_attn=disabled_sa, use_checkpoint=use_checkpoint ) ) if level and i == self.num_res_blocks[level]: out_ch = ch layers.append( - ResBlock( - ch, - time_embed_dim, - dropout, + get_resblock( + merge_factor=merge_factor, + merge_strategy=merge_strategy, + video_kernel_size=video_kernel_size, + ch=ch, + time_embed_dim=time_embed_dim, + dropout=dropout, out_channels=out_ch, dims=dims, use_checkpoint=use_checkpoint, @@ -602,6 +833,10 @@ class UNetModel(nn.Module): transformer_options["current_index"] = 0 transformer_patches = transformer_options.get("patches", {}) + num_video_frames = kwargs.get("num_video_frames", self.default_num_video_frames) + image_only_indicator = kwargs.get("image_only_indicator", self.default_image_only_indicator) + time_context = kwargs.get("time_context", None) + assert (y is not None) == ( self.num_classes is not None ), "must specify y if and only if the model is class-conditional" @@ -616,7 +851,7 @@ class UNetModel(nn.Module): h = x.type(self.dtype) for id, module in enumerate(self.input_blocks): transformer_options["block"] = ("input", id) - h = forward_timestep_embed(module, h, emb, context, transformer_options) + h = forward_timestep_embed(module, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator) h = apply_control(h, control, 'input') if "input_block_patch" in transformer_patches: patch = transformer_patches["input_block_patch"] @@ -630,9 +865,10 @@ class UNetModel(nn.Module): h = p(h, transformer_options) transformer_options["block"] = ("middle", 0) - h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options) + h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator) h = apply_control(h, control, 'middle') + for id, module in enumerate(self.output_blocks): transformer_options["block"] = ("output", id) hsp = hs.pop() @@ -649,7 +885,7 @@ class UNetModel(nn.Module): output_shape = hs[-1].shape else: output_shape = None - h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape) + h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator) h = h.type(x.dtype) if self.predict_codebook_ids: return self.id_predictor(h) diff --git a/comfy/ldm/modules/diffusionmodules/util.py b/comfy/ldm/modules/diffusionmodules/util.py index 0298ca99..704bbe57 100644 --- a/comfy/ldm/modules/diffusionmodules/util.py +++ b/comfy/ldm/modules/diffusionmodules/util.py @@ -13,11 +13,78 @@ import math import torch import torch.nn as nn import numpy as np -from einops import repeat +from einops import repeat, rearrange from comfy.ldm.util import instantiate_from_config import comfy.ops +class AlphaBlender(nn.Module): + strategies = ["learned", "fixed", "learned_with_images"] + + def __init__( + self, + alpha: float, + merge_strategy: str = "learned_with_images", + rearrange_pattern: str = "b t -> (b t) 1 1", + ): + super().__init__() + self.merge_strategy = merge_strategy + self.rearrange_pattern = rearrange_pattern + + assert ( + merge_strategy in self.strategies + ), f"merge_strategy needs to be in {self.strategies}" + + if self.merge_strategy == "fixed": + self.register_buffer("mix_factor", torch.Tensor([alpha])) + elif ( + self.merge_strategy == "learned" + or self.merge_strategy == "learned_with_images" + ): + self.register_parameter( + "mix_factor", torch.nn.Parameter(torch.Tensor([alpha])) + ) + else: + raise ValueError(f"unknown merge strategy {self.merge_strategy}") + + def get_alpha(self, image_only_indicator: torch.Tensor) -> torch.Tensor: + # skip_time_mix = rearrange(repeat(skip_time_mix, 'b -> (b t) () () ()', t=t), '(b t) 1 ... -> b 1 t ...', t=t) + if self.merge_strategy == "fixed": + # make shape compatible + # alpha = repeat(self.mix_factor, '1 -> b () t () ()', t=t, b=bs) + alpha = self.mix_factor + elif self.merge_strategy == "learned": + alpha = torch.sigmoid(self.mix_factor) + # make shape compatible + # alpha = repeat(alpha, '1 -> s () ()', s = t * bs) + elif self.merge_strategy == "learned_with_images": + assert image_only_indicator is not None, "need image_only_indicator ..." + alpha = torch.where( + image_only_indicator.bool(), + torch.ones(1, 1, device=image_only_indicator.device), + rearrange(torch.sigmoid(self.mix_factor), "... -> ... 1"), + ) + alpha = rearrange(alpha, self.rearrange_pattern) + # make shape compatible + # alpha = repeat(alpha, '1 -> s () ()', s = t * bs) + else: + raise NotImplementedError() + return alpha + + def forward( + self, + x_spatial, + x_temporal, + image_only_indicator=None, + ) -> torch.Tensor: + alpha = self.get_alpha(image_only_indicator) + x = ( + alpha.to(x_spatial.dtype) * x_spatial + + (1.0 - alpha).to(x_spatial.dtype) * x_temporal + ) + return x + + def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): if schedule == "linear": betas = ( diff --git a/comfy/ldm/modules/temporal_ae.py b/comfy/ldm/modules/temporal_ae.py new file mode 100644 index 00000000..11ae049f --- /dev/null +++ b/comfy/ldm/modules/temporal_ae.py @@ -0,0 +1,244 @@ +import functools +from typing import Callable, Iterable, Union + +import torch +from einops import rearrange, repeat + +import comfy.ops + +from .diffusionmodules.model import ( + AttnBlock, + Decoder, + ResnetBlock, +) +from .diffusionmodules.openaimodel import ResBlock, timestep_embedding +from .attention import BasicTransformerBlock + +def partialclass(cls, *args, **kwargs): + class NewCls(cls): + __init__ = functools.partialmethod(cls.__init__, *args, **kwargs) + + return NewCls + + +class VideoResBlock(ResnetBlock): + def __init__( + self, + out_channels, + *args, + dropout=0.0, + video_kernel_size=3, + alpha=0.0, + merge_strategy="learned", + **kwargs, + ): + super().__init__(out_channels=out_channels, dropout=dropout, *args, **kwargs) + if video_kernel_size is None: + video_kernel_size = [3, 1, 1] + self.time_stack = ResBlock( + channels=out_channels, + emb_channels=0, + dropout=dropout, + dims=3, + use_scale_shift_norm=False, + use_conv=False, + up=False, + down=False, + kernel_size=video_kernel_size, + use_checkpoint=False, + skip_t_emb=True, + ) + + self.merge_strategy = merge_strategy + if self.merge_strategy == "fixed": + self.register_buffer("mix_factor", torch.Tensor([alpha])) + elif self.merge_strategy == "learned": + self.register_parameter( + "mix_factor", torch.nn.Parameter(torch.Tensor([alpha])) + ) + else: + raise ValueError(f"unknown merge strategy {self.merge_strategy}") + + def get_alpha(self, bs): + if self.merge_strategy == "fixed": + return self.mix_factor + elif self.merge_strategy == "learned": + return torch.sigmoid(self.mix_factor) + else: + raise NotImplementedError() + + def forward(self, x, temb, skip_video=False, timesteps=None): + b, c, h, w = x.shape + if timesteps is None: + timesteps = b + + x = super().forward(x, temb) + + if not skip_video: + x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps) + + x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps) + + x = self.time_stack(x, temb) + + alpha = self.get_alpha(bs=b // timesteps) + x = alpha * x + (1.0 - alpha) * x_mix + + x = rearrange(x, "b c t h w -> (b t) c h w") + return x + + +class AE3DConv(torch.nn.Conv2d): + def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs): + super().__init__(in_channels, out_channels, *args, **kwargs) + if isinstance(video_kernel_size, Iterable): + padding = [int(k // 2) for k in video_kernel_size] + else: + padding = int(video_kernel_size // 2) + + self.time_mix_conv = torch.nn.Conv3d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=video_kernel_size, + padding=padding, + ) + + def forward(self, input, timesteps=None, skip_video=False): + if timesteps is None: + timesteps = input.shape[0] + x = super().forward(input) + if skip_video: + return x + x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps) + x = self.time_mix_conv(x) + return rearrange(x, "b c t h w -> (b t) c h w") + + +class AttnVideoBlock(AttnBlock): + def __init__( + self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned" + ): + super().__init__(in_channels) + # no context, single headed, as in base class + self.time_mix_block = BasicTransformerBlock( + dim=in_channels, + n_heads=1, + d_head=in_channels, + checkpoint=False, + ff_in=True, + ) + + time_embed_dim = self.in_channels * 4 + self.video_time_embed = torch.nn.Sequential( + comfy.ops.Linear(self.in_channels, time_embed_dim), + torch.nn.SiLU(), + comfy.ops.Linear(time_embed_dim, self.in_channels), + ) + + self.merge_strategy = merge_strategy + if self.merge_strategy == "fixed": + self.register_buffer("mix_factor", torch.Tensor([alpha])) + elif self.merge_strategy == "learned": + self.register_parameter( + "mix_factor", torch.nn.Parameter(torch.Tensor([alpha])) + ) + else: + raise ValueError(f"unknown merge strategy {self.merge_strategy}") + + def forward(self, x, timesteps=None, skip_time_block=False): + if skip_time_block: + return super().forward(x) + + if timesteps is None: + timesteps = x.shape[0] + + x_in = x + x = self.attention(x) + h, w = x.shape[2:] + x = rearrange(x, "b c h w -> b (h w) c") + + x_mix = x + num_frames = torch.arange(timesteps, device=x.device) + num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps) + num_frames = rearrange(num_frames, "b t -> (b t)") + t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False) + emb = self.video_time_embed(t_emb) # b, n_channels + emb = emb[:, None, :] + x_mix = x_mix + emb + + alpha = self.get_alpha() + x_mix = self.time_mix_block(x_mix, timesteps=timesteps) + x = alpha * x + (1.0 - alpha) * x_mix # alpha merge + + x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) + x = self.proj_out(x) + + return x_in + x + + def get_alpha( + self, + ): + if self.merge_strategy == "fixed": + return self.mix_factor + elif self.merge_strategy == "learned": + return torch.sigmoid(self.mix_factor) + else: + raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}") + + + +def make_time_attn( + in_channels, + attn_type="vanilla", + attn_kwargs=None, + alpha: float = 0, + merge_strategy: str = "learned", +): + return partialclass( + AttnVideoBlock, in_channels, alpha=alpha, merge_strategy=merge_strategy + ) + + +class Conv2DWrapper(torch.nn.Conv2d): + def forward(self, input: torch.Tensor, **kwargs) -> torch.Tensor: + return super().forward(input) + + +class VideoDecoder(Decoder): + available_time_modes = ["all", "conv-only", "attn-only"] + + def __init__( + self, + *args, + video_kernel_size: Union[int, list] = 3, + alpha: float = 0.0, + merge_strategy: str = "learned", + time_mode: str = "conv-only", + **kwargs, + ): + self.video_kernel_size = video_kernel_size + self.alpha = alpha + self.merge_strategy = merge_strategy + self.time_mode = time_mode + assert ( + self.time_mode in self.available_time_modes + ), f"time_mode parameter has to be in {self.available_time_modes}" + + if self.time_mode != "attn-only": + kwargs["conv_out_op"] = partialclass(AE3DConv, video_kernel_size=self.video_kernel_size) + if self.time_mode not in ["conv-only", "only-last-conv"]: + kwargs["attn_op"] = partialclass(make_time_attn, alpha=self.alpha, merge_strategy=self.merge_strategy) + if self.time_mode not in ["attn-only", "only-last-conv"]: + kwargs["resnet_op"] = partialclass(VideoResBlock, video_kernel_size=self.video_kernel_size, alpha=self.alpha, merge_strategy=self.merge_strategy) + + super().__init__(*args, **kwargs) + + def get_last_layer(self, skip_time_mix=False, **kwargs): + if self.time_mode == "attn-only": + raise NotImplementedError("TODO") + else: + return ( + self.conv_out.time_mix_conv.weight + if not skip_time_mix + else self.conv_out.weight + ) diff --git a/comfy/model_base.py b/comfy/model_base.py index 772e2693..34274c4a 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -10,17 +10,22 @@ from . import utils class ModelType(Enum): EPS = 1 V_PREDICTION = 2 + V_PREDICTION_EDM = 3 -from comfy.model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete +from comfy.model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete, ModelSamplingContinuousEDM + def model_sampling(model_config, model_type): + s = ModelSamplingDiscrete + if model_type == ModelType.EPS: c = EPS elif model_type == ModelType.V_PREDICTION: c = V_PREDICTION - - s = ModelSamplingDiscrete + elif model_type == ModelType.V_PREDICTION_EDM: + c = V_PREDICTION + s = ModelSamplingContinuousEDM class ModelSampling(s, c): pass @@ -262,3 +267,48 @@ class SDXL(BaseModel): out.append(self.embedder(torch.Tensor([target_width]))) flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1) return torch.cat((clip_pooled.to(flat.device), flat), dim=1) + +class SVD_img2vid(BaseModel): + def __init__(self, model_config, model_type=ModelType.V_PREDICTION_EDM, device=None): + super().__init__(model_config, model_type, device=device) + self.embedder = Timestep(256) + + def encode_adm(self, **kwargs): + fps_id = kwargs.get("fps", 6) - 1 + motion_bucket_id = kwargs.get("motion_bucket_id", 127) + augmentation = kwargs.get("augmentation_level", 0) + + out = [] + out.append(self.embedder(torch.Tensor([fps_id]))) + out.append(self.embedder(torch.Tensor([motion_bucket_id]))) + out.append(self.embedder(torch.Tensor([augmentation]))) + + flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0) + return flat + + def extra_conds(self, **kwargs): + out = {} + adm = self.encode_adm(**kwargs) + if adm is not None: + out['y'] = comfy.conds.CONDRegular(adm) + + latent_image = kwargs.get("concat_latent_image", None) + noise = kwargs.get("noise", None) + device = kwargs["device"] + + if latent_image is None: + latent_image = torch.zeros_like(noise) + + if latent_image.shape[1:] != noise.shape[1:]: + latent_image = utils.common_upscale(latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center") + + latent_image = utils.repeat_to_batch_size(latent_image, noise.shape[0]) + + out['c_concat'] = comfy.conds.CONDNoiseShape(latent_image) + + if "time_conditioning" in kwargs: + out["time_context"] = comfy.conds.CONDCrossAttn(kwargs["time_conditioning"]) + + out['image_only_indicator'] = comfy.conds.CONDConstant(torch.zeros((1,), device=device)) + out['num_video_frames'] = comfy.conds.CONDConstant(noise.shape[0]) + return out diff --git a/comfy/model_detection.py b/comfy/model_detection.py index d65d91e7..45d603a0 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -24,7 +24,8 @@ def calculate_transformer_depth(prefix, state_dict_keys, state_dict): last_transformer_depth = count_blocks(state_dict_keys, transformer_prefix + '{}') context_dim = state_dict['{}0.attn2.to_k.weight'.format(transformer_prefix)].shape[1] use_linear_in_transformer = len(state_dict['{}1.proj_in.weight'.format(prefix)].shape) == 2 - return last_transformer_depth, context_dim, use_linear_in_transformer + time_stack = '{}1.time_stack.0.attn1.to_q.weight'.format(prefix) in state_dict or '{}1.time_mix_blocks.0.attn1.to_q.weight'.format(prefix) in state_dict + return last_transformer_depth, context_dim, use_linear_in_transformer, time_stack return None def detect_unet_config(state_dict, key_prefix, dtype): @@ -57,6 +58,7 @@ def detect_unet_config(state_dict, key_prefix, dtype): context_dim = None use_linear_in_transformer = False + video_model = False current_res = 1 count = 0 @@ -99,6 +101,7 @@ def detect_unet_config(state_dict, key_prefix, dtype): if context_dim is None: context_dim = out[1] use_linear_in_transformer = out[2] + video_model = out[3] else: transformer_depth.append(0) @@ -127,6 +130,19 @@ def detect_unet_config(state_dict, key_prefix, dtype): unet_config["transformer_depth_middle"] = transformer_depth_middle unet_config['use_linear_in_transformer'] = use_linear_in_transformer unet_config["context_dim"] = context_dim + + if video_model: + unet_config["extra_ff_mix_layer"] = True + unet_config["use_spatial_context"] = True + unet_config["merge_strategy"] = "learned_with_images" + unet_config["merge_factor"] = 0.0 + unet_config["video_kernel_size"] = [3, 1, 1] + unet_config["use_temporal_resblock"] = True + unet_config["use_temporal_attention"] = True + else: + unet_config["use_temporal_resblock"] = False + unet_config["use_temporal_attention"] = False + return unet_config def model_config_from_unet_config(unet_config): diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py index 9e2a1c1a..fac5c995 100644 --- a/comfy/model_sampling.py +++ b/comfy/model_sampling.py @@ -1,7 +1,7 @@ import torch import numpy as np from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule - +import math class EPS: def calculate_input(self, sigma, noise): @@ -83,3 +83,47 @@ class ModelSamplingDiscrete(torch.nn.Module): percent = 1.0 - percent return self.sigma(torch.tensor(percent * 999.0)).item() + +class ModelSamplingContinuousEDM(torch.nn.Module): + def __init__(self, model_config=None): + super().__init__() + self.sigma_data = 1.0 + + if model_config is not None: + sampling_settings = model_config.sampling_settings + else: + sampling_settings = {} + + sigma_min = sampling_settings.get("sigma_min", 0.002) + sigma_max = sampling_settings.get("sigma_max", 120.0) + self.set_sigma_range(sigma_min, sigma_max) + + def set_sigma_range(self, sigma_min, sigma_max): + sigmas = torch.linspace(math.log(sigma_min), math.log(sigma_max), 1000).exp() + + self.register_buffer('sigmas', sigmas) #for compatibility with some schedulers + self.register_buffer('log_sigmas', sigmas.log()) + + @property + def sigma_min(self): + return self.sigmas[0] + + @property + def sigma_max(self): + return self.sigmas[-1] + + def timestep(self, sigma): + return 0.25 * sigma.log() + + def sigma(self, timestep): + return (timestep / 0.25).exp() + + def percent_to_sigma(self, percent): + if percent <= 0.0: + return 999999999.9 + if percent >= 1.0: + return 0.0 + percent = 1.0 - percent + + log_sigma_min = math.log(self.sigma_min) + return math.exp((math.log(self.sigma_max) - log_sigma_min) * percent + log_sigma_min) diff --git a/comfy/sd.py b/comfy/sd.py index a8df3bdd..7f85540c 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -159,7 +159,15 @@ class VAE: self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype) if config is None: - if "taesd_decoder.1.weight" in sd: + if "decoder.mid.block_1.mix_factor" in sd: + encoder_config = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} + decoder_config = encoder_config.copy() + decoder_config["video_kernel_size"] = [3, 1, 1] + decoder_config["alpha"] = 0.0 + self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"}, + encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': encoder_config}, + decoder_config={'target': "comfy.ldm.modules.temporal_ae.VideoDecoder", 'params': decoder_config}) + elif "taesd_decoder.1.weight" in sd: self.first_stage_model = comfy.taesd.taesd.TAESD() else: #default SD1.x/SD2.x VAE parameters diff --git a/comfy/supported_models.py b/comfy/supported_models.py index fdd4ea4f..7e2ac677 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -17,6 +17,7 @@ class SD15(supported_models_base.BASE): "model_channels": 320, "use_linear_in_transformer": False, "adm_in_channels": None, + "use_temporal_attention": False, } unet_extra_config = { @@ -56,6 +57,7 @@ class SD20(supported_models_base.BASE): "model_channels": 320, "use_linear_in_transformer": True, "adm_in_channels": None, + "use_temporal_attention": False, } latent_format = latent_formats.SD15 @@ -88,6 +90,7 @@ class SD21UnclipL(SD20): "model_channels": 320, "use_linear_in_transformer": True, "adm_in_channels": 1536, + "use_temporal_attention": False, } clip_vision_prefix = "embedder.model.visual." @@ -100,6 +103,7 @@ class SD21UnclipH(SD20): "model_channels": 320, "use_linear_in_transformer": True, "adm_in_channels": 2048, + "use_temporal_attention": False, } clip_vision_prefix = "embedder.model.visual." @@ -112,6 +116,7 @@ class SDXLRefiner(supported_models_base.BASE): "context_dim": 1280, "adm_in_channels": 2560, "transformer_depth": [0, 0, 4, 4, 4, 4, 0, 0], + "use_temporal_attention": False, } latent_format = latent_formats.SDXL @@ -148,7 +153,8 @@ class SDXL(supported_models_base.BASE): "use_linear_in_transformer": True, "transformer_depth": [0, 0, 2, 2, 10, 10], "context_dim": 2048, - "adm_in_channels": 2816 + "adm_in_channels": 2816, + "use_temporal_attention": False, } latent_format = latent_formats.SDXL @@ -203,8 +209,34 @@ class SSD1B(SDXL): "use_linear_in_transformer": True, "transformer_depth": [0, 0, 2, 2, 4, 4], "context_dim": 2048, - "adm_in_channels": 2816 + "adm_in_channels": 2816, + "use_temporal_attention": False, } +class SVD_img2vid(supported_models_base.BASE): + unet_config = { + "model_channels": 320, + "in_channels": 8, + "use_linear_in_transformer": True, + "transformer_depth": [1, 1, 1, 1, 1, 1, 0, 0], + "context_dim": 1024, + "adm_in_channels": 768, + "use_temporal_attention": True, + "use_temporal_resblock": True + } + + clip_vision_prefix = "conditioner.embedders.0.open_clip.model.visual." + + latent_format = latent_formats.SD15 + + sampling_settings = {"sigma_max": 700.0, "sigma_min": 0.002} + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.SVD_img2vid(self, device=device) + return out + + def clip_target(self): + return None models = [SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B] +models += [SVD_img2vid] diff --git a/comfy_extras/nodes_model_advanced.py b/comfy_extras/nodes_model_advanced.py index 0f4ddd9c..6991c983 100644 --- a/comfy_extras/nodes_model_advanced.py +++ b/comfy_extras/nodes_model_advanced.py @@ -128,6 +128,36 @@ class ModelSamplingDiscrete: m.add_object_patch("model_sampling", model_sampling) return (m, ) +class ModelSamplingContinuousEDM: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "sampling": (["v_prediction", "eps"],), + "sigma_max": ("FLOAT", {"default": 120.0, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}), + "sigma_min": ("FLOAT", {"default": 0.002, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}), + }} + + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "advanced/model" + + def patch(self, model, sampling, sigma_max, sigma_min): + m = model.clone() + + if sampling == "eps": + sampling_type = comfy.model_sampling.EPS + elif sampling == "v_prediction": + sampling_type = comfy.model_sampling.V_PREDICTION + + class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingContinuousEDM, sampling_type): + pass + + model_sampling = ModelSamplingAdvanced() + model_sampling.set_sigma_range(sigma_min, sigma_max) + m.add_object_patch("model_sampling", model_sampling) + return (m, ) + class RescaleCFG: @classmethod def INPUT_TYPES(s): @@ -169,5 +199,6 @@ class RescaleCFG: NODE_CLASS_MAPPINGS = { "ModelSamplingDiscrete": ModelSamplingDiscrete, + "ModelSamplingContinuousEDM": ModelSamplingContinuousEDM, "RescaleCFG": RescaleCFG, }