All the unet weights should now be initialized with the right dtype.

This commit is contained in:
comfyanonymous 2023-06-15 18:42:30 -04:00
parent cf3974c829
commit ae43f09ef7
3 changed files with 29 additions and 23 deletions

View File

@ -51,9 +51,9 @@ def init_(tensor):
# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
def __init__(self, dim_in, dim_out, dtype=None):
super().__init__()
self.proj = comfy.ops.Linear(dim_in, dim_out * 2)
self.proj = comfy.ops.Linear(dim_in, dim_out * 2, dtype=dtype)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
@ -68,7 +68,7 @@ class FeedForward(nn.Module):
project_in = nn.Sequential(
comfy.ops.Linear(dim, inner_dim, dtype=dtype),
nn.GELU()
) if not glu else GEGLU(dim, inner_dim)
) if not glu else GEGLU(dim, inner_dim, dtype=dtype)
self.net = nn.Sequential(
project_in,
@ -89,8 +89,8 @@ def zero_module(module):
return module
def Normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
def Normalize(in_channels, dtype=None):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype)
class SpatialSelfAttention(nn.Module):
@ -594,7 +594,7 @@ class SpatialTransformer(nn.Module):
context_dim = [context_dim]
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = Normalize(in_channels)
self.norm = Normalize(in_channels, dtype=dtype)
if not use_linear:
self.proj_in = nn.Conv2d(in_channels,
inner_dim,

View File

@ -111,14 +111,14 @@ class Upsample(nn.Module):
upsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
if use_conv:
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding, dtype=dtype)
def forward(self, x, output_shape=None):
assert x.shape[1] == self.channels
@ -160,7 +160,7 @@ class Downsample(nn.Module):
downsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
@ -169,7 +169,7 @@ class Downsample(nn.Module):
stride = 2 if dims != 3 else (1, 2, 2)
if use_conv:
self.op = conv_nd(
dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
dims, self.channels, self.out_channels, 3, stride=stride, padding=padding, dtype=dtype
)
else:
assert self.channels == self.out_channels
@ -220,7 +220,7 @@ class ResBlock(TimestepBlock):
self.use_scale_shift_norm = use_scale_shift_norm
self.in_layers = nn.Sequential(
normalization(channels),
normalization(channels, dtype=dtype),
nn.SiLU(),
conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=dtype),
)
@ -228,11 +228,11 @@ class ResBlock(TimestepBlock):
self.updown = up or down
if up:
self.h_upd = Upsample(channels, False, dims)
self.x_upd = Upsample(channels, False, dims)
self.h_upd = Upsample(channels, False, dims, dtype=dtype)
self.x_upd = Upsample(channels, False, dims, dtype=dtype)
elif down:
self.h_upd = Downsample(channels, False, dims)
self.x_upd = Downsample(channels, False, dims)
self.h_upd = Downsample(channels, False, dims, dtype=dtype)
self.x_upd = Downsample(channels, False, dims, dtype=dtype)
else:
self.h_upd = self.x_upd = nn.Identity()
@ -240,11 +240,11 @@ class ResBlock(TimestepBlock):
nn.SiLU(),
linear(
emb_channels,
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
2 * self.out_channels if use_scale_shift_norm else self.out_channels, dtype=dtype
),
)
self.out_layers = nn.Sequential(
normalization(self.out_channels),
normalization(self.out_channels, dtype=dtype),
nn.SiLU(),
nn.Dropout(p=dropout),
zero_module(
@ -604,6 +604,7 @@ class UNetModel(nn.Module):
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
dtype=self.dtype
)
]
ch = mult * model_channels
@ -651,10 +652,11 @@ class UNetModel(nn.Module):
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
down=True,
dtype=self.dtype
)
if resblock_updown
else Downsample(
ch, conv_resample, dims=dims, out_channels=out_ch
ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype
)
)
)
@ -679,6 +681,7 @@ class UNetModel(nn.Module):
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
dtype=self.dtype
),
AttentionBlock(
ch,
@ -698,6 +701,7 @@ class UNetModel(nn.Module):
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
dtype=self.dtype
),
)
self._feature_size += ch
@ -715,6 +719,7 @@ class UNetModel(nn.Module):
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
dtype=self.dtype
)
]
ch = model_channels * mult
@ -758,18 +763,19 @@ class UNetModel(nn.Module):
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
up=True,
dtype=self.dtype
)
if resblock_updown
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype)
)
ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
self.out = nn.Sequential(
normalization(ch),
normalization(ch, dtype=self.dtype),
nn.SiLU(),
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype)),
)
if self.predict_codebook_ids:
self.id_predictor = nn.Sequential(

View File

@ -206,13 +206,13 @@ def mean_flat(tensor):
return tensor.mean(dim=list(range(1, len(tensor.shape))))
def normalization(channels):
def normalization(channels, dtype=None):
"""
Make a standard normalization layer.
:param channels: number of input channels.
:return: an nn.Module for normalization.
"""
return GroupNorm32(32, channels)
return GroupNorm32(32, channels, dtype=dtype)
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.