mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-20 03:13:30 +00:00
Remove some useless code.
This commit is contained in:
parent
95d796fc85
commit
2b13939044
@ -13,7 +13,7 @@ from ..ldm.modules.diffusionmodules.util import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from ..ldm.modules.attention import SpatialTransformer
|
from ..ldm.modules.attention import SpatialTransformer
|
||||||
from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock
|
from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample
|
||||||
from ..ldm.util import exists
|
from ..ldm.util import exists
|
||||||
|
|
||||||
|
|
||||||
@ -57,6 +57,7 @@ class ControlNet(nn.Module):
|
|||||||
transformer_depth_middle=None,
|
transformer_depth_middle=None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
assert use_spatial_transformer == True, "use_spatial_transformer has to be true"
|
||||||
if use_spatial_transformer:
|
if use_spatial_transformer:
|
||||||
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
|
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
|
||||||
|
|
||||||
@ -200,13 +201,7 @@ class ControlNet(nn.Module):
|
|||||||
|
|
||||||
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
|
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
|
||||||
layers.append(
|
layers.append(
|
||||||
AttentionBlock(
|
SpatialTransformer(
|
||||||
ch,
|
|
||||||
use_checkpoint=use_checkpoint,
|
|
||||||
num_heads=num_heads,
|
|
||||||
num_head_channels=dim_head,
|
|
||||||
use_new_attention_order=use_new_attention_order,
|
|
||||||
) if not use_spatial_transformer else SpatialTransformer(
|
|
||||||
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
|
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
|
||||||
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
||||||
use_checkpoint=use_checkpoint
|
use_checkpoint=use_checkpoint
|
||||||
@ -259,13 +254,7 @@ class ControlNet(nn.Module):
|
|||||||
use_checkpoint=use_checkpoint,
|
use_checkpoint=use_checkpoint,
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
),
|
),
|
||||||
AttentionBlock(
|
SpatialTransformer( # always uses a self-attn
|
||||||
ch,
|
|
||||||
use_checkpoint=use_checkpoint,
|
|
||||||
num_heads=num_heads,
|
|
||||||
num_head_channels=dim_head,
|
|
||||||
use_new_attention_order=use_new_attention_order,
|
|
||||||
) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
|
|
||||||
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
|
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
|
||||||
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
|
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
|
||||||
use_checkpoint=use_checkpoint
|
use_checkpoint=use_checkpoint
|
||||||
|
@ -19,45 +19,6 @@ from ..attention import SpatialTransformer
|
|||||||
from comfy.ldm.util import exists
|
from comfy.ldm.util import exists
|
||||||
|
|
||||||
|
|
||||||
# dummy replace
|
|
||||||
def convert_module_to_f16(x):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def convert_module_to_f32(x):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
## go
|
|
||||||
class AttentionPool2d(nn.Module):
|
|
||||||
"""
|
|
||||||
Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
spacial_dim: int,
|
|
||||||
embed_dim: int,
|
|
||||||
num_heads_channels: int,
|
|
||||||
output_dim: int = None,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
|
|
||||||
self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
|
|
||||||
self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
|
|
||||||
self.num_heads = embed_dim // num_heads_channels
|
|
||||||
self.attention = QKVAttention(self.num_heads)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
b, c, *_spatial = x.shape
|
|
||||||
x = x.reshape(b, c, -1) # NC(HW)
|
|
||||||
x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
|
|
||||||
x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
|
|
||||||
x = self.qkv_proj(x)
|
|
||||||
x = self.attention(x)
|
|
||||||
x = self.c_proj(x)
|
|
||||||
return x[:, :, 0]
|
|
||||||
|
|
||||||
|
|
||||||
class TimestepBlock(nn.Module):
|
class TimestepBlock(nn.Module):
|
||||||
"""
|
"""
|
||||||
Any module where forward() takes timestep embeddings as a second argument.
|
Any module where forward() takes timestep embeddings as a second argument.
|
||||||
@ -138,19 +99,6 @@ class Upsample(nn.Module):
|
|||||||
x = self.conv(x)
|
x = self.conv(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
class TransposedUpsample(nn.Module):
|
|
||||||
'Learned 2x upsampling without padding'
|
|
||||||
def __init__(self, channels, out_channels=None, ks=5):
|
|
||||||
super().__init__()
|
|
||||||
self.channels = channels
|
|
||||||
self.out_channels = out_channels or channels
|
|
||||||
|
|
||||||
self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
|
|
||||||
|
|
||||||
def forward(self,x):
|
|
||||||
return self.up(x)
|
|
||||||
|
|
||||||
|
|
||||||
class Downsample(nn.Module):
|
class Downsample(nn.Module):
|
||||||
"""
|
"""
|
||||||
A downsampling layer with an optional convolution.
|
A downsampling layer with an optional convolution.
|
||||||
@ -296,142 +244,6 @@ class ResBlock(TimestepBlock):
|
|||||||
h = self.out_layers(h)
|
h = self.out_layers(h)
|
||||||
return self.skip_connection(x) + h
|
return self.skip_connection(x) + h
|
||||||
|
|
||||||
|
|
||||||
class AttentionBlock(nn.Module):
|
|
||||||
"""
|
|
||||||
An attention block that allows spatial positions to attend to each other.
|
|
||||||
Originally ported from here, but adapted to the N-d case.
|
|
||||||
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
channels,
|
|
||||||
num_heads=1,
|
|
||||||
num_head_channels=-1,
|
|
||||||
use_checkpoint=False,
|
|
||||||
use_new_attention_order=False,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.channels = channels
|
|
||||||
if num_head_channels == -1:
|
|
||||||
self.num_heads = num_heads
|
|
||||||
else:
|
|
||||||
assert (
|
|
||||||
channels % num_head_channels == 0
|
|
||||||
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
|
|
||||||
self.num_heads = channels // num_head_channels
|
|
||||||
self.use_checkpoint = use_checkpoint
|
|
||||||
self.norm = normalization(channels)
|
|
||||||
self.qkv = conv_nd(1, channels, channels * 3, 1)
|
|
||||||
if use_new_attention_order:
|
|
||||||
# split qkv before split heads
|
|
||||||
self.attention = QKVAttention(self.num_heads)
|
|
||||||
else:
|
|
||||||
# split heads before split qkv
|
|
||||||
self.attention = QKVAttentionLegacy(self.num_heads)
|
|
||||||
|
|
||||||
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
|
|
||||||
#return pt_checkpoint(self._forward, x) # pytorch
|
|
||||||
|
|
||||||
def _forward(self, x):
|
|
||||||
b, c, *spatial = x.shape
|
|
||||||
x = x.reshape(b, c, -1)
|
|
||||||
qkv = self.qkv(self.norm(x))
|
|
||||||
h = self.attention(qkv)
|
|
||||||
h = self.proj_out(h)
|
|
||||||
return (x + h).reshape(b, c, *spatial)
|
|
||||||
|
|
||||||
|
|
||||||
def count_flops_attn(model, _x, y):
|
|
||||||
"""
|
|
||||||
A counter for the `thop` package to count the operations in an
|
|
||||||
attention operation.
|
|
||||||
Meant to be used like:
|
|
||||||
macs, params = thop.profile(
|
|
||||||
model,
|
|
||||||
inputs=(inputs, timestamps),
|
|
||||||
custom_ops={QKVAttention: QKVAttention.count_flops},
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
b, c, *spatial = y[0].shape
|
|
||||||
num_spatial = int(np.prod(spatial))
|
|
||||||
# We perform two matmuls with the same number of ops.
|
|
||||||
# The first computes the weight matrix, the second computes
|
|
||||||
# the combination of the value vectors.
|
|
||||||
matmul_ops = 2 * b * (num_spatial ** 2) * c
|
|
||||||
model.total_ops += th.DoubleTensor([matmul_ops])
|
|
||||||
|
|
||||||
|
|
||||||
class QKVAttentionLegacy(nn.Module):
|
|
||||||
"""
|
|
||||||
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, n_heads):
|
|
||||||
super().__init__()
|
|
||||||
self.n_heads = n_heads
|
|
||||||
|
|
||||||
def forward(self, qkv):
|
|
||||||
"""
|
|
||||||
Apply QKV attention.
|
|
||||||
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
|
|
||||||
:return: an [N x (H * C) x T] tensor after attention.
|
|
||||||
"""
|
|
||||||
bs, width, length = qkv.shape
|
|
||||||
assert width % (3 * self.n_heads) == 0
|
|
||||||
ch = width // (3 * self.n_heads)
|
|
||||||
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
|
|
||||||
scale = 1 / math.sqrt(math.sqrt(ch))
|
|
||||||
weight = th.einsum(
|
|
||||||
"bct,bcs->bts", q * scale, k * scale
|
|
||||||
) # More stable with f16 than dividing afterwards
|
|
||||||
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
|
|
||||||
a = th.einsum("bts,bcs->bct", weight, v)
|
|
||||||
return a.reshape(bs, -1, length)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def count_flops(model, _x, y):
|
|
||||||
return count_flops_attn(model, _x, y)
|
|
||||||
|
|
||||||
|
|
||||||
class QKVAttention(nn.Module):
|
|
||||||
"""
|
|
||||||
A module which performs QKV attention and splits in a different order.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, n_heads):
|
|
||||||
super().__init__()
|
|
||||||
self.n_heads = n_heads
|
|
||||||
|
|
||||||
def forward(self, qkv):
|
|
||||||
"""
|
|
||||||
Apply QKV attention.
|
|
||||||
:param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
|
|
||||||
:return: an [N x (H * C) x T] tensor after attention.
|
|
||||||
"""
|
|
||||||
bs, width, length = qkv.shape
|
|
||||||
assert width % (3 * self.n_heads) == 0
|
|
||||||
ch = width // (3 * self.n_heads)
|
|
||||||
q, k, v = qkv.chunk(3, dim=1)
|
|
||||||
scale = 1 / math.sqrt(math.sqrt(ch))
|
|
||||||
weight = th.einsum(
|
|
||||||
"bct,bcs->bts",
|
|
||||||
(q * scale).view(bs * self.n_heads, ch, length),
|
|
||||||
(k * scale).view(bs * self.n_heads, ch, length),
|
|
||||||
) # More stable with f16 than dividing afterwards
|
|
||||||
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
|
|
||||||
a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
|
|
||||||
return a.reshape(bs, -1, length)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def count_flops(model, _x, y):
|
|
||||||
return count_flops_attn(model, _x, y)
|
|
||||||
|
|
||||||
|
|
||||||
class Timestep(nn.Module):
|
class Timestep(nn.Module):
|
||||||
def __init__(self, dim):
|
def __init__(self, dim):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -507,6 +319,7 @@ class UNetModel(nn.Module):
|
|||||||
device=None,
|
device=None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
assert use_spatial_transformer == True, "use_spatial_transformer has to be true"
|
||||||
if use_spatial_transformer:
|
if use_spatial_transformer:
|
||||||
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
|
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
|
||||||
|
|
||||||
@ -631,14 +444,7 @@ class UNetModel(nn.Module):
|
|||||||
disabled_sa = False
|
disabled_sa = False
|
||||||
|
|
||||||
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
|
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
|
||||||
layers.append(
|
layers.append(SpatialTransformer(
|
||||||
AttentionBlock(
|
|
||||||
ch,
|
|
||||||
use_checkpoint=use_checkpoint,
|
|
||||||
num_heads=num_heads,
|
|
||||||
num_head_channels=dim_head,
|
|
||||||
use_new_attention_order=use_new_attention_order,
|
|
||||||
) if not use_spatial_transformer else SpatialTransformer(
|
|
||||||
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
|
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
|
||||||
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
||||||
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device
|
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device
|
||||||
@ -693,13 +499,7 @@ class UNetModel(nn.Module):
|
|||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
device=device,
|
device=device,
|
||||||
),
|
),
|
||||||
AttentionBlock(
|
SpatialTransformer( # always uses a self-attn
|
||||||
ch,
|
|
||||||
use_checkpoint=use_checkpoint,
|
|
||||||
num_heads=num_heads,
|
|
||||||
num_head_channels=dim_head,
|
|
||||||
use_new_attention_order=use_new_attention_order,
|
|
||||||
) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
|
|
||||||
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
|
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
|
||||||
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
|
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
|
||||||
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device
|
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device
|
||||||
@ -751,13 +551,7 @@ class UNetModel(nn.Module):
|
|||||||
|
|
||||||
if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
|
if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
|
||||||
layers.append(
|
layers.append(
|
||||||
AttentionBlock(
|
SpatialTransformer(
|
||||||
ch,
|
|
||||||
use_checkpoint=use_checkpoint,
|
|
||||||
num_heads=num_heads_upsample,
|
|
||||||
num_head_channels=dim_head,
|
|
||||||
use_new_attention_order=use_new_attention_order,
|
|
||||||
) if not use_spatial_transformer else SpatialTransformer(
|
|
||||||
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
|
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
|
||||||
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
||||||
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device
|
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device
|
||||||
@ -797,22 +591,6 @@ class UNetModel(nn.Module):
|
|||||||
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
|
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
|
||||||
)
|
)
|
||||||
|
|
||||||
def convert_to_fp16(self):
|
|
||||||
"""
|
|
||||||
Convert the torso of the model to float16.
|
|
||||||
"""
|
|
||||||
self.input_blocks.apply(convert_module_to_f16)
|
|
||||||
self.middle_block.apply(convert_module_to_f16)
|
|
||||||
self.output_blocks.apply(convert_module_to_f16)
|
|
||||||
|
|
||||||
def convert_to_fp32(self):
|
|
||||||
"""
|
|
||||||
Convert the torso of the model to float32.
|
|
||||||
"""
|
|
||||||
self.input_blocks.apply(convert_module_to_f32)
|
|
||||||
self.middle_block.apply(convert_module_to_f32)
|
|
||||||
self.output_blocks.apply(convert_module_to_f32)
|
|
||||||
|
|
||||||
def forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs):
|
def forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs):
|
||||||
"""
|
"""
|
||||||
Apply the model to an input batch.
|
Apply the model to an input batch.
|
||||||
|
Loading…
Reference in New Issue
Block a user