mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-20 03:13:30 +00:00
Refactor comfy.ops
comfy.ops -> comfy.ops.disable_weight_init This should make it more clear what they actually do. Some unused code has also been removed.
This commit is contained in:
parent
b0aab1e4ea
commit
77755ab8db
@ -53,7 +53,7 @@ class ControlNet(nn.Module):
|
|||||||
transformer_depth_middle=None,
|
transformer_depth_middle=None,
|
||||||
transformer_depth_output=None,
|
transformer_depth_output=None,
|
||||||
device=None,
|
device=None,
|
||||||
operations=comfy.ops,
|
operations=comfy.ops.disable_weight_init,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -38,7 +38,7 @@ class ClipVisionModel():
|
|||||||
if comfy.model_management.should_use_fp16(self.load_device, prioritize_performance=False):
|
if comfy.model_management.should_use_fp16(self.load_device, prioritize_performance=False):
|
||||||
self.dtype = torch.float16
|
self.dtype = torch.float16
|
||||||
|
|
||||||
self.model = comfy.clip_model.CLIPVisionModelProjection(config, self.dtype, offload_device, comfy.ops)
|
self.model = comfy.clip_model.CLIPVisionModelProjection(config, self.dtype, offload_device, comfy.ops.disable_weight_init)
|
||||||
|
|
||||||
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||||
def load_sd(self, sd):
|
def load_sd(self, sd):
|
||||||
|
@ -208,9 +208,9 @@ class ControlLoraOps:
|
|||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
if self.up is not None:
|
if self.up is not None:
|
||||||
return torch.nn.functional.linear(input, self.weight.to(input.dtype).to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias)
|
return torch.nn.functional.linear(input, self.weight.to(dtype=input.dtype, device=input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias)
|
||||||
else:
|
else:
|
||||||
return torch.nn.functional.linear(input, self.weight.to(input.device), self.bias)
|
return torch.nn.functional.linear(input, self.weight.to(dtype=input.dtype, device=input.device), self.bias)
|
||||||
|
|
||||||
class Conv2d(torch.nn.Module):
|
class Conv2d(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -247,24 +247,9 @@ class ControlLoraOps:
|
|||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
if self.up is not None:
|
if self.up is not None:
|
||||||
return torch.nn.functional.conv2d(input, self.weight.to(input.dtype).to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias, self.stride, self.padding, self.dilation, self.groups)
|
return torch.nn.functional.conv2d(input, self.weight.to(dtype=input.dtype, device=input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||||
else:
|
else:
|
||||||
return torch.nn.functional.conv2d(input, self.weight.to(input.device), self.bias, self.stride, self.padding, self.dilation, self.groups)
|
return torch.nn.functional.conv2d(input, self.weight.to(dtype=input.dtype, device=input.device), self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||||
|
|
||||||
def conv_nd(self, dims, *args, **kwargs):
|
|
||||||
if dims == 2:
|
|
||||||
return self.Conv2d(*args, **kwargs)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"unsupported dimensions: {dims}")
|
|
||||||
|
|
||||||
class Conv3d(comfy.ops.Conv3d):
|
|
||||||
pass
|
|
||||||
|
|
||||||
class GroupNorm(comfy.ops.GroupNorm):
|
|
||||||
pass
|
|
||||||
|
|
||||||
class LayerNorm(comfy.ops.LayerNorm):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ControlLora(ControlNet):
|
class ControlLora(ControlNet):
|
||||||
@ -278,7 +263,9 @@ class ControlLora(ControlNet):
|
|||||||
controlnet_config = model.model_config.unet_config.copy()
|
controlnet_config = model.model_config.unet_config.copy()
|
||||||
controlnet_config.pop("out_channels")
|
controlnet_config.pop("out_channels")
|
||||||
controlnet_config["hint_channels"] = self.control_weights["input_hint_block.0.weight"].shape[1]
|
controlnet_config["hint_channels"] = self.control_weights["input_hint_block.0.weight"].shape[1]
|
||||||
controlnet_config["operations"] = ControlLoraOps()
|
class control_lora_ops(ControlLoraOps, comfy.ops.disable_weight_init):
|
||||||
|
pass
|
||||||
|
controlnet_config["operations"] = control_lora_ops
|
||||||
self.control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
|
self.control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
|
||||||
dtype = model.get_dtype()
|
dtype = model.get_dtype()
|
||||||
self.control_model.to(dtype)
|
self.control_model.to(dtype)
|
||||||
|
@ -19,6 +19,7 @@ if model_management.xformers_enabled():
|
|||||||
|
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
|
ops = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
# CrossAttn precision handling
|
# CrossAttn precision handling
|
||||||
if args.dont_upcast_attention:
|
if args.dont_upcast_attention:
|
||||||
@ -55,7 +56,7 @@ def init_(tensor):
|
|||||||
|
|
||||||
# feedforward
|
# feedforward
|
||||||
class GEGLU(nn.Module):
|
class GEGLU(nn.Module):
|
||||||
def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=comfy.ops):
|
def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=ops):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.proj = operations.Linear(dim_in, dim_out * 2, dtype=dtype, device=device)
|
self.proj = operations.Linear(dim_in, dim_out * 2, dtype=dtype, device=device)
|
||||||
|
|
||||||
@ -65,7 +66,7 @@ class GEGLU(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FeedForward(nn.Module):
|
class FeedForward(nn.Module):
|
||||||
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0., dtype=None, device=None, operations=comfy.ops):
|
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0., dtype=None, device=None, operations=ops):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
inner_dim = int(dim * mult)
|
inner_dim = int(dim * mult)
|
||||||
dim_out = default(dim_out, dim)
|
dim_out = default(dim_out, dim)
|
||||||
@ -356,7 +357,7 @@ def optimized_attention_for_device(device, mask=False):
|
|||||||
|
|
||||||
|
|
||||||
class CrossAttention(nn.Module):
|
class CrossAttention(nn.Module):
|
||||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops):
|
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=ops):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
inner_dim = dim_head * heads
|
inner_dim = dim_head * heads
|
||||||
context_dim = default(context_dim, query_dim)
|
context_dim = default(context_dim, query_dim)
|
||||||
@ -389,7 +390,7 @@ class CrossAttention(nn.Module):
|
|||||||
|
|
||||||
class BasicTransformerBlock(nn.Module):
|
class BasicTransformerBlock(nn.Module):
|
||||||
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, ff_in=False, inner_dim=None,
|
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, ff_in=False, inner_dim=None,
|
||||||
disable_self_attn=False, disable_temporal_crossattention=False, switch_temporal_ca_to_sa=False, dtype=None, device=None, operations=comfy.ops):
|
disable_self_attn=False, disable_temporal_crossattention=False, switch_temporal_ca_to_sa=False, dtype=None, device=None, operations=ops):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.ff_in = ff_in or inner_dim is not None
|
self.ff_in = ff_in or inner_dim is not None
|
||||||
@ -558,7 +559,7 @@ class SpatialTransformer(nn.Module):
|
|||||||
def __init__(self, in_channels, n_heads, d_head,
|
def __init__(self, in_channels, n_heads, d_head,
|
||||||
depth=1, dropout=0., context_dim=None,
|
depth=1, dropout=0., context_dim=None,
|
||||||
disable_self_attn=False, use_linear=False,
|
disable_self_attn=False, use_linear=False,
|
||||||
use_checkpoint=True, dtype=None, device=None, operations=comfy.ops):
|
use_checkpoint=True, dtype=None, device=None, operations=ops):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if exists(context_dim) and not isinstance(context_dim, list):
|
if exists(context_dim) and not isinstance(context_dim, list):
|
||||||
context_dim = [context_dim] * depth
|
context_dim = [context_dim] * depth
|
||||||
@ -632,7 +633,7 @@ class SpatialVideoTransformer(SpatialTransformer):
|
|||||||
disable_self_attn=False,
|
disable_self_attn=False,
|
||||||
disable_temporal_crossattention=False,
|
disable_temporal_crossattention=False,
|
||||||
max_time_embed_period: int = 10000,
|
max_time_embed_period: int = 10000,
|
||||||
dtype=None, device=None, operations=comfy.ops
|
dtype=None, device=None, operations=ops
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
in_channels,
|
in_channels,
|
||||||
|
@ -8,6 +8,7 @@ from typing import Optional, Any
|
|||||||
|
|
||||||
from comfy import model_management
|
from comfy import model_management
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
|
ops = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
if model_management.xformers_enabled_vae():
|
if model_management.xformers_enabled_vae():
|
||||||
import xformers
|
import xformers
|
||||||
@ -48,7 +49,7 @@ class Upsample(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.with_conv = with_conv
|
self.with_conv = with_conv
|
||||||
if self.with_conv:
|
if self.with_conv:
|
||||||
self.conv = comfy.ops.Conv2d(in_channels,
|
self.conv = ops.Conv2d(in_channels,
|
||||||
in_channels,
|
in_channels,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
stride=1,
|
stride=1,
|
||||||
@ -78,7 +79,7 @@ class Downsample(nn.Module):
|
|||||||
self.with_conv = with_conv
|
self.with_conv = with_conv
|
||||||
if self.with_conv:
|
if self.with_conv:
|
||||||
# no asymmetric padding in torch conv, must do it ourselves
|
# no asymmetric padding in torch conv, must do it ourselves
|
||||||
self.conv = comfy.ops.Conv2d(in_channels,
|
self.conv = ops.Conv2d(in_channels,
|
||||||
in_channels,
|
in_channels,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
stride=2,
|
stride=2,
|
||||||
@ -105,30 +106,30 @@ class ResnetBlock(nn.Module):
|
|||||||
|
|
||||||
self.swish = torch.nn.SiLU(inplace=True)
|
self.swish = torch.nn.SiLU(inplace=True)
|
||||||
self.norm1 = Normalize(in_channels)
|
self.norm1 = Normalize(in_channels)
|
||||||
self.conv1 = comfy.ops.Conv2d(in_channels,
|
self.conv1 = ops.Conv2d(in_channels,
|
||||||
out_channels,
|
out_channels,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=1)
|
padding=1)
|
||||||
if temb_channels > 0:
|
if temb_channels > 0:
|
||||||
self.temb_proj = comfy.ops.Linear(temb_channels,
|
self.temb_proj = ops.Linear(temb_channels,
|
||||||
out_channels)
|
out_channels)
|
||||||
self.norm2 = Normalize(out_channels)
|
self.norm2 = Normalize(out_channels)
|
||||||
self.dropout = torch.nn.Dropout(dropout, inplace=True)
|
self.dropout = torch.nn.Dropout(dropout, inplace=True)
|
||||||
self.conv2 = comfy.ops.Conv2d(out_channels,
|
self.conv2 = ops.Conv2d(out_channels,
|
||||||
out_channels,
|
out_channels,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=1)
|
padding=1)
|
||||||
if self.in_channels != self.out_channels:
|
if self.in_channels != self.out_channels:
|
||||||
if self.use_conv_shortcut:
|
if self.use_conv_shortcut:
|
||||||
self.conv_shortcut = comfy.ops.Conv2d(in_channels,
|
self.conv_shortcut = ops.Conv2d(in_channels,
|
||||||
out_channels,
|
out_channels,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=1)
|
padding=1)
|
||||||
else:
|
else:
|
||||||
self.nin_shortcut = comfy.ops.Conv2d(in_channels,
|
self.nin_shortcut = ops.Conv2d(in_channels,
|
||||||
out_channels,
|
out_channels,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
@ -245,22 +246,22 @@ class AttnBlock(nn.Module):
|
|||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
|
|
||||||
self.norm = Normalize(in_channels)
|
self.norm = Normalize(in_channels)
|
||||||
self.q = comfy.ops.Conv2d(in_channels,
|
self.q = ops.Conv2d(in_channels,
|
||||||
in_channels,
|
in_channels,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=0)
|
padding=0)
|
||||||
self.k = comfy.ops.Conv2d(in_channels,
|
self.k = ops.Conv2d(in_channels,
|
||||||
in_channels,
|
in_channels,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=0)
|
padding=0)
|
||||||
self.v = comfy.ops.Conv2d(in_channels,
|
self.v = ops.Conv2d(in_channels,
|
||||||
in_channels,
|
in_channels,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=0)
|
padding=0)
|
||||||
self.proj_out = comfy.ops.Conv2d(in_channels,
|
self.proj_out = ops.Conv2d(in_channels,
|
||||||
in_channels,
|
in_channels,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
@ -312,14 +313,14 @@ class Model(nn.Module):
|
|||||||
# timestep embedding
|
# timestep embedding
|
||||||
self.temb = nn.Module()
|
self.temb = nn.Module()
|
||||||
self.temb.dense = nn.ModuleList([
|
self.temb.dense = nn.ModuleList([
|
||||||
comfy.ops.Linear(self.ch,
|
ops.Linear(self.ch,
|
||||||
self.temb_ch),
|
self.temb_ch),
|
||||||
comfy.ops.Linear(self.temb_ch,
|
ops.Linear(self.temb_ch,
|
||||||
self.temb_ch),
|
self.temb_ch),
|
||||||
])
|
])
|
||||||
|
|
||||||
# downsampling
|
# downsampling
|
||||||
self.conv_in = comfy.ops.Conv2d(in_channels,
|
self.conv_in = ops.Conv2d(in_channels,
|
||||||
self.ch,
|
self.ch,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
stride=1,
|
stride=1,
|
||||||
@ -388,7 +389,7 @@ class Model(nn.Module):
|
|||||||
|
|
||||||
# end
|
# end
|
||||||
self.norm_out = Normalize(block_in)
|
self.norm_out = Normalize(block_in)
|
||||||
self.conv_out = comfy.ops.Conv2d(block_in,
|
self.conv_out = ops.Conv2d(block_in,
|
||||||
out_ch,
|
out_ch,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
stride=1,
|
stride=1,
|
||||||
@ -461,7 +462,7 @@ class Encoder(nn.Module):
|
|||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
|
|
||||||
# downsampling
|
# downsampling
|
||||||
self.conv_in = comfy.ops.Conv2d(in_channels,
|
self.conv_in = ops.Conv2d(in_channels,
|
||||||
self.ch,
|
self.ch,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
stride=1,
|
stride=1,
|
||||||
@ -506,7 +507,7 @@ class Encoder(nn.Module):
|
|||||||
|
|
||||||
# end
|
# end
|
||||||
self.norm_out = Normalize(block_in)
|
self.norm_out = Normalize(block_in)
|
||||||
self.conv_out = comfy.ops.Conv2d(block_in,
|
self.conv_out = ops.Conv2d(block_in,
|
||||||
2*z_channels if double_z else z_channels,
|
2*z_channels if double_z else z_channels,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
stride=1,
|
stride=1,
|
||||||
@ -541,7 +542,7 @@ class Decoder(nn.Module):
|
|||||||
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
||||||
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
||||||
resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
|
resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
|
||||||
conv_out_op=comfy.ops.Conv2d,
|
conv_out_op=ops.Conv2d,
|
||||||
resnet_op=ResnetBlock,
|
resnet_op=ResnetBlock,
|
||||||
attn_op=AttnBlock,
|
attn_op=AttnBlock,
|
||||||
**ignorekwargs):
|
**ignorekwargs):
|
||||||
@ -565,7 +566,7 @@ class Decoder(nn.Module):
|
|||||||
self.z_shape, np.prod(self.z_shape)))
|
self.z_shape, np.prod(self.z_shape)))
|
||||||
|
|
||||||
# z to block_in
|
# z to block_in
|
||||||
self.conv_in = comfy.ops.Conv2d(z_channels,
|
self.conv_in = ops.Conv2d(z_channels,
|
||||||
block_in,
|
block_in,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
stride=1,
|
stride=1,
|
||||||
|
@ -12,13 +12,13 @@ from .util import (
|
|||||||
checkpoint,
|
checkpoint,
|
||||||
avg_pool_nd,
|
avg_pool_nd,
|
||||||
zero_module,
|
zero_module,
|
||||||
normalization,
|
|
||||||
timestep_embedding,
|
timestep_embedding,
|
||||||
AlphaBlender,
|
AlphaBlender,
|
||||||
)
|
)
|
||||||
from ..attention import SpatialTransformer, SpatialVideoTransformer, default
|
from ..attention import SpatialTransformer, SpatialVideoTransformer, default
|
||||||
from comfy.ldm.util import exists
|
from comfy.ldm.util import exists
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
|
ops = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
class TimestepBlock(nn.Module):
|
class TimestepBlock(nn.Module):
|
||||||
"""
|
"""
|
||||||
@ -70,7 +70,7 @@ class Upsample(nn.Module):
|
|||||||
upsampling occurs in the inner-two dimensions.
|
upsampling occurs in the inner-two dimensions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None, operations=comfy.ops):
|
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None, operations=ops):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
self.out_channels = out_channels or channels
|
self.out_channels = out_channels or channels
|
||||||
@ -106,7 +106,7 @@ class Downsample(nn.Module):
|
|||||||
downsampling occurs in the inner-two dimensions.
|
downsampling occurs in the inner-two dimensions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None, operations=comfy.ops):
|
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None, operations=ops):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
self.out_channels = out_channels or channels
|
self.out_channels = out_channels or channels
|
||||||
@ -159,7 +159,7 @@ class ResBlock(TimestepBlock):
|
|||||||
skip_t_emb=False,
|
skip_t_emb=False,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
device=None,
|
device=None,
|
||||||
operations=comfy.ops
|
operations=ops
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
@ -284,7 +284,7 @@ class VideoResBlock(ResBlock):
|
|||||||
down: bool = False,
|
down: bool = False,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
device=None,
|
device=None,
|
||||||
operations=comfy.ops
|
operations=ops
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
channels,
|
channels,
|
||||||
@ -434,7 +434,7 @@ class UNetModel(nn.Module):
|
|||||||
disable_temporal_crossattention=False,
|
disable_temporal_crossattention=False,
|
||||||
max_ddpm_temb_period=10000,
|
max_ddpm_temb_period=10000,
|
||||||
device=None,
|
device=None,
|
||||||
operations=comfy.ops,
|
operations=ops,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert use_spatial_transformer == True, "use_spatial_transformer has to be true"
|
assert use_spatial_transformer == True, "use_spatial_transformer has to be true"
|
||||||
@ -581,7 +581,7 @@ class UNetModel(nn.Module):
|
|||||||
up=False,
|
up=False,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
device=None,
|
device=None,
|
||||||
operations=comfy.ops
|
operations=ops
|
||||||
):
|
):
|
||||||
if self.use_temporal_resblocks:
|
if self.use_temporal_resblocks:
|
||||||
return VideoResBlock(
|
return VideoResBlock(
|
||||||
|
@ -16,7 +16,6 @@ import numpy as np
|
|||||||
from einops import repeat, rearrange
|
from einops import repeat, rearrange
|
||||||
|
|
||||||
from comfy.ldm.util import instantiate_from_config
|
from comfy.ldm.util import instantiate_from_config
|
||||||
import comfy.ops
|
|
||||||
|
|
||||||
class AlphaBlender(nn.Module):
|
class AlphaBlender(nn.Module):
|
||||||
strategies = ["learned", "fixed", "learned_with_images"]
|
strategies = ["learned", "fixed", "learned_with_images"]
|
||||||
@ -273,46 +272,6 @@ def mean_flat(tensor):
|
|||||||
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
||||||
|
|
||||||
|
|
||||||
def normalization(channels, dtype=None):
|
|
||||||
"""
|
|
||||||
Make a standard normalization layer.
|
|
||||||
:param channels: number of input channels.
|
|
||||||
:return: an nn.Module for normalization.
|
|
||||||
"""
|
|
||||||
return GroupNorm32(32, channels, dtype=dtype)
|
|
||||||
|
|
||||||
|
|
||||||
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
|
|
||||||
class SiLU(nn.Module):
|
|
||||||
def forward(self, x):
|
|
||||||
return x * torch.sigmoid(x)
|
|
||||||
|
|
||||||
|
|
||||||
class GroupNorm32(nn.GroupNorm):
|
|
||||||
def forward(self, x):
|
|
||||||
return super().forward(x.float()).type(x.dtype)
|
|
||||||
|
|
||||||
|
|
||||||
def conv_nd(dims, *args, **kwargs):
|
|
||||||
"""
|
|
||||||
Create a 1D, 2D, or 3D convolution module.
|
|
||||||
"""
|
|
||||||
if dims == 1:
|
|
||||||
return nn.Conv1d(*args, **kwargs)
|
|
||||||
elif dims == 2:
|
|
||||||
return comfy.ops.Conv2d(*args, **kwargs)
|
|
||||||
elif dims == 3:
|
|
||||||
return nn.Conv3d(*args, **kwargs)
|
|
||||||
raise ValueError(f"unsupported dimensions: {dims}")
|
|
||||||
|
|
||||||
|
|
||||||
def linear(*args, **kwargs):
|
|
||||||
"""
|
|
||||||
Create a linear module.
|
|
||||||
"""
|
|
||||||
return comfy.ops.Linear(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def avg_pool_nd(dims, *args, **kwargs):
|
def avg_pool_nd(dims, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Create a 1D, 2D, or 3D average pooling module.
|
Create a 1D, 2D, or 3D average pooling module.
|
||||||
|
@ -5,6 +5,7 @@ import torch
|
|||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
|
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
|
ops = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
from .diffusionmodules.model import (
|
from .diffusionmodules.model import (
|
||||||
AttnBlock,
|
AttnBlock,
|
||||||
@ -130,9 +131,9 @@ class AttnVideoBlock(AttnBlock):
|
|||||||
|
|
||||||
time_embed_dim = self.in_channels * 4
|
time_embed_dim = self.in_channels * 4
|
||||||
self.video_time_embed = torch.nn.Sequential(
|
self.video_time_embed = torch.nn.Sequential(
|
||||||
comfy.ops.Linear(self.in_channels, time_embed_dim),
|
ops.Linear(self.in_channels, time_embed_dim),
|
||||||
torch.nn.SiLU(),
|
torch.nn.SiLU(),
|
||||||
comfy.ops.Linear(time_embed_dim, self.in_channels),
|
ops.Linear(time_embed_dim, self.in_channels),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.merge_strategy = merge_strategy
|
self.merge_strategy = merge_strategy
|
||||||
|
@ -48,7 +48,7 @@ class BaseModel(torch.nn.Module):
|
|||||||
if self.manual_cast_dtype is not None:
|
if self.manual_cast_dtype is not None:
|
||||||
operations = comfy.ops.manual_cast
|
operations = comfy.ops.manual_cast
|
||||||
else:
|
else:
|
||||||
operations = comfy.ops
|
operations = comfy.ops.disable_weight_init
|
||||||
self.diffusion_model = UNetModel(**unet_config, device=device, operations=operations)
|
self.diffusion_model = UNetModel(**unet_config, device=device, operations=operations)
|
||||||
self.model_type = model_type
|
self.model_type = model_type
|
||||||
self.model_sampling = model_sampling(model_config, model_type)
|
self.model_sampling = model_sampling(model_config, model_type)
|
||||||
|
99
comfy/ops.py
99
comfy/ops.py
@ -1,67 +1,27 @@
|
|||||||
import torch
|
import torch
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
|
||||||
class Linear(torch.nn.Linear):
|
class disable_weight_init:
|
||||||
|
class Linear(torch.nn.Linear):
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
class Conv2d(torch.nn.Conv2d):
|
class Conv2d(torch.nn.Conv2d):
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
class Conv3d(torch.nn.Conv3d):
|
class Conv3d(torch.nn.Conv3d):
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
class GroupNorm(torch.nn.GroupNorm):
|
class GroupNorm(torch.nn.GroupNorm):
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
class LayerNorm(torch.nn.LayerNorm):
|
class LayerNorm(torch.nn.LayerNorm):
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def conv_nd(dims, *args, **kwargs):
|
|
||||||
if dims == 2:
|
|
||||||
return Conv2d(*args, **kwargs)
|
|
||||||
elif dims == 3:
|
|
||||||
return Conv3d(*args, **kwargs)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"unsupported dimensions: {dims}")
|
|
||||||
|
|
||||||
def cast_bias_weight(s, input):
|
|
||||||
bias = None
|
|
||||||
if s.bias is not None:
|
|
||||||
bias = s.bias.to(device=input.device, dtype=input.dtype)
|
|
||||||
weight = s.weight.to(device=input.device, dtype=input.dtype)
|
|
||||||
return weight, bias
|
|
||||||
|
|
||||||
class manual_cast:
|
|
||||||
class Linear(Linear):
|
|
||||||
def forward(self, input):
|
|
||||||
weight, bias = cast_bias_weight(self, input)
|
|
||||||
return torch.nn.functional.linear(input, weight, bias)
|
|
||||||
|
|
||||||
class Conv2d(Conv2d):
|
|
||||||
def forward(self, input):
|
|
||||||
weight, bias = cast_bias_weight(self, input)
|
|
||||||
return self._conv_forward(input, weight, bias)
|
|
||||||
|
|
||||||
class Conv3d(Conv3d):
|
|
||||||
def forward(self, input):
|
|
||||||
weight, bias = cast_bias_weight(self, input)
|
|
||||||
return self._conv_forward(input, weight, bias)
|
|
||||||
|
|
||||||
class GroupNorm(GroupNorm):
|
|
||||||
def forward(self, input):
|
|
||||||
weight, bias = cast_bias_weight(self, input)
|
|
||||||
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
|
|
||||||
|
|
||||||
class LayerNorm(LayerNorm):
|
|
||||||
def forward(self, input):
|
|
||||||
weight, bias = cast_bias_weight(self, input)
|
|
||||||
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def conv_nd(s, dims, *args, **kwargs):
|
def conv_nd(s, dims, *args, **kwargs):
|
||||||
if dims == 2:
|
if dims == 2:
|
||||||
@ -71,20 +31,35 @@ class manual_cast:
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"unsupported dimensions: {dims}")
|
raise ValueError(f"unsupported dimensions: {dims}")
|
||||||
|
|
||||||
@contextmanager
|
def cast_bias_weight(s, input):
|
||||||
def use_comfy_ops(device=None, dtype=None): # Kind of an ugly hack but I can't think of a better way
|
bias = None
|
||||||
old_torch_nn_linear = torch.nn.Linear
|
if s.bias is not None:
|
||||||
force_device = device
|
bias = s.bias.to(device=input.device, dtype=input.dtype)
|
||||||
force_dtype = dtype
|
weight = s.weight.to(device=input.device, dtype=input.dtype)
|
||||||
def linear_with_dtype(in_features: int, out_features: int, bias: bool = True, device=None, dtype=None):
|
return weight, bias
|
||||||
if force_device is not None:
|
|
||||||
device = force_device
|
|
||||||
if force_dtype is not None:
|
|
||||||
dtype = force_dtype
|
|
||||||
return Linear(in_features, out_features, bias=bias, device=device, dtype=dtype)
|
|
||||||
|
|
||||||
torch.nn.Linear = linear_with_dtype
|
class manual_cast(disable_weight_init):
|
||||||
try:
|
class Linear(disable_weight_init.Linear):
|
||||||
yield
|
def forward(self, input):
|
||||||
finally:
|
weight, bias = cast_bias_weight(self, input)
|
||||||
torch.nn.Linear = old_torch_nn_linear
|
return torch.nn.functional.linear(input, weight, bias)
|
||||||
|
|
||||||
|
class Conv2d(disable_weight_init.Conv2d):
|
||||||
|
def forward(self, input):
|
||||||
|
weight, bias = cast_bias_weight(self, input)
|
||||||
|
return self._conv_forward(input, weight, bias)
|
||||||
|
|
||||||
|
class Conv3d(disable_weight_init.Conv3d):
|
||||||
|
def forward(self, input):
|
||||||
|
weight, bias = cast_bias_weight(self, input)
|
||||||
|
return self._conv_forward(input, weight, bias)
|
||||||
|
|
||||||
|
class GroupNorm(disable_weight_init.GroupNorm):
|
||||||
|
def forward(self, input):
|
||||||
|
weight, bias = cast_bias_weight(self, input)
|
||||||
|
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
|
||||||
|
|
||||||
|
class LayerNorm(disable_weight_init.LayerNorm):
|
||||||
|
def forward(self, input):
|
||||||
|
weight, bias = cast_bias_weight(self, input)
|
||||||
|
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
|
||||||
|
Loading…
Reference in New Issue
Block a user