mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
All the unet weights should now be initialized with the right dtype.
This commit is contained in:
parent
cf3974c829
commit
ae43f09ef7
@ -51,9 +51,9 @@ def init_(tensor):
|
||||
|
||||
# feedforward
|
||||
class GEGLU(nn.Module):
|
||||
def __init__(self, dim_in, dim_out):
|
||||
def __init__(self, dim_in, dim_out, dtype=None):
|
||||
super().__init__()
|
||||
self.proj = comfy.ops.Linear(dim_in, dim_out * 2)
|
||||
self.proj = comfy.ops.Linear(dim_in, dim_out * 2, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||
@ -68,7 +68,7 @@ class FeedForward(nn.Module):
|
||||
project_in = nn.Sequential(
|
||||
comfy.ops.Linear(dim, inner_dim, dtype=dtype),
|
||||
nn.GELU()
|
||||
) if not glu else GEGLU(dim, inner_dim)
|
||||
) if not glu else GEGLU(dim, inner_dim, dtype=dtype)
|
||||
|
||||
self.net = nn.Sequential(
|
||||
project_in,
|
||||
@ -89,8 +89,8 @@ def zero_module(module):
|
||||
return module
|
||||
|
||||
|
||||
def Normalize(in_channels):
|
||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
def Normalize(in_channels, dtype=None):
|
||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype)
|
||||
|
||||
|
||||
class SpatialSelfAttention(nn.Module):
|
||||
@ -594,7 +594,7 @@ class SpatialTransformer(nn.Module):
|
||||
context_dim = [context_dim]
|
||||
self.in_channels = in_channels
|
||||
inner_dim = n_heads * d_head
|
||||
self.norm = Normalize(in_channels)
|
||||
self.norm = Normalize(in_channels, dtype=dtype)
|
||||
if not use_linear:
|
||||
self.proj_in = nn.Conv2d(in_channels,
|
||||
inner_dim,
|
||||
|
@ -111,14 +111,14 @@ class Upsample(nn.Module):
|
||||
upsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
|
||||
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.dims = dims
|
||||
if use_conv:
|
||||
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
|
||||
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding, dtype=dtype)
|
||||
|
||||
def forward(self, x, output_shape=None):
|
||||
assert x.shape[1] == self.channels
|
||||
@ -160,7 +160,7 @@ class Downsample(nn.Module):
|
||||
downsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
|
||||
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
@ -169,7 +169,7 @@ class Downsample(nn.Module):
|
||||
stride = 2 if dims != 3 else (1, 2, 2)
|
||||
if use_conv:
|
||||
self.op = conv_nd(
|
||||
dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
|
||||
dims, self.channels, self.out_channels, 3, stride=stride, padding=padding, dtype=dtype
|
||||
)
|
||||
else:
|
||||
assert self.channels == self.out_channels
|
||||
@ -220,7 +220,7 @@ class ResBlock(TimestepBlock):
|
||||
self.use_scale_shift_norm = use_scale_shift_norm
|
||||
|
||||
self.in_layers = nn.Sequential(
|
||||
normalization(channels),
|
||||
normalization(channels, dtype=dtype),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=dtype),
|
||||
)
|
||||
@ -228,11 +228,11 @@ class ResBlock(TimestepBlock):
|
||||
self.updown = up or down
|
||||
|
||||
if up:
|
||||
self.h_upd = Upsample(channels, False, dims)
|
||||
self.x_upd = Upsample(channels, False, dims)
|
||||
self.h_upd = Upsample(channels, False, dims, dtype=dtype)
|
||||
self.x_upd = Upsample(channels, False, dims, dtype=dtype)
|
||||
elif down:
|
||||
self.h_upd = Downsample(channels, False, dims)
|
||||
self.x_upd = Downsample(channels, False, dims)
|
||||
self.h_upd = Downsample(channels, False, dims, dtype=dtype)
|
||||
self.x_upd = Downsample(channels, False, dims, dtype=dtype)
|
||||
else:
|
||||
self.h_upd = self.x_upd = nn.Identity()
|
||||
|
||||
@ -240,11 +240,11 @@ class ResBlock(TimestepBlock):
|
||||
nn.SiLU(),
|
||||
linear(
|
||||
emb_channels,
|
||||
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
|
||||
2 * self.out_channels if use_scale_shift_norm else self.out_channels, dtype=dtype
|
||||
),
|
||||
)
|
||||
self.out_layers = nn.Sequential(
|
||||
normalization(self.out_channels),
|
||||
normalization(self.out_channels, dtype=dtype),
|
||||
nn.SiLU(),
|
||||
nn.Dropout(p=dropout),
|
||||
zero_module(
|
||||
@ -604,6 +604,7 @@ class UNetModel(nn.Module):
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
dtype=self.dtype
|
||||
)
|
||||
]
|
||||
ch = mult * model_channels
|
||||
@ -651,10 +652,11 @@ class UNetModel(nn.Module):
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
down=True,
|
||||
dtype=self.dtype
|
||||
)
|
||||
if resblock_updown
|
||||
else Downsample(
|
||||
ch, conv_resample, dims=dims, out_channels=out_ch
|
||||
ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype
|
||||
)
|
||||
)
|
||||
)
|
||||
@ -679,6 +681,7 @@ class UNetModel(nn.Module):
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
dtype=self.dtype
|
||||
),
|
||||
AttentionBlock(
|
||||
ch,
|
||||
@ -698,6 +701,7 @@ class UNetModel(nn.Module):
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
dtype=self.dtype
|
||||
),
|
||||
)
|
||||
self._feature_size += ch
|
||||
@ -715,6 +719,7 @@ class UNetModel(nn.Module):
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
dtype=self.dtype
|
||||
)
|
||||
]
|
||||
ch = model_channels * mult
|
||||
@ -758,18 +763,19 @@ class UNetModel(nn.Module):
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
up=True,
|
||||
dtype=self.dtype
|
||||
)
|
||||
if resblock_updown
|
||||
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
|
||||
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype)
|
||||
)
|
||||
ds //= 2
|
||||
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self._feature_size += ch
|
||||
|
||||
self.out = nn.Sequential(
|
||||
normalization(ch),
|
||||
normalization(ch, dtype=self.dtype),
|
||||
nn.SiLU(),
|
||||
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
|
||||
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype)),
|
||||
)
|
||||
if self.predict_codebook_ids:
|
||||
self.id_predictor = nn.Sequential(
|
||||
|
@ -206,13 +206,13 @@ def mean_flat(tensor):
|
||||
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
||||
|
||||
|
||||
def normalization(channels):
|
||||
def normalization(channels, dtype=None):
|
||||
"""
|
||||
Make a standard normalization layer.
|
||||
:param channels: number of input channels.
|
||||
:return: an nn.Module for normalization.
|
||||
"""
|
||||
return GroupNorm32(32, channels)
|
||||
return GroupNorm32(32, channels, dtype=dtype)
|
||||
|
||||
|
||||
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
|
||||
|
Loading…
Reference in New Issue
Block a user