Initialize more unet weights as the right dtype.

This commit is contained in:
comfyanonymous 2023-06-15 15:00:10 -04:00
parent e21d9ad445
commit 7bf89ba923

View File

@ -208,6 +208,7 @@ class ResBlock(TimestepBlock):
use_checkpoint=False, use_checkpoint=False,
up=False, up=False,
down=False, down=False,
dtype=None
): ):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
@ -221,7 +222,7 @@ class ResBlock(TimestepBlock):
self.in_layers = nn.Sequential( self.in_layers = nn.Sequential(
normalization(channels), normalization(channels),
nn.SiLU(), nn.SiLU(),
conv_nd(dims, channels, self.out_channels, 3, padding=1), conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=dtype),
) )
self.updown = up or down self.updown = up or down
@ -247,7 +248,7 @@ class ResBlock(TimestepBlock):
nn.SiLU(), nn.SiLU(),
nn.Dropout(p=dropout), nn.Dropout(p=dropout),
zero_module( zero_module(
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1, dtype=dtype)
), ),
) )
@ -255,10 +256,10 @@ class ResBlock(TimestepBlock):
self.skip_connection = nn.Identity() self.skip_connection = nn.Identity()
elif use_conv: elif use_conv:
self.skip_connection = conv_nd( self.skip_connection = conv_nd(
dims, channels, self.out_channels, 3, padding=1 dims, channels, self.out_channels, 3, padding=1, dtype=dtype
) )
else: else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) self.skip_connection = conv_nd(dims, channels, self.out_channels, 1, dtype=dtype)
def forward(self, x, emb): def forward(self, x, emb):
""" """
@ -558,9 +559,9 @@ class UNetModel(nn.Module):
time_embed_dim = model_channels * 4 time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential( self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim), linear(model_channels, time_embed_dim, dtype=self.dtype),
nn.SiLU(), nn.SiLU(),
linear(time_embed_dim, time_embed_dim), linear(time_embed_dim, time_embed_dim, dtype=self.dtype),
) )
if self.num_classes is not None: if self.num_classes is not None:
@ -573,9 +574,9 @@ class UNetModel(nn.Module):
assert adm_in_channels is not None assert adm_in_channels is not None
self.label_emb = nn.Sequential( self.label_emb = nn.Sequential(
nn.Sequential( nn.Sequential(
linear(adm_in_channels, time_embed_dim), linear(adm_in_channels, time_embed_dim, dtype=self.dtype),
nn.SiLU(), nn.SiLU(),
linear(time_embed_dim, time_embed_dim), linear(time_embed_dim, time_embed_dim, dtype=self.dtype),
) )
) )
else: else:
@ -584,7 +585,7 @@ class UNetModel(nn.Module):
self.input_blocks = nn.ModuleList( self.input_blocks = nn.ModuleList(
[ [
TimestepEmbedSequential( TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels, 3, padding=1) conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype)
) )
] ]
) )