mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-20 03:13:30 +00:00
Initialize more unet weights as the right dtype.
This commit is contained in:
parent
e21d9ad445
commit
7bf89ba923
@ -208,6 +208,7 @@ class ResBlock(TimestepBlock):
|
|||||||
use_checkpoint=False,
|
use_checkpoint=False,
|
||||||
up=False,
|
up=False,
|
||||||
down=False,
|
down=False,
|
||||||
|
dtype=None
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
@ -221,7 +222,7 @@ class ResBlock(TimestepBlock):
|
|||||||
self.in_layers = nn.Sequential(
|
self.in_layers = nn.Sequential(
|
||||||
normalization(channels),
|
normalization(channels),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
conv_nd(dims, channels, self.out_channels, 3, padding=1),
|
conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=dtype),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.updown = up or down
|
self.updown = up or down
|
||||||
@ -247,7 +248,7 @@ class ResBlock(TimestepBlock):
|
|||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
nn.Dropout(p=dropout),
|
nn.Dropout(p=dropout),
|
||||||
zero_module(
|
zero_module(
|
||||||
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
|
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1, dtype=dtype)
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -255,10 +256,10 @@ class ResBlock(TimestepBlock):
|
|||||||
self.skip_connection = nn.Identity()
|
self.skip_connection = nn.Identity()
|
||||||
elif use_conv:
|
elif use_conv:
|
||||||
self.skip_connection = conv_nd(
|
self.skip_connection = conv_nd(
|
||||||
dims, channels, self.out_channels, 3, padding=1
|
dims, channels, self.out_channels, 3, padding=1, dtype=dtype
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
|
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1, dtype=dtype)
|
||||||
|
|
||||||
def forward(self, x, emb):
|
def forward(self, x, emb):
|
||||||
"""
|
"""
|
||||||
@ -558,9 +559,9 @@ class UNetModel(nn.Module):
|
|||||||
|
|
||||||
time_embed_dim = model_channels * 4
|
time_embed_dim = model_channels * 4
|
||||||
self.time_embed = nn.Sequential(
|
self.time_embed = nn.Sequential(
|
||||||
linear(model_channels, time_embed_dim),
|
linear(model_channels, time_embed_dim, dtype=self.dtype),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
linear(time_embed_dim, time_embed_dim),
|
linear(time_embed_dim, time_embed_dim, dtype=self.dtype),
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.num_classes is not None:
|
if self.num_classes is not None:
|
||||||
@ -573,9 +574,9 @@ class UNetModel(nn.Module):
|
|||||||
assert adm_in_channels is not None
|
assert adm_in_channels is not None
|
||||||
self.label_emb = nn.Sequential(
|
self.label_emb = nn.Sequential(
|
||||||
nn.Sequential(
|
nn.Sequential(
|
||||||
linear(adm_in_channels, time_embed_dim),
|
linear(adm_in_channels, time_embed_dim, dtype=self.dtype),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
linear(time_embed_dim, time_embed_dim),
|
linear(time_embed_dim, time_embed_dim, dtype=self.dtype),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -584,7 +585,7 @@ class UNetModel(nn.Module):
|
|||||||
self.input_blocks = nn.ModuleList(
|
self.input_blocks = nn.ModuleList(
|
||||||
[
|
[
|
||||||
TimestepEmbedSequential(
|
TimestepEmbedSequential(
|
||||||
conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype)
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user